In [1]:
import torch
import torch.nn.functional as F #?? 
from torch import nn

In [2]:
#This should be in another place, but I couldnt figure out how to import from another notebook
def initialize_weights(*models):
    for model in models:
        for module in model.modules():
            if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
                nn.init.kaiming_normal(module.weight)
                if module.bias is not None:
                    module.bias.data.zero_()
            elif isinstance(module, nn.BatchNorm2d):
                module.weight.data.fill_(1)
                module.bias.data.zero_()

In [3]:
class _EncoderBlock(nn.Module):
    
    def __init__(model, input_channels, output_channels, dropout=False, polling=True, bn=False):
        #class initialization with inputs model, the input and output channels, and choices on whether to dropout, poll, or use batch normalization
        super(_EncoderBlock, model).__init__() #Takes initialization values from torch library
        layers = [ #Set layers as two sets of: Convolution, Normalization (Batch or Group), GELU activation
            nn.Conv2d(input_channels, output_channels, kernel_size=3, padding=1, padding_mode='reflect'),
            nn.BatchNorm2d(output_channels) if bn else nn.GroupNorm(32, output_channels),
            nn.GELU(),
            nn.Conv2d(output_channels, output_channels, kernel_size=3, padding=1, padding_mode='reflect'),
            nn.BatchNorm2d(output_channels) if bn else nn.GroupNorm(32, output_channels),
            nn.GELU(),
        ]
        
        if dropout:
            layers.append(nn.Dropout()) #Add dropout layer to the list if requested
            
        model.encode = nn.Sequential(*layers) #Set the encode atribute as the actual build of the model
        model.pool = None
        if polling:
            model.pool = nn.MaxPool2d(kernel_size=2, stride=2) #If pooling is requested, create an atribute for it
        
    def forward(model,x): #This function is called by pytorch automatically, and is the actual "running" part of the class
        if model.pool is not None:
            x = model.pool(x) #Does pooling if previously requested
        return model.encode(x) #Returns the encoded model

In [4]:
class _DecoderBlock(nn.Module):
    
    def __init__(model, input_channels, middle_channels, output_channels, bn=False):
        super(_DecoderBlock, model).__init__()
        layers = [ #Same layers as above. It will later be used backward to "decode"
            nn.Conv2d(input_channels, middle_channels, kernel_size=3, padding=1, padding_mode='reflect'),
            nn.BatchNorm2d(middle_channels) if bn else nn.GroupNorm(32, middle_channels),
            nn.GELU(),
            nn.Conv2d(middle_channels, output_channels, kernel_size=3, padding=1, padding_mode='reflect'),
            nn.BatchNorm2d(output_channels) if bn else nn.GroupNorm(32, output_channels),
            nn.GELU(),
        ]
        
        #There is no pooling for the decoding steps
        
        model.decode = nn.Sequential(*layers)
        
    def forward(model, x):
        return model.decode(x)

In [5]:
class UNet(nn.Module):
    
    def __init__(model, classes, input_channels=3, bn=False):
        #Initializing the full UNet architecture. bn and input_channels are set, but put as inputs to make it easy to change
        super(UNet, model).__init__()
        model.enc1 = _EncoderBlock(input_channels, 64, polling=False, bn=bn)
        model.enc2 = _EncoderBlock(64, 128, bn=bn)
        model.enc3 = _EncoderBlock(128, 256, bn=bn)
        model.enc4 = _EncoderBlock(256, 512, bn=bn)
        
        model.polling = nn.AvgPool2d(kernel_size=2, stride=2)        
        model.center = _DecoderBlock(512, 1024, 512, bn=bn)
        #Pool and centering before transitioning between encoding and decoding. Similar to "flatten"
        
        model.dec4 = _DecoderBlock(1024, 512, 256, bn=bn)
        model.dec3 = _DecoderBlock(512, 256, 128, bn=bn)
        model.dec2 = _DecoderBlock(256, 128, 64, bn=bn)
        model.dec1 = nn.Sequential( #For the last convolution set, we remove padding and change the kernel size
            nn.Conv2d(128, 64, kernel_size=3, padding=1, padding_mode='reflect'),
            nn.BatchNorm2d(64) if bn else nn.GroupNorm(32, 64),
            nn.GELU(),
            nn.Conv2d(64, 64, kernel_size=1, padding=0),
            nn.BatchNorm2d(64) if bn else nn.GroupNorm(32, 64),
            nn.GELU(),
        )
        
        model.final = nn.Conv2d(64, classes, kernel_size=1) #One last convolution
        initialize_weights(model) #initialize weights for our model
        
    def forward(model, x): #Actually run through the model for some input
        enc1 = model.enc1(x)
        enc2 = model.enc2(enc1)
        enc3 = model.enc3(enc2)
        enc4 = model.enc4(enc3)
        center = model.center(model.polling(enc4)) #This is the "flattening" step
        dec4 = model.dec4(torch.cat([F.interpolate(center, enc4.size()[-2:], align_corners=False, mode='bilinear'), enc4], 1))
        dec3 = model.dec3(torch.cat([F.interpolate(dec4, enc3.size()[-2:], align_corners=False, mode='bilinear'), enc3], 1))
        dec2 = model.dec2(torch.cat([F.interpolate(dec3, enc2.size()[-2:], align_corners=False, mode='bilinear'), enc2], 1))
        dec1 = model.dec1(torch.cat([F.interpolate(dec2, enc1.size()[-2:], align_corners=False, mode='bilinear'), enc1], 1))
        #There is more required for the decoding as we need to take the previous step and merge it with the encoding output
        #It first interpolates decN to the size of encN, then concatenates the two together, then runs the decoder block
        
        final = model.final(dec1)
        
        return final    

In [6]:
if __name__ == '__main__':
    model = UNet(input_channels=1, classes=1)
    print(model)
    x = torch.randn(1, 1, 200, 200)
    with torch.no_grad():
        final = model(x)
        print(final.shape)

  nn.init.kaiming_normal(module.weight)


UNet(
  (enc1): _EncoderBlock(
    (encode): Sequential(
      (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode=reflect)
      (1): GroupNorm(32, 64, eps=1e-05, affine=True)
      (2): GELU()
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode=reflect)
      (4): GroupNorm(32, 64, eps=1e-05, affine=True)
      (5): GELU()
    )
  )
  (enc2): _EncoderBlock(
    (encode): Sequential(
      (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode=reflect)
      (1): GroupNorm(32, 128, eps=1e-05, affine=True)
      (2): GELU()
      (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode=reflect)
      (4): GroupNorm(32, 128, eps=1e-05, affine=True)
      (5): GELU()
    )
    (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (enc3): _EncoderBlock(
    (encode): Sequential(
      (0): Conv2d(128, 256, kernel_size=(3, 3), st

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


torch.Size([1, 1, 200, 200])


Things to do:

Remove dropout. It doesnt seem to be used in the model and was probably just a carryover from their testing. Need to see if its used anywhere else. 

Figure out how to import functions from other directories. The initialize_weights function should be imported from the utils folder instead of in here. 