In [None]:
class UNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet, self).__init__()
        # Encoder (downsampling)
        self.encoder1 = UNetConvBlock(in_channels, 64)
        self.encoder2 = UNetConvBlock(64, 128)
        self.encoder3 = UNetConvBlock(128, 256)
        self.encoder4 = UNetConvBlock(256, 512)

        # Bottleneck
        self.bottleneck = UNetConvBlock(512, 1024)

        # Decoder (upsampling)
        self.decoder4 = UNetUpBlock(1024, 512)
        self.decoder3 = UNetUpBlock(512, 256)
        self.decoder2 = UNetUpBlock(256, 128)
        self.decoder1 = UNetUpBlock(128, 64)

        # Final convolution layer to output segmentation mask
        self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)

    def forward(self, x):
        # Encoder path
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(enc1)
        enc3 = self.encoder3(enc2)
        enc4 = self.encoder4(enc3)

        # Bottleneck
        bottleneck = self.bottleneck(enc4)

        # Decoder path with skip connections
        dec4 = self.decoder4(bottleneck, enc4)
        dec3 = self.decoder3(dec4, enc3)
        dec2 = self.decoder2(dec3, enc2)
        dec1 = self.decoder1(dec2, enc1)

        # Final convolution layer
        output = self.final_conv(dec1)

        return output

class UNetConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNetConvBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.bn(x)
        return x

class UNetUpBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNetUpBlock, self).__init__()
        self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
        self.conv_block = UNetConvBlock(in_channels, out_channels)

    def forward(self, x, skip_connection):
        x = self.up(x)
        # Adjust dimensions if necessary
        diffY = skip_connection.size()[2] - x.size()[2]
        diffX = skip_connection.size()[3] - x.size()[3]
        x = F.pad(x, (diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2))
        x = torch.cat([x, skip_connection], dim=1)
        x = self.conv_block(x)
        return x