## Implementing a U-Net model

This code implements a U-Net model for semantic segmentation from the paper [U-Net: Convolutional Networks for Biomedical Image Segmentation](https://arxiv.org/abs/1505.04597).

A U-Net consists of an encoder - a series of convolution and pooling layers  which reduce the spatial resolution of the input, followed by a decoder - a series of transposed convolution and upsampling layers which increase the spatial resolution of the input. The encoder and decoder are connected by a bottleneck layer which is responsible for reducing the number of channels in the input.

The key innovation of U-Net is the addition of skip connections that connect the contracting path to the corresponding layers in the expanding path, allowing the network to recover fine-grained details lost during downsampling.

In [61]:
import torch
import torch.nn as nn
import torchvision.transforms.functional

# Implement the double 3X3 convolution blocks
# The original paper did not use padding, but we will use padding to keep the image size the same

class double_convolution(nn.Module):
    """
    This class implements the double convolution block which consists of two 3X3 convolution layers,
    each followed by a ReLU activation function.

    """
    def __init__(self, in_channels, out_channels): # Initialize the class
        super().__init__() # Initialize the parent class

        # First 3X3 convolution layer
        self.first_cnn = nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 3, padding = 1)
        self.act1 = nn.ReLU()

        # Second 3X3 convolution layer
        self.second_cnn = nn.Conv2d(in_channels = out_channels, out_channels = out_channels, kernel_size = 3, padding = 1)
        self.act2 = nn.ReLU()

    # Pass the input through the double convolution block
    def forward(self, x):
        x = self.first_cnn(x)
        x = self.act1(x)
        x = self.act2(self.second_cnn(x))
        return x


# Implement the Downsample block that occurs after each double convolution block
class down_sample(nn.Module):
    """
    This class implements the downsample block which consists of a Max Pooling layer with a kernel size of 2.
    The Max Pooling layer halves the image size reducing the spatial resolution of the feature maps
    while retaining the most important features.
    """
    def __init__(self):
        super().__init__()
        self.max_pool = nn.MaxPool2d(kernel_size = 2, stride = 2)
    
    # Pass the input through the downsample block
    def forward(self, x):
        x = self.max_pool(x)
        return x
    
# Implement the UpSample block that occurs in the decoder part of the network
class up_sample(nn.Module):
    """
    This class implements the upsample block which consists of a convolution transpose layer with a kernel size of 2.
    The convolution transpose layer doubles the image size increasing the spatial resolution of the feature maps
    while retaining the learned features.
    """
    def __init__(self, in_channels, out_channels):
        super().__init__()

        # Convolution transpose layer
        self.up_sample = nn.ConvTranspose2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 2, stride = 2)

    # Pass the input through the upsample block
    def forward(self, x):
        x = self.up_sample(x)
        return x 

# Implement the crop and concatenate block that occurs in the decoder part of the network
# This block concatenates the output of the upsample block with the output of the corresponding downsample block
# The output of the crop and concatenate block is then passed through a double convolution block
class crop_and_concatenate(nn.Module):
    """
    This class implements the crop and concatenate block which combines the output of the upsample block
    with the corresponding features from the contracting path through skip connections,
    allowing the network to recover the fine-grained details lost during downsampling
    and produce a high-resolution output segmentation map.
    """ 
    # def forward(self, upsampled, bypass):
    #     # Crop the feature map from the contacting path to match the size of the upsampled feature map
    #     bypass = torchvision.transforms.functional.center_crop(img = bypass, output_size = [upsampled.shape[2], upsampled.shape[3]]) 
    #     # Concatenate the upsampled feature map with the cropped feature map from the contracting path
    #     x = torch.cat([upsampled, bypass], dim = 1) # Concatenate along the channel dimension
    #     return x
    # Alternatively crop the upsampled feature map to match the size of the feature map from the contracting path
    def forward(self, upsampled, bypass):
        upsampled = torchvision.transforms.functional.resize(img = upsampled, size = bypass.shape[2:], antialias=True)
        x = torch.cat([upsampled, bypass], dim = 1) # Concatenate along the channel dimension
        return x

# m = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
# input = torch.randn(1, 1024, 28, 28)
# m(input).shape 

# m = nn.MaxPool2d(kernel_size = 2, stride = 2)
# xx = torch.randn(1, 1, 143, 143)
# m(xx).shape

## Implement the UNet architecture
class UNet(nn.Module):
    # in_channels: number of channels in the input image
    # out_channels: number of channels in the output image
    def __init__(self, in_channels, out_channels):
        super().__init__()

        # Define the contracting path: convolution blocks followed by downsample blocks
        self.down_conv = nn.ModuleList(double_convolution(in_chans, out_chans) for in_chans, out_chans in
                                       [(in_channels, 64), (64, 128), (128, 256), (256, 512)]) # List of downsample blocks
        
        self.down_samples = nn.ModuleList(down_sample() for _ in range(4))

        # Define the bottleneck layer
        self.bottleneck = double_convolution(in_channels = 512, out_channels = 1024)

        # Define the expanding path: upsample blocks followed by convolution blocks
        self.up_samples = nn.ModuleList(up_sample(in_chans, out_chans) for in_chans, out_chans in
                                        [(1024, 512), (512, 256), (256, 128), (128, 64)]) # List of upsample blocks
        
        self.concat = nn.ModuleList(crop_and_concatenate() for _ in range(4))

        self.up_conv = nn.ModuleList(double_convolution(in_chans, out_chans) for in_chans, out_chans in
                                        [(1024, 512), (512, 256), (256, 128), (128, 64)]) # List of convolution blocks
        
        # Final 1X1 convolution layer to produce the output segmentation map:
        # The primary purpose of 1x1 convolutions is to transform the channel dimension of the feature map,
        # while leaving the spatial dimensions unchanged.
        self.final_conv = nn.Conv2d(in_channels = 64, out_channels = out_channels, kernel_size = 1)

    # Pass the input through the UNet architecture
    def forward(self, x):
        # Pass the input through the contacting path
        skip_connections = [] # List to store the outputs of the downsample blocks
        for down_conv, down_sample in zip(self.down_conv, self.down_samples):
            x = down_conv(x)
            skip_connections.append(x)
            x = down_sample(x)
        
        # Pass the output of the contacting path through the bottleneck layer
        x = self.bottleneck(x)

        # Pass the output of the bottleneck layer through the expanding path
        skip_connections = skip_connections[::-1] # Reverse the list of skip connections
        for up_sample, concat, up_conv in zip(self.up_samples, self.concat, self.up_conv):
            x = up_sample(x)
            x = concat(x, skip_connections.pop(0)) # Remove the first element from the list of skip connections
            x = up_conv(x)
        
        # Pass the output of the expanding path through the final convolution layer
        x = self.final_conv(x)
        return x
        

### Sanity check for the model

In [95]:
import torchsummary
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(in_channels = 3, out_channels = 1).to(device)
dummy_input = torch.randn((1, 3, 572, 572)).to(device)
mask = model(dummy_input)
mask.shape
torchsummary.summary(model, input_size = (3, 572, 572))



----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 572, 572]           1,792
              ReLU-2         [-1, 64, 572, 572]               0
            Conv2d-3         [-1, 64, 572, 572]          36,928
              ReLU-4         [-1, 64, 572, 572]               0
double_convolution-5         [-1, 64, 572, 572]               0
         MaxPool2d-6         [-1, 64, 286, 286]               0
       down_sample-7         [-1, 64, 286, 286]               0
            Conv2d-8        [-1, 128, 286, 286]          73,856
              ReLU-9        [-1, 128, 286, 286]               0
           Conv2d-10        [-1, 128, 286, 286]         147,584
             ReLU-11        [-1, 128, 286, 286]               0
double_convolution-12        [-1, 128, 286, 286]               0
        MaxPool2d-13        [-1, 128, 143, 143]               0
      down_sample-14        [-1, 128, 

### Create a UNet model with skip connections

This section implemets a U-Net model that incorporates some of the recent advances in deep learning, that is:
- [Residual networks](https://arxiv.org/abs/1512.03385): The key idea behind ResNets is the use of residual connections, which allow for the direct propagation of information through the network without being modified by the layers in between. The residual connection is achieved by adding the input of a layer to its output, so that the output of the layer becomes: `y = f(x) + x`. **The shortcut connection skips one or more layers, with the change in dimensions, if any, compensated with a 1x1 convolutional layer.** 

- [Group normalization](https://arxiv.org/abs/1803.08494): works by normalizing the activations of a layer across groups of channels instead of the entire batch. See more explanations and comparisons between different normalizations in [this blog post](https://gaoxiangluo.github.io/2021/08/01/Group-Norm-Batch-Norm-Instance-Norm-which-is-better/).

- [Swish activation function](https://arxiv.org/abs/1710.05941): is a self-gated activation function that is defined as `f(x) = x * sigmoid(x)`. It has been shown to outperform ReLU and other activation functions on deeper models across a number of challenging datasets.

- [Attention gated Unets](https://arxiv.org/abs/1804.03999): is a modification of the U-Net architecture that uses attention gates to selectively focus on the most relevant parts of the input image. The attention gates are implemented as a 1x1 convolutional layer that learns a weight for each channel in the input. The output of the attention gate is then multiplied with the input to the layer, so that the output of the layer becomes: `y = f(x) * sigmoid(x)`. The attention gates are applied to the output of the contracting path and the input of the expanding path.

In [201]:
import torch
import torch.nn as nn
import torchvision

# Define a Residual block
class residual_block(nn.Module):
    """
    This class implements a residual block which consists of two convolution layers with group normalization
    """
    def __init__(self, in_channels, out_channels, n_groups = 8):
        super().__init__()
        # First convolution layer
        self.first_conv = nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 3, padding = 1)
        self.first_norm = nn.GroupNorm(num_groups = n_groups, num_channels = out_channels)
        self.act1 = nn.SiLU() # Swish activation function

        # Second convolution layer
        self.second_conv = nn.Conv2d(in_channels = out_channels, out_channels = out_channels, kernel_size = 3, padding = 1)
        self.second_norm = nn.GroupNorm(num_groups = n_groups, num_channels = out_channels)
        self.act2 = nn.SiLU() # Swish activation function

        # If the number of input channels is not equal to the number of output channels,
        # then use a 1X1 convolution layer to compensate for the difference in dimensions
        # This allows the input to have the same dimensions as the output of the residual block
        if in_channels != out_channels:
            self.shortcut = nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 1)
        else:
            # Pass the input as is
            self.shortcut = nn.Identity()

    # Pass the input through the residual block
    def forward(self, x):
        # Store the input
        input = x

        # Pass input through the first convolution layer
        x = self.act1(self.second_norm(self.first_conv(x)))

        # Pass the output of the first convolution layer through the second convolution layer
        x = self.act2(self.second_norm(self.second_conv(x)))

        # Add the input to the output of the second convolution layer
        # This is the skip connection
        x = x + self.shortcut(input)
        return x

# Implement the DownSample block that occurs after each residual block
class down_sample(nn.Module):
    def __init__(self):
        super().__init__()
        self.max_pool = nn.MaxPool2d(kernel_size = 2, stride = 2)

    # Pass the input through the downsample block
    def forward(self, x):
        x = self.max_pool(x)
        return x

# Implement the UpSample block that occurs in the decoder path/expanding path
class up_sample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        # Convolution transpose layer to upsample the input
        self.up_sample = nn.ConvTranspose2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 2, stride = 2)

    # Pass the input through the upsample block
    def forward(self, x):
        x = self.up_sample(x)
        return x

# Implement the crop and concatenate layer
class crop_and_concatenate(nn.Module):
    def forward(self, upsampled, bypass):
        # Crop the upsampled feature map to match the dimensions of the bypass feature map
        upsampled = torchvision.transforms.functional.resize(upsampled, size = bypass.shape[2:], antialias=True)
        x = torch.cat([upsampled, bypass], dim = 1) # Concatenate along the channel dimension
        return x

# Implement an attention block
class attention_block(nn.Module):
    def __init__(self, skip_channels, gate_channels, inter_channels = None, n_groups = 8):
        super().__init__()

        if inter_channels is None:
            inter_channels = skip_channels // 2

        # Implement W_g i.e the convolution layer that operates on the gate signal
        # Upsample gate signal to be the same size as the skip connection
        self.W_g = up_sample(in_channels = gate_channels, out_channels = skip_channels)
        #self.W_g_norm = nn.GroupNorm(num_groups = n_groups, num_channels = skip_channels)
        #self.W_g_act = nn.SiLU() # Swish activation function

        # Implement W_x i.e the convolution layer that operates on the skip connection
        self.W_x = nn.Conv2d(in_channels = skip_channels, out_channels = inter_channels, kernel_size = 1)
        #self.W_x_norm = nn.GroupNorm(num_groups = n_groups, num_channels = inter_channels)
        #self.W_x_act = nn.SiLU() # Swish activation function

        # Implement phi i.e the convolution layer that operates on the output of W_x + W_g
        self.phi = nn.Conv2d(in_channels = inter_channels, out_channels = 1, kernel_size = 1)
        #self.phi_norm = nn.GroupNorm(num_groups = n_groups, num_channels = 1)
        #self.phi_act = nn.SiLU() # Swish activation function

        # Implement the sigmoid activation function
        self.sigmoid = nn.Sigmoid()
        # Implement the Swish activation function
        self.act = nn.SiLU()

        # Implement final group normalization layer
        self.final_norm = nn.GroupNorm(num_groups = n_groups, num_channels = skip_channels)

    # Pass the input through the attention block
    def forward(self, skip_connection, gate_signal):
        # Upsample the gate signal to match the channels of the skip connection
        gate_signal = self.W_g(gate_signal)
        # Ensure that the sizes of the skip connection and the gate signal match before addition
        if gate_signal.shape[2:] != skip_connection.shape[2:]:
            gate_signal = torchvision.transforms.functional.resize(gate_signal, size = skip_connection.shape[2:], antialias=True)
        # Project to the intermediate channels
        gate_signal = self.W_x(gate_signal)

        # Project the skip connection to the intermediate channels
        skip_signal = self.W_x(skip_connection)

        # Add the skip connection and the gate signal
        add_xg = gate_signal + skip_signal

        # Pass the output of the addition through the activation function
        add_xg = self.act(add_xg)

        # Pass the output of attention through a 1x1 convolution layer to obtain the attention map
        attention_map = self.sigmoid(self.phi(add_xg))

        # Multiply the skip connection with the attention map
        # Perform element-wise multiplication
        skip_connection = torch.mul(skip_connection, attention_map)

        skip_connection = nn.Conv2d(in_channels = skip_connection.shape[1], out_channels = skip_connection.shape[1], kernel_size = 1)(skip_connection)
        skip_connection = self.act(self.final_norm(skip_connection))

        return skip_connection


## Implement a residual attention U-Net
class ResidualAttentionUnet(nn.Module):
    def __init__(self, in_channels, out_channels, n_groups = 4, n_channels = [64, 128, 256, 512, 1024]):
        super().__init__()

        # Define the contracting path: residual blocks followed by downsampling
        self.down_conv = nn.ModuleList(residual_block(in_chans, out_chans) for in_chans, out_chans in
                                       [(in_channels, n_channels[0]), (n_channels[0], n_channels[1]), (n_channels[1], n_channels[2]), (n_channels[2], n_channels[3])])
        self.down_samples = nn.ModuleList(down_sample() for _ in range(4))

        # Define the bottleneck residual block
        self.bottleneck = residual_block(n_channels[3], n_channels[4])


        # Define the attention blocks
        self.attention_blocks = nn.ModuleList(attention_block(skip_channels = residuals_chans, gate_channels = gate_chans) for gate_chans, residuals_chans in
                                              [(n_channels[4], n_channels[3]), (n_channels[3], n_channels[2]), (n_channels[2], n_channels[1]), (n_channels[1], n_channels[0])])


        # Define the expanding path: upsample blocks, followed by crop and concatenate, followed by residual blocks
        self.upsamples = nn.ModuleList(up_sample(in_chans, out_chans) for in_chans, out_chans in
                                       [(n_channels[4], n_channels[3]), (n_channels[3], n_channels[2]), (n_channels[2], n_channels[1]), (n_channels[1], n_channels[0])])
        
        self.concat = nn.ModuleList(crop_and_concatenate() for _ in range(4))

        self.up_conv = nn.ModuleList(residual_block(in_chans, out_chans) for in_chans, out_chans in
                                     [(n_channels[4], n_channels[3]), (n_channels[3], n_channels[2]), (n_channels[2], n_channels[1]), (n_channels[1], n_channels[0])])
        
        # Final 1X1 convolution layer to produce the output segmentation map:
        # The primary purpose of 1x1 convolutions is to transform the channel dimension of the feature map,
        # while leaving the spatial dimensions unchanged.
        self.final_conv = nn.Conv2d(in_channels = n_channels[0] , out_channels = out_channels, kernel_size = 1)

    # Pass the input through the residual attention U-Net
    def forward(self, x):
        # Store the skip connections
        skip_connections = []
        # # Store the gate signals
        # gate_signals = []

        # Pass the input through the contracting path
        for down_conv, down_sample in zip(self.down_conv, self.down_samples):
            x = down_conv(x)
            skip_connections.append(x)
            #gate_signals.append(x)
            x = down_sample(x)
        
        # Pass the output of the contracting path through the bottleneck
        x = self.bottleneck(x)
        skip_connections.append(x)

        # Attention on the residual connections
        #skip_connections = skip_connections[::-1]
        n = len(skip_connections)
        indices = [(n - 1 - i, n - 2 - i) for i in range(n - 1)]
        attentions = []
        for i, g_x in enumerate(indices):
            g_gate = g_x[0]
            x_residual = g_x[1]
            attn = self.attention_blocks[i](skip_connections[x_residual], skip_connections[g_gate])
            attentions.append(attn)

        #attentions = attentions[::-1]
    
        # Pass the output of the attention blocks through the expanding path
        for up_sample, concat, up_conv in zip(self.upsamples, self.concat, self.up_conv):
            x = up_sample(x)
            x = concat(x, attentions.pop(0))
            x = up_conv(x)

        # Pass the output of the expanding path through the final convolution layer
        x = self.final_conv(x)
        return x

In [204]:
## Sanity check
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ResidualAttentionUnet(in_channels = 3, out_channels = 1).to(device)
x = torch.randn((1, 3, 572, 572)).to(device)
mask = model(x)
mask.shape

torch.Size([1, 1, 572, 572])

In [196]:
for i, g_x in enumerate(indices):
    print(i, g_x[0], g_x[1])

0 4 3
1 3 2
2 2 1
3 1 0


In [193]:
attention_blocks

ModuleList(
  (0): attention_block(
    (W_g): up_sample(
      (up_sample): ConvTranspose2d(1024, 512, kernel_size=(2, 2), stride=(2, 2))
    )
    (W_x): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
    (phi): Conv2d(256, 1, kernel_size=(1, 1), stride=(1, 1))
    (sigmoid): Sigmoid()
    (act): SiLU()
    (final_norm): GroupNorm(32, 512, eps=1e-05, affine=True)
  )
  (1): attention_block(
    (W_g): up_sample(
      (up_sample): ConvTranspose2d(512, 256, kernel_size=(2, 2), stride=(2, 2))
    )
    (W_x): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))
    (phi): Conv2d(128, 1, kernel_size=(1, 1), stride=(1, 1))
    (sigmoid): Sigmoid()
    (act): SiLU()
    (final_norm): GroupNorm(32, 256, eps=1e-05, affine=True)
  )
  (2): attention_block(
    (W_g): up_sample(
      (up_sample): ConvTranspose2d(256, 128, kernel_size=(2, 2), stride=(2, 2))
    )
    (W_x): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
    (phi): Conv2d(64, 1, kernel_size=(1, 1), stride=(1, 1))
    (

In [154]:
n_channels = [64, 128, 256, 512, 1024]
attention_blocks = nn.ModuleList(attention_block(skip_channels = residuals_chans, gate_channels = gate_chans) for gate_chans, residuals_chans in
                                            [(n_channels[4], n_channels[3]), (n_channels[3], n_channels[2]), (n_channels[2], n_channels[1]), (n_channels[1], n_channels[0])])

attention_blocks[3]

attention_block(
  (W_g): up_sample(
    (up_sample): ConvTranspose2d(128, 64, kernel_size=(2, 2), stride=(2, 2))
  )
  (W_x): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))
  (phi): Conv2d(32, 1, kernel_size=(1, 1), stride=(1, 1))
  (sigmoid): Sigmoid()
  (act): SiLU()
  (final_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
)

In [141]:
n = 5
indices = [(n-1-i, n-2-i) for i in range(n-1)]
indices

[(4, 3), (3, 2), (2, 1), (1, 0)]

In [140]:
x = [1, 2, 3]
y = [4, 5, 6]
for a, b in enumerate(zip(x, y)):
    print(a, b[0], b[1])

0 1 4
1 2 5
2 3 6


In [199]:
# Test attention block
att_b = attention_block(skip_channels=512, gate_channels = 1024)
x = torch.randn(1, 512, 4, 4)
g = torch.randn(1, 1024, 2, 2)
y = att_b(x, g)
y.shape

torch.Size([1, 512, 4, 4])

In [147]:
a = [0, 1, 2, 3, 4]
for i, j in indices:
    print(a[i], a[j])

4 3
3 2
2 1
1 0


In [109]:
skip_connection = torch.randn((1, 512, 71, 71))
skip_connection_shape = skip_connection.shape
input = skip_connection
gate_signal = torch.randn((1, 1024, 35, 35))
# Upsample gate signal
upp = up_sample(in_channels=1024, out_channels=512)
gate_signal = upp(gate_signal)
if gate_signal.shape[2:] != skip_connection_shape[2:]:
    gate_signal = torchvision.transforms.functional.resize(gate_signal, size = skip_connection_shape[2:], antialias=True)
gate_signal.shape



torch.Size([1, 512, 71, 71])

In [106]:
gate_signal.shape[2:] == gate_signal.shape[2:]

True

In [None]:


# Resize the skip connection to match the dimensions of the gate_signal
skip_connection = nn.Conv2d(in_channels = 512, out_channels = 512, kernel_size = 1)(skip_connection)
skip_connection.shape
#skip_connection = torchvision.transforms.functional.resize(skip_connection, size = gate_signal.shape[2:])


In [None]:
# Match the channel dimensions of the skip connection and the gate signal
skip_connection = nn.Conv2d(in_channels = 512, out_channels = 1024, kernel_size = 1)(skip_connection)

# Add the skip connection and the gate signal
attn = gate_signal + skip_connection

# Apply a non-linear activation function
attn = nn.ReLU()(attn)

# Psi
psi = nn.Conv2d(in_channels = 1024, out_channels = 1, kernel_size = 1)(attn)

# Apply a sigmoid activation function to transform the output to a range between 0 and 1
psi = nn.Sigmoid()(psi)

# Upsample the output to the size of the skip connection
upsampled_psi = torch.nn.UpsamplingNearest2d(size = skip_connection_shape[2:])(psi)

# Multiply the skip connection with the upsampled output
attn = torch.mul(upsampled_psi, input)
attn.shape

In [102]:
attn.shape

torch.Size([1, 512, 64, 64])

In [89]:
skip_connection.shape[2:]

torch.Size([32, 32])

In [99]:
input.shape

torch.Size([1, 512, 64, 64])

In [100]:
upsampled_psi.shape

torch.Size([1, 1, 64, 64])

In [104]:
upsampled_psi * torch.randn((1, 512, 72, 72))

RuntimeError: The size of tensor a (64) must match the size of tensor b (72) at non-singleton dimension 3

In [None]:
skip_connection

In [80]:
attn

tensor([[[[-0.5226, -0.5840,  1.4101,  ...,  0.2814,  1.1430, -1.2001],
          [ 0.3523, -0.7049,  1.7028,  ...,  0.3190,  1.3144, -0.2240],
          [ 0.0247,  0.2418,  0.2530,  ...,  0.0825, -0.6044,  1.6606],
          ...,
          [ 1.3069,  0.4218,  0.2553,  ...,  1.0103,  1.9774, -0.4712],
          [-0.5925,  0.7382, -0.3403,  ...,  0.6515, -0.1832,  0.0458],
          [-0.5867, -1.6009, -1.1576,  ..., -0.0732,  0.6630,  1.6776]],

         [[ 0.1562,  0.1886,  0.6047,  ..., -0.2483,  0.4273,  0.1439],
          [ 0.1227,  0.2132, -0.6852,  ..., -1.6059,  0.4704, -1.9172],
          [ 0.8583, -2.0347, -1.0380,  ..., -0.7814,  0.2466,  1.1290],
          ...,
          [-1.1470, -1.5456,  0.8223,  ...,  0.2230,  0.9603, -0.3140],
          [ 0.1218, -2.3055,  1.5006,  ...,  1.8543,  0.9587, -0.8977],
          [ 1.4204,  1.6201,  0.8486,  ...,  0.6392, -1.4300, -1.0203]],

         [[-0.1478,  1.3004, -0.0982,  ..., -1.8086,  1.1026,  0.3511],
          [ 0.4592, -0.6684,  

In [69]:
y = torch.randn(2,2)
b= nn.Identity()
xx = b(y)
y, xx

(tensor([[-1.2981,  0.4661],
         [ 0.8026,  0.5173]]),
 tensor([[-1.2981,  0.4661],
         [ 0.8026,  0.5173]]))

In [79]:
# See how 1x1 works
m = nn.Conv2d(512, 1024, kernel_size=1)
input = torch.randn(1, 512, 32, 32)
out = m(input)
ss = out + torch.randn(1, 1024, 32, 32)
ss.shape	

torch.Size([1, 1024, 32, 32])

In [73]:
out.shape

torch.Size([1, 1024, 32, 32])

In [57]:
mask = model(dummy_input)
mask.shape

torch.Size([10, 1, 32, 32])

In [48]:
#@title Unet Definition

import math
from typing import Optional, Tuple, Union, List

import torch
from torch import nn

# A fancy activation function
class Swish(nn.Module):
    """
    ### Swish actiavation function
    $$x \cdot \sigma(x)$$
    """

    def forward(self, x):
        return x * torch.sigmoid(x)

# The time embedding 
class TimeEmbedding(nn.Module):
    """
    ### Embeddings for $t$
    """

    def __init__(self, n_channels: int):
        """
        * `n_channels` is the number of dimensions in the embedding
        """
        super().__init__()
        self.n_channels = n_channels
        # First linear layer
        self.lin1 = nn.Linear(self.n_channels // 4, self.n_channels)
        # Activation
        self.act = Swish()
        # Second linear layer
        self.lin2 = nn.Linear(self.n_channels, self.n_channels)

    def forward(self, t: torch.Tensor):
        # Create sinusoidal position embeddings
        # [same as those from the transformer](../../transformers/positional_encoding.html)
        #
        # \begin{align}
        # PE^{(1)}_{t,i} &= sin\Bigg(\frac{t}{10000^{\frac{i}{d - 1}}}\Bigg) \\
        # PE^{(2)}_{t,i} &= cos\Bigg(\frac{t}{10000^{\frac{i}{d - 1}}}\Bigg)
        # \end{align}
        #
        # where $d$ is `half_dim`
        half_dim = self.n_channels // 8
        emb = math.log(10_000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
        emb = t[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=1)

        # Transform with the MLP
        emb = self.act(self.lin1(emb))
        emb = self.lin2(emb)

        #
        return emb

# Residual blocks include 'skip' connections
class ResidualBlock(nn.Module):
    """
    ### Residual block
    A residual block has two convolution layers with group normalization.
    Each resolution is processed with two residual blocks.
    """

    def __init__(self, in_channels: int, out_channels: int, time_channels: int, n_groups: int = 32):
        """
        * `in_channels` is the number of input channels
        * `out_channels` is the number of input channels
        * `time_channels` is the number channels in the time step ($t$) embeddings
        * `n_groups` is the number of groups for [group normalization](../../normalization/group_norm/index.html)
        """
        super().__init__()
        # Group normalization and the first convolution layer
        self.norm1 = nn.GroupNorm(n_groups, in_channels)
        self.act1 = Swish()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), padding=(1, 1))

        # Group normalization and the second convolution layer
        self.norm2 = nn.GroupNorm(n_groups, out_channels)
        self.act2 = Swish()
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=(3, 3), padding=(1, 1))

        # If the number of input channels is not equal to the number of output channels we have to
        # project the shortcut connection
        if in_channels != out_channels:
            self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1))
        else:
            self.shortcut = nn.Identity()

        # Linear layer for time embeddings
        self.time_emb = nn.Linear(time_channels, out_channels)

    def forward(self, x: torch.Tensor, t: torch.Tensor):
        """
        * `x` has shape `[batch_size, in_channels, height, width]`
        * `t` has shape `[batch_size, time_channels]`
        """
        # First convolution layer
        h = self.conv1(self.act1(self.norm1(x)))
        # Add time embeddings
        h += self.time_emb(t)[:, :, None, None]
        # Second convolution layer
        h = self.conv2(self.act2(self.norm2(h)))

        # Add the shortcut connection and return
        return h + self.shortcut(x)

# Ahh yes, magical attention...
class AttentionBlock(nn.Module):
    """
    ### Attention block
    This is similar to [transformer multi-head attention](../../transformers/mha.html).
    """

    def __init__(self, n_channels: int, n_heads: int = 1, d_k: int = None, n_groups: int = 32):
        """
        * `n_channels` is the number of channels in the input
        * `n_heads` is the number of heads in multi-head attention
        * `d_k` is the number of dimensions in each head
        * `n_groups` is the number of groups for [group normalization](../../normalization/group_norm/index.html)
        """
        super().__init__()

        # Default `d_k`
        if d_k is None:
            d_k = n_channels
        # Normalization layer
        self.norm = nn.GroupNorm(n_groups, n_channels)
        # Projections for query, key and values
        self.projection = nn.Linear(n_channels, n_heads * d_k * 3)
        # Linear layer for final transformation
        self.output = nn.Linear(n_heads * d_k, n_channels)
        # Scale for dot-product attention
        self.scale = d_k ** -0.5
        #
        self.n_heads = n_heads
        self.d_k = d_k

    def forward(self, x: torch.Tensor, t: Optional[torch.Tensor] = None):
        """
        * `x` has shape `[batch_size, in_channels, height, width]`
        * `t` has shape `[batch_size, time_channels]`
        """
        # `t` is not used, but it's kept in the arguments because for the attention layer function signature
        # to match with `ResidualBlock`.
        _ = t
        # Get shape
        batch_size, n_channels, height, width = x.shape
        # Change `x` to shape `[batch_size, seq, n_channels]`
        x = x.view(batch_size, n_channels, -1).permute(0, 2, 1)
        # Get query, key, and values (concatenated) and shape it to `[batch_size, seq, n_heads, 3 * d_k]`
        qkv = self.projection(x).view(batch_size, -1, self.n_heads, 3 * self.d_k)
        # Split query, key, and values. Each of them will have shape `[batch_size, seq, n_heads, d_k]`
        q, k, v = torch.chunk(qkv, 3, dim=-1)
        # Calculate scaled dot-product $\frac{Q K^\top}{\sqrt{d_k}}$
        attn = torch.einsum('bihd,bjhd->bijh', q, k) * self.scale
        # Softmax along the sequence dimension $\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)$
        attn = attn.softmax(dim=1)
        # Multiply by values
        res = torch.einsum('bijh,bjhd->bihd', attn, v)
        # Reshape to `[batch_size, seq, n_heads * d_k]`
        res = res.view(batch_size, -1, self.n_heads * self.d_k)
        # Transform to `[batch_size, seq, n_channels]`
        res = self.output(res)

        # Add skip connection
        res += x

        # Change to shape `[batch_size, in_channels, height, width]`
        res = res.permute(0, 2, 1).view(batch_size, n_channels, height, width)

        #
        return res


class DownBlock(nn.Module):
    """
    ### Down block
    This combines `ResidualBlock` and `AttentionBlock`. These are used in the first half of U-Net at each resolution.
    """

    def __init__(self, in_channels: int, out_channels: int, time_channels: int, has_attn: bool):
        super().__init__()
        self.res = ResidualBlock(in_channels, out_channels, time_channels)
        if has_attn:
            self.attn = AttentionBlock(out_channels)
        else:
            self.attn = nn.Identity()

    def forward(self, x: torch.Tensor, t: torch.Tensor):
        x = self.res(x, t)
        x = self.attn(x)
        return x


class UpBlock(nn.Module):
    """
    ### Up block
    This combines `ResidualBlock` and `AttentionBlock`. These are used in the second half of U-Net at each resolution.
    """

    def __init__(self, in_channels: int, out_channels: int, time_channels: int, has_attn: bool):
        super().__init__()
        # The input has `in_channels + out_channels` because we concatenate the output of the same resolution
        # from the first half of the U-Net
        self.res = ResidualBlock(in_channels + out_channels, out_channels, time_channels)
        if has_attn:
            self.attn = AttentionBlock(out_channels)
        else:
            self.attn = nn.Identity()

    def forward(self, x: torch.Tensor, t: torch.Tensor):
        x = self.res(x, t)
        x = self.attn(x)
        return x


class MiddleBlock(nn.Module):
    """
    ### Middle block
    It combines a `ResidualBlock`, `AttentionBlock`, followed by another `ResidualBlock`.
    This block is applied at the lowest resolution of the U-Net.
    """

    def __init__(self, n_channels: int, time_channels: int):
        super().__init__()
        self.res1 = ResidualBlock(n_channels, n_channels, time_channels)
        self.attn = AttentionBlock(n_channels)
        self.res2 = ResidualBlock(n_channels, n_channels, time_channels)

    def forward(self, x: torch.Tensor, t: torch.Tensor):
        x = self.res1(x, t)
        x = self.attn(x)
        x = self.res2(x, t)
        return x


class Upsample(nn.Module):
    """
    ### Scale up the feature map by $2 \times$
    """

    def __init__(self, n_channels):
        super().__init__()
        self.conv = nn.ConvTranspose2d(n_channels, n_channels, (4, 4), (2, 2), (1, 1))

    def forward(self, x: torch.Tensor, t: torch.Tensor):
        # `t` is not used, but it's kept in the arguments because for the attention layer function signature
        # to match with `ResidualBlock`.
        _ = t
        return self.conv(x)


class Downsample(nn.Module):
    """
    ### Scale down the feature map by $\frac{1}{2} \times$
    """

    def __init__(self, n_channels):
        super().__init__()
        self.conv = nn.Conv2d(n_channels, n_channels, (3, 3), (2, 2), (1, 1))

    def forward(self, x: torch.Tensor, t: torch.Tensor):
        # `t` is not used, but it's kept in the arguments because for the attention layer function signature
        # to match with `ResidualBlock`.
        _ = t
        return self.conv(x)

# The core class definition (aka the important bit)
class UNet(nn.Module):
    """
    ## U-Net
    """

    def __init__(self, image_channels: int = 3, n_channels: int = 64,
                 ch_mults: Union[Tuple[int, ...], List[int]] = (1, 2, 2, 4),
                 is_attn: Union[Tuple[bool, ...], List[int]] = (False, False, True, True),
                 n_blocks: int = 2):
        """
        * `image_channels` is the number of channels in the image. $3$ for RGB.
        * `n_channels` is number of channels in the initial feature map that we transform the image into
        * `ch_mults` is the list of channel numbers at each resolution. The number of channels is `ch_mults[i] * n_channels`
        * `is_attn` is a list of booleans that indicate whether to use attention at each resolution
        * `n_blocks` is the number of `UpDownBlocks` at each resolution
        """
        super().__init__()

        # Number of resolutions
        n_resolutions = len(ch_mults)

        # Project image into feature map
        self.image_proj = nn.Conv2d(image_channels, n_channels, kernel_size=(3, 3), padding=(1, 1))

        # Time embedding layer. Time embedding has `n_channels * 4` channels
        self.time_emb = TimeEmbedding(n_channels * 4)

        # #### First half of U-Net - decreasing resolution
        down = []
        # Number of channels
        out_channels = in_channels = n_channels
        # For each resolution
        for i in range(n_resolutions):
            # Number of output channels at this resolution
            out_channels = in_channels * ch_mults[i]
            # Add `n_blocks`
            for _ in range(n_blocks):
                down.append(DownBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
                in_channels = out_channels
            # Down sample at all resolutions except the last
            if i < n_resolutions - 1:
                down.append(Downsample(in_channels))

        # Combine the set of modules
        self.down = nn.ModuleList(down)

        # Middle block
        self.middle = MiddleBlock(out_channels, n_channels * 4, )

        # #### Second half of U-Net - increasing resolution
        up = []
        # Number of channels
        in_channels = out_channels
        # For each resolution
        for i in reversed(range(n_resolutions)):
            # `n_blocks` at the same resolution
            out_channels = in_channels
            for _ in range(n_blocks):
                up.append(UpBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
            # Final block to reduce the number of channels
            out_channels = in_channels // ch_mults[i]
            up.append(UpBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
            in_channels = out_channels
            # Up sample at all resolutions except last
            if i > 0:
                up.append(Upsample(in_channels))

        # Combine the set of modules
        self.up = nn.ModuleList(up)

        # Final normalization and convolution layer
        self.norm = nn.GroupNorm(8, n_channels)
        self.act = Swish()
        self.final = nn.Conv2d(in_channels, image_channels, kernel_size=(3, 3), padding=(1, 1))

    def forward(self, x: torch.Tensor, t: torch.Tensor):
        """
        * `x` has shape `[batch_size, in_channels, height, width]`
        * `t` has shape `[batch_size]`
        """

        # Get time-step embeddings
        t = self.time_emb(t)

        # Get image projection
        x = self.image_proj(x)

        # `h` will store outputs at each resolution for skip connection
        h = [x]
        # First half of U-Net
        for m in self.down:
            x = m(x, t)
            h.append(x)

        # Middle (bottom)
        x = self.middle(x, t)

        # Second half of U-Net
        for m in self.up:
            if isinstance(m, Upsample):
                x = m(x, t)
            else:
                # Get the skip connection from first half of U-Net and concatenate
                s = h.pop()
                x = torch.cat((x, s), dim=1)
                #
                x = m(x, t)

        # Final normalization and convolution
        return self.final(self.act(self.norm(x)))

In [50]:
# Let's see it in action on dummy data:

# A dummy batch of 10 3-channel 32px images
x = torch.randn(1, 3, 572, 572)

# 't' - what timestep are we on
t = torch.tensor([50.], dtype=torch.long)

# Define the unet model
unet = UNet()

# The foreward pass (takes both x and t)
model_output = unet(x, t)

# The output shape matches the input.
model_output.shape

RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 144 but got size 143 for tensor number 1 in the list.

In [113]:
from torch import nn
class AttentionBlock(nn.Module):
    def __init__(self, f_g, f_l, f_int):
        super().__init__()
        
        self.w_g = nn.Sequential(
                                nn.Conv2d(f_g, f_int,
                                            kernel_size=1, stride=1,
                                            padding=0, bias=True),
                                nn.BatchNorm2d(f_int)
        )
        
        self.w_x = nn.Sequential(
                                nn.Conv2d(f_l, f_int,
                                            kernel_size=1, stride=1,
                                            padding=0, bias=True),
                                nn.BatchNorm2d(f_int)
        )
        
        self.psi = nn.Sequential(
                                nn.Conv2d(f_int, 1,
                                            kernel_size=1, stride=1,
                                            padding=0,  bias=True),
                                nn.BatchNorm2d(1),
                                nn.Sigmoid(),
        )
        
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, g, x):
        g1 = self.w_g(g)
        x1 = self.w_x(x)
        psi = self.relu(g1+x1)
        psi = self.psi(psi)
        
        return psi*x

In [119]:
f_g=512
f_l=512
f_int = 256
#  decoder + concat
# d5 = self.up5(x5)
# x4 = self.att5(g=d5, x=x4)
g = torch.randn(1, 512, 28, 28)
x = torch.randn(1, 512, 28, 28)
att5 = AttentionBlock(f_g, f_l, f_int)
att5(g, x).shape

torch.Size([1, 512, 28, 28])

In [131]:
g1 = nn.Conv2d(f_g, f_int, kernel_size=1, stride=1, padding=0, bias=True)(g)
g1 = nn.BatchNorm2d(f_int)(g1)
x1 = nn.Conv2d(f_l, f_int, kernel_size=1, stride=1, padding=0, bias=True)(x)
x1 = nn.BatchNorm2d(f_int)(x1)
psi = nn.ReLU(inplace=True)(g1+x1)
psi = nn.Conv2d(f_int, 1, kernel_size=1, stride=1, padding=0,  bias=True)(psi)
psi = nn.BatchNorm2d(1)(psi)
psi = nn.Sigmoid()(psi)
out = torch.mul(psi, x)
out.shape

torch.Size([1, 512, 28, 28])

In [128]:
x.shape

torch.Size([1, 512, 28, 28])

In [129]:
psi.shape

torch.Size([1, 1, 28, 28])

In [116]:
att5

AttentionBlock(
  (w_g): Sequential(
    (0): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (w_x): Sequential(
    (0): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (psi): Sequential(
    (0): Conv2d(256, 1, kernel_size=(1, 1), stride=(1, 1))
    (1): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): Sigmoid()
  )
  (relu): ReLU(inplace=True)
)