In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import numpy as np

In [None]:
class config:
    seed = 10

In [14]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

set_seed(config.seed)

In [15]:
class Encoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.c1 = nn.Conv2d(in_channels, out_channels, kernel_size = 3, padding = 1)
        self.c2 = nn.Conv2d(out_channels, out_channels, kernel_size = 3, padding = 1)
        self.pool = nn.MaxPool2d(kernel_size=2)
        self.act = nn.ReLU(inplace = True)

    def forward(self, x):
        x = self.act(self.c1(x))
        x = self.act(self.c2(x))
        return x, self.pool(x)

In [16]:
EncoderTest = Encoder(8, 64)
dummy_input = torch.randn(1, 8, 512, 512)
out = EncoderTest(dummy_input)
print(out[1].shape)

torch.Size([1, 64, 256, 256])


In [17]:
class Decoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.upConv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
        self.c1 = nn.Conv2d(in_channels, out_channels, kernel_size = 3, padding = 1)
        self.c2 = nn.Conv2d(out_channels, out_channels, kernel_size = 3, padding = 1)
        self.act = nn.ReLU(inplace = True)

    def forward(self, z, y):
        
        z = self.upConv(z)
        x = torch.cat([z,y], dim = 1)
        x = self.act(self.c1(x))
        x = self.act(self.c2(x))
        
        return x

In [19]:
DecoderTest = Decoder(in_channels=64, out_channels=32)
dummy_input1 = torch.randn(1, 64, 512, 512)
dummy_input2 = torch.randn(1, 32, 1024, 1024)
out = DecoderTest(dummy_input1, dummy_input2)
print(out.shape)

torch.Size([1, 32, 1024, 1024])


In [20]:
class UnetModel(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.downsample1 = Encoder(in_channels, 64)
        self.downsample2 = Encoder(64, 128)
        self.downsample3 = Encoder(128, 256)
        self.downsample4 = Encoder(256, 512)
        self.downsample5 = Encoder(512, 1024)
        
        self.upsample1 = Decoder(1024, 512)
        self.upsample2 = Decoder(512, 256)
        self.upsample3 = Decoder(256, 128)
        self.upsample4 = Decoder(128, 64)

        self.lastconv = nn.Conv2d(64, out_channels, kernel_size=1)

    def forward(self, x):
        x1, x = self.downsample1(x)
        x2, x = self.downsample2(x)
        x3, x = self.downsample3(x)
        x4, x = self.downsample4(x)
        x5, _ = self.downsample5(x)


        x = self.upsample1(x5, x4)
        x = self.upsample2(x, x3)
        x = self.upsample3(x, x2)
        x = self.upsample4(x, x1)
        
        return self.lastconv(x)