In [1]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torchvision.transforms as transform

In [2]:
class unet(nn.Module):
    def __init__(self):
        super(unet, self).__init__()
        self.conv1 = nn.Conv2d(1,64,3) # (in_channels, out_channels, kernel_size)
        self.conv2 = nn.Conv2d(64,64,3) # (in_channels, out_channels, kernel_size)
        self.conv3 = nn.Conv2d(64,128,3) # (in_channels, out_channels, kernel_size)
        self.conv4 = nn.Conv2d(128,128,3) # (in_channels, out_channels, kernel_size)
        self.conv5 = nn.Conv2d(128,256,3) # (in_channels, out_channels, kernel_size)
        self.conv6 = nn.Conv2d(256,256,3) # (in_channels, out_channels, kernel_size)
        self.conv7 = nn.Conv2d(256,512,3) # (in_channels, out_channels, kernel_size)
        self.conv8 = nn.Conv2d(512,512,3) # (in_channels, out_channels, kernel_size)
        self.conv9 = nn.Conv2d(512,1024,3) # (in_channels, out_channels, kernel_size)
        self.conv10 = nn.Conv2d(1024,1024,3) # (in_channels, out_channels, kernel_size)
        self.convtrans1 = nn.ConvTranspose2d(1024,512,2) # (in_channels, out_channels, kernel_size)
        self.conv11 = nn.Conv2d(1024,512,3) # (in_channels, out_channels, kernel_size)
        self.conv12 = nn.Conv2d(512,512,3) # (in_channels, out_channels, kernel_size)
        self.convtrans2 = nn.ConvTranspose2d(512,256,2) # (in_channels, out_channels, kernel_size)
        self.conv13 = nn.Conv2d(512,256,3) # (in_channels, out_channels, kernel_size)
        self.conv14 = nn.Conv2d(256,256,3) # (in_channels, out_channels, kernel_size)
        self.convtrans3 = nn.ConvTranspose2d(256,128,2) # (in_channels, out_channels, kernel_size)
        self.conv15 = nn.Conv2d(256,128,3) # (in_channels, out_channels, kernel_size)
        self.conv16 = nn.Conv2d(128,128,3) # (in_channels, out_channels, kernel_size)
        self.convtrans4 = nn.ConvTranspose2d(128,64,2) # (in_channels, out_channels, kernel_size)
        self.conv17 = nn.Conv2d(128,64,3) # (in_channels, out_channels, kernel_size)
        self.conv18 = nn.Conv2d(64,64,3) # (in_channels, out_channels, kernel_size)
        self.conv19 = nn.Conv2d(64,2,1) # (in_channels, out_channels, kernel_size)        
        self.maxpool = nn.MaxPool2d(2) # (kernel_size, stride, padding)   
        
    def forward(self,x):
        
        x = F.relu(self.conv1(x))
        out1 = F.relu(self.conv2(x))
        out1_transform = nn.Upsample(size=(392, 392), mode='bilinear')(out1)
        
        x = self.maxpool(out1)
        x = F.relu(self.conv3(x))
        out2 = F.relu(self.conv4(x))
        out2_transform = nn.Upsample(size=(200, 200), mode='bilinear')(out2)
        
        x = self.maxpool(out2)
        x = F.relu(self.conv5(x))
        out3 = F.relu(self.conv6(x))
        out3_transform = nn.Upsample(size=(104, 104), mode='bilinear')(out3)
        
        x = self.maxpool(out3)
        x = F.relu(self.conv7(x))
        out4 = F.relu(self.conv8(x))
        out4_transform = nn.Upsample(size=(56, 56), mode='bilinear')(out4)
        
        x = self.maxpool(out4)
        x = F.relu(self.conv9(x))
        x = F.relu(self.conv10(x))
        out5 = self.convtrans1(x)
        out5_transform = nn.Upsample(size=(56, 56), mode='bilinear')(out5)
        
        print(out4_transform.shape)
        print(out5.shape)
        
        x = F.relu(self.conv11(torch.cat((out4_transform,out5_transform), 1)))
        x = F.relu(self.conv12(x))
        out6 = self.convtrans2(x)
        out6_transform = nn.Upsample(size=(104, 104), mode='bilinear')(out6)
        
        print(out3_transform.shape)
        print(out6.shape)
        
        x = F.relu(self.conv13(torch.cat((out3_transform,out6_transform), 1)))
        x = F.relu(self.conv14(x))
        out7 = self.convtrans3(x)
        out7_transform = nn.Upsample(size=(200, 200), mode='bilinear')(out7)
        
        print(out2_transform.shape)
        print(out7.shape)
        
        x = F.relu(self.conv15(torch.cat((out2_transform,out7_transform), 1)))
        x = F.relu(self.conv16(x))
        out8 = self.convtrans4(x)
        out8_transform = nn.Upsample(size=(392, 392), mode='bilinear')(out8)
        
        print(out1_transform.shape)
        print(out8.shape)
        
        x = F.relu(self.conv17(torch.cat((out1_transform,out8_transform), 1)))
        x = F.relu(self.conv18(x))
        x = F.relu(self.conv19(x))
        
        return x
        

In [3]:
unet_inst = unet().float()

In [4]:
img = torch.from_numpy(np.zeros((574,574))).float().unsqueeze_(0)
print(img.shape)

torch.Size([1, 574, 574])


In [5]:
unet_inst(img.unsqueeze_(0))

  "See the documentation of nn.Upsample for details.".format(mode))


torch.Size([1, 512, 56, 56])
torch.Size([1, 512, 29, 29])
torch.Size([1, 256, 104, 104])
torch.Size([1, 256, 53, 53])
torch.Size([1, 128, 200, 200])
torch.Size([1, 128, 101, 101])
torch.Size([1, 64, 392, 392])
torch.Size([1, 64, 197, 197])


tensor(1.00000e-03 *
       [[[[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          ...,
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],

         [[ 6.7505,  6.8766,  6.8726,  ...,  6.8507,  6.8471,  6.7724],
          [ 6.6623,  6.7422,  6.7122,  ...,  6.6760,  6.6658,  6.5747],
          [ 6.6249,  6.6822,  6.6259,  ...,  6.5896,  6.5803,  6.5006],
          ...,
          [ 6.5805,  6.6454,  6.5919,  ...,  6.5525,  6.5532,  6.4831],
          [ 6.5576,  6.6463,  6.5925,  ...,  6.5428,  6.5348,  6.4673],
          [ 6.5626,  6.6432,  6.5915,  ...,  6.5410,  6.5283,  6.4575]]]])