In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [5]:
class Conv(nn.Module):
    def __init__(self,input_size, output_size):
        super(Conv, self).__init__()
        
        self.input_size = input_size
        self.output_size = output_size
        self.conv1 = nn.Conv2d(in_channels=self.input_size, out_channels=self.output_size,
                               kernel_size=3,padding='same')
        self.conv2_bn = nn.BatchNorm2d(20)
        
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2_bn(x)
        x = F.relu(x)
        
        return x
           

In [18]:
model = Conv(256,512)
model

Conv(
  (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=same)
  (conv2_bn): BatchNorm2d(20, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

In [8]:
class DownConv(nn.Module):
    def __init__(self,input_size, output_size):
        super(DownConv, self).__init__()
        
        self.input_size = input_size
        self.output_size = output_size
        self.conv1 = nn.Conv2d(in_channels=self.input_size, out_channels=self.output_size,
                               kernel_size=3,padding='same')
        self.conv2_bn = nn.BatchNorm2d(20)
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2_bn(x)
        x = F.max_pool2d(x, 2)
        x = F.relu(x)
        
        return x
             
    

In [17]:
model = DownConv(256,512)
model

DownConv(
  (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=same)
  (conv2_bn): BatchNorm2d(20, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

In [10]:
class InConv(nn.Module):
    def __init__(self,input_size, output_size):
        super(InConv, self).__init__()
        
        self.input_size = input_size
        self.output_size = output_size
        self.conv1 = nn.Conv2d(in_channels=self.input_size, out_channels=self.output_size,
                               kernel_size=3,padding='same')
        self.conv1_bn = nn.BatchNorm2d(20)
        self.conv2 = nn.Conv2d(in_channels=self.input_size, out_channels=self.output_size,
                               kernel_size=3,padding='same')
        self.conv2_bn = nn.BatchNorm2d(20)
        
    def forward(self,x):
        x = self.conv1(x)
        x = self.conv1_bn(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = self.conv2_bn(x)
        x = F.relu(x)
        
        return x
        


In [16]:
model = InConv(256,512)
model

InConv(
  (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=same)
  (conv1_bn): BatchNorm2d(20, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=same)
  (conv2_bn): BatchNorm2d(20, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

In [14]:
class Down(nn.Module):
    def __init__(self,input_size, output_size):
        super(Down, self).__init__()
        
        self.input_size = input_size
        self.output_size = output_size
        self.conv1 = nn.Conv2d(in_channels=self.input_size, out_channels=self.output_size,
                               kernel_size=3,padding='same')
        self.conv1_bn = nn.BatchNorm2d(20)
        self.conv2 = nn.Conv2d(in_channels=self.input_size, out_channels=self.output_size,
                               kernel_size=3,padding='same')
        self.conv2_bn = nn.BatchNorm2d(20)
        
        
    def forward(self,x):
        x = F.max_pool2d(x, 2)
        x = self.conv1(x)
        x = self.conv1_bn(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = self.conv2_bn(x)
        x = F.relu(x)
        
        return x
        
        

In [15]:
model = Down(256,512)
model

Down(
  (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=same)
  (conv1_bn): BatchNorm2d(20, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=same)
  (conv2_bn): BatchNorm2d(20, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

In [19]:
class OutConv(nn.Module):
    def __init__(self,input_size, output_size):
        super(OutConv, self).__init__()
        
        self.input_size = input_size
        self.output_size = output_size
        self.conv1 = nn.Conv2d(in_channels=self.input_size, out_channels=self.output_size,
                               kernel_size=1)
    def forward(self,x):
        x = self.conv1(x)
        return x
        

In [None]:
#Still need to work on this module
class Up(nn.Module):
    def __init__(self,input_size,output_size):
        
        self.input_size = input_size
        self.output_size = output_size
        self.conv2d_transpose = nn.ConvTranspose2d(in_channels=self.input_size, out_channels=self.output_size,
                               kernel_size=3, stride = 2, padding='same')
        
        #F.conv_transpose2d(input, weight)
        

In [20]:
class unet(nn.Module):
    def __init__(self,input_size, output_size):
        super(unet, self).__init__()

        self.input_size = input_size
        self.output_size = output_size

        self.input1 = InConv(input_size, 64)
        
        self.downsc1 = Down(64, 128)
        self.downsc2 = Down(128, 256)
        self.downsc3 = Down(256, 512)
        self.downsc4 = Down(512, 1024)
        
        self.upsc1 = Up(1024, 512)
        self.upsc2 = Up(512, 256)
        self.upsc3 = Up(256, 128)
        self.upsc4 = Up(128, 64)
        self.outc = OutConv(64, output_size)

    def forward(self, x):
    
        in1 = self.input1(x)
        in2 = self.downsc1(in1)
        in3 = self.downsc2(in2)
        in4 = self.downsc3(in3)
        in5 = self.downsc4(in4)
        x = self.upsc1(in5, in4)
        x = self.upsc2(x, in3)
        x = self.upsc3(x, in2)
        x = self.upsc4(x, in1)
        output = self.outc(x)
        return output

In [23]:
class Discriminator(nn.Module):
    def __init__(self,input_size):
        super(Discriminator, self).__init__()
        
        self.input_size = input_size
        
        self.conv1 = DownConv(input_size, 64)
        self.conv2 = DownConv(64,128)
        self.conv3 = DownConv(128,256)
        self.conv4 = DownConv(256,512)
        
    def forward(self,x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        
        dim = np.prod(x.shape[1:])
        x = x.view(-1,dim)
        x = nn.linear(x,1)
        x = nn.Sigmoid()
        
        return x
        

In [24]:
model = Discriminator(256)
model

Discriminator(
  (conv1): DownConv(
    (conv1): Conv2d(256, 64, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (conv2_bn): BatchNorm2d(20, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (conv2): DownConv(
    (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (conv2_bn): BatchNorm2d(20, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (conv3): DownConv(
    (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (conv2_bn): BatchNorm2d(20, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (conv4): DownConv(
    (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (conv2_bn): BatchNorm2d(20, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)