In [154]:
import torch
import torch.nn as nn
import torchvision.transforms.functional as TF

In [69]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace = True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace = True),  
        )
    
    def forward(self, x):
        return self.conv(x)

In [151]:
class unet(nn.Module):
    def __init__(self, in_channels = 3, out_channels = 1, channels = [64,128,256,512]):
        super(unet, self).__init__()
        self.downs = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size = 2, stride = 2)
        self.ups = nn.ModuleList()
        
        # building downconv part of unet
        for channel in channels:
            self.downs.append(DoubleConv(in_channels, channel))
            in_channels = channel
            
        # note that we end channels at 512 and then add a convolution for 1024 channels because of the for loop
        # this 1024 is only used once going down, but other values are used going down and up
        self.bottleneck = DoubleConv(channels[-1], channels[-1]*2)
            
        # building upconv part of unet
        # note that the length of ups is 2x that of downs bc for every channel you append ConvTranspose2d and DoubleConv
        for channel in channels[::-1]:
            self.ups.append(nn.ConvTranspose2d(channel*2, channel, kernel_size = 2, stride = 2))
            self.ups.append(DoubleConv(channel*2, channel))
            
        self.final_conv = DoubleConv(channels[0], out_channels)
        
    def forward(self, x):
        skip_connections = []
        
        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)
            print('downs:',x.shape)
            
        x = self.bottleneck(x)
        print('bottleneck:',x.shape)
        
        #reverse list in place because channels of skip connections going up is in opposite direction of going down
        skip_connections = skip_connections[::-1]
        
        for i in range(0, len(self.ups), 2):
            x = self.ups[i](x) #doing the upconv
            sc = skip_connections[i//2]
            
            if x.shape != sc.shape: #only take :2 because only want height and width, don't watch batch size or channels
                x = TF.resize(x, size = sc.shape[:2])
            
            x = torch.cat((x, sc), dim = 1)
            x = self.ups[i+1](x) #doing the downconv
            print('ups:',x.shape)
            
#         print(x.shape, self.final_conv)
        x = self.final_conv(x)
        
        return(x)

In [152]:
def test_model():
    x = torch.randn((3,1,160,160))
    model = unet(in_channels = 1, out_channels = 1)
    preds = model(x)
    print(x.shape, preds.shape)
    assert x.shape == preds.shape

In [153]:
test_model()

downs: torch.Size([3, 64, 80, 80])
downs: torch.Size([3, 128, 40, 40])
downs: torch.Size([3, 256, 20, 20])
downs: torch.Size([3, 512, 10, 10])
bottleneck: torch.Size([3, 1024, 10, 10])
ups: torch.Size([3, 512, 20, 20])
ups: torch.Size([3, 256, 40, 40])
ups: torch.Size([3, 128, 80, 80])
ups: torch.Size([3, 64, 160, 160])
torch.Size([3, 1, 160, 160]) torch.Size([3, 1, 160, 160])
