In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as TF

In [3]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3,stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)


In [23]:
class UNET(nn.Module):
    def __init__(self, in_channels=3, out_channels=3):
        super(UNET, self).__init__()
        self.conv1 = DoubleConv(in_channels, 64)
        self.conv2 = DoubleConv(64, 128)
        self.conv3 = DoubleConv(128, 256)
        self.conv4 = DoubleConv(256, 512)
        self.conv5 = DoubleConv(512, 1024)

        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv6 = DoubleConv(1024, 512)
        self.conv7 = DoubleConv(512, 256)
        self.conv8 = DoubleConv(256, 128)
        self.conv9 = DoubleConv(128, 64)

        self.tconv1 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.tconv2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.tconv3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.tconv4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)

        self.bottleneck = DoubleConv(1024, 1024)
        self.out = nn.Conv2d(64, out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []
        c1 = self.conv1(x)
        skip_connections.append(c1)
        c2 = self.conv2(self.maxpool(c1))
        skip_connections.append(c2)
        c3 = self.conv3(self.maxpool(c2))
        skip_connections.append(c3)
        c4 = self.conv4(self.maxpool(c3))
        skip_connections.append(c4)
        out = self.conv5(self.maxpool(c4))
        out = self.bottleneck(out)
        skip_connections = skip_connections[::-1]
        out = self.tconv1(out)
        if out.shape != skip_connections[0].shape:
            out = TF.resize(out, size=skip_connections[0].shape[2:])
        out = torch.cat([out, skip_connections[0]], dim=1)
        out = self.conv6(out)

        out = self.tconv2(out)
        if out.shape != skip_connections[1].shape:
            out = TF.resize(out, size=skip_connections[1].shape[2:])
        out = torch.cat((out, skip_connections[1]), dim=1)
        out = self.conv7(out)

        out = self.tconv3(out)
        if out.shape != skip_connections[2].shape:
            out = TF.resize(out, size=skip_connections[2].shape[2:])
        out = torch.cat((out, skip_connections[2]), dim=1)
        out = self.conv8(out)

        out = self.tconv4(out)
        if out.shape != skip_connections[3].shape:
            out = TF.resize(out, size=skip_connections[3].shape[2:])
        out = torch.cat((out, skip_connections[3]), dim=1)
        out = self.conv9(out)

        out = self.out(out)
        return out
    




In [26]:
def test():
    x = torch.randn((5, 3, 161, 161))
    model = UNET(in_channels=3, out_channels=3)
    preds = model(x)
    print(preds.shape)
    print(x.shape)
    assert preds.shape == x.shape

In [27]:
test()

torch.Size([5, 3, 161, 161])
torch.Size([5, 3, 161, 161])


In [None]:
a= torch.rand(32, 3, 161, 161)
b = torch.rand(32, 3, 161, 161)
c = torch.cat([a,b],dim=1)
print(c.shape)


torch.Size([32, 6, 161, 161])
