In [None]:
import numpy as np

#Building the model using Pytorch
import torch
import torch.nn as nn
#Additional torch stuff, like loss functions:
from torch.nn import functional as F

import tensorflow as tf # To get the dataset
from tensorflow.keras import layers, Model

#For train test split
from sklearn.model_selection import train_test_split

#For plotting
import matplotlib.pyplot as plt

# Dataset and Pre-processing

I'm using the MNIST dataset, and loading it using tensorflow's datasets package.
You can obviously use any other dataset, but it will affect a few key things:
- n_classes when we're defining the model. It's currently set to 10, since there are 10 classes (digits) in MNIST, but will vary with dataset.
- Loss. This, as well as the above will change if you're not classifying data. The loss *'criterion'* can be swapped for BCELoss (Binary cross-entropy loss as opposed to normal Cross-Entropy loss). Also, n_classes should be set to 1.
- Potentially sigmoid(x) instead of softmax(x) at the end of the model.

In [None]:
# Load the MNIST dataset
(x_train, y_train), (_, _) = tf.keras.datasets.mnist.load_data()

# Preprocess the data
x_train = x_train.astype("float32") / 255.0  # Normalize to [0, 1]
x_train = np.expand_dims(x_train, axis=-1)   # Add channel dimension

# Create masks (e.g., binary segmentation of digits)
# Here, we'll use the digits as masks for demonstration purposes.
y_train = (y_train[..., None] > 0).astype("float32")  # Convert labels to binary masks

# Train-test split
X_train, X_test, Y_train, Y_test = train_test_split(x_train, y_train, test_size=0.2, random_state=42)


Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
[1m11490434/11490434[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step


In [None]:
y_train[-1]

array([1.], dtype=float32)

In [None]:
from torch.utils.data import DataLoader, TensorDataset
# Convert NumPy arrays to PyTorch tensors
X_train_tensor = torch.tensor(X_train, dtype=torch.float32).permute(0, 3, 1, 2)  # (N, H, W, C) -> (N, C, H, W)
X_test_tensor = torch.tensor(X_test, dtype=torch.float32).permute(0, 3, 1, 2)
# Ensure Y_train and Y_test have 4 dimensions
# Y_train_tensor = torch.tensor(Y_train, dtype=torch.float32).unsqueeze(1)  # Add channel dimension
# Y_test_tensor = torch.tensor(Y_test, dtype=torch.float32).unsqueeze(1)    # Add channel dimension

Y_train_tensor = torch.tensor(Y_train, dtype=torch.long)  # No unsqueeze for labels
Y_test_tensor = torch.tensor(Y_test, dtype=torch.long)


# Create dataloaders
batch_size = 32
train_dataset = TensorDataset(X_train_tensor, Y_train_tensor)
test_dataset = TensorDataset(X_test_tensor, Y_test_tensor)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Model

A convolutional neural network essentially is a specialised case of a neural network for image analysis. Since the number of weights would grow exponentially with dimensions of the image in a typical fully connected neural net, CNNs were developed. CNNs use a number of filters, which slice up the image into overlapping chunks and analyse each as one unit.
The CNN will compute the output of neurons that are connected to local regions in the input, each computing a dot product between their weights and a small region they are connected to in the input volume. This may result in volume such as 32x32x12 if we decided to use 12 filters on a say 32x32x3 image.

A ReLU or other activation function will likely follow, integrating some non-linearity, but leaving the dimensions unchanged

An example of a basic CONVNET would go something like
[Input] - *[Conv] - [ReLU] -[Conv] - [ReLU]* - [Pool] - *[Conv] - [ReLU] - [Conv] - [ReLU] *- [Pool] - *[Conv] - [ReLU] - [Conv] - [ReLU]* - [Pool] -- [FC] - Output

The POOL operation downsamples the image along width and height, so 36x36x12 will become 18x18x12 for instance.
A Fully Connected (FC) layer is used at the end to compute scores for instance.

A UNet model contains ~ 4 downblocks, a bottleneck and 4 upblocks![UNet arch.PNG](https://miro.medium.com/v2/resize:fit:933/1*hzufVrlQYVs_Lj-vHw8SgQ.png)

**How do we pick parameters?**
- Obviously it varies with the dataset, which is why I've kept the parameters as global variables that can be changed neatly outside of the confusing looking model code
- However, there are some guidelines that we can keep in mind:
  - Starting with a small number of filters during the downward steps and increasing until it reaches the bottleneck
  - Using powers of two for these filters
  - Matching the filter sizes, in reverse, for the upward steps
  - For kernel size, 3×3 is the most common choice for convolution layers in UNet. It provides a balance between capturing local details and computational efficiency. 1×1 convolutions can be used at the output layer for final predictions.

In [None]:
#Initialising some constants:
p_kernel = (2,2)
p_stride = (2,2)
strides = 1

# filters = [16, 32, 64, 128, 256]
input_ch = 1
filters = [64, 128, 256, 512, 1024]
kernel = (3,3)
learning_rate = 1e-4

Usually the max pool kernel and stride are just fixed during the creation of down_block, but I want the model to be a tad more modular, so that I can change any variable from the initialising section alone.

You can also implement batchnorms in between like so:

```
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.double_conv(x)
```



In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            #nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            #nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.double_conv(x)

**Important Notes about the blocks**
****
1. The 'down' blocks include two conv2ds. Each of them contain a convolution layer followed by a relu activation layer.
2. Now, since the 'filters' variable is the same for both, it means that the dimension of the output of the first will be the same as that of the second - e.g: (32,32,3) --conv2d--> (32,32,filters) --conv2d--> (32,32,filters)
3. This does not (as I initially thought) mean that the second convolution is meaningless, since it can extract further details from the basic features extracted by the first layer. This is the significance of 2 relus.
4. Additionally, the receptive field increases over successive convolutions even if filters remain the same, because the convolution is working on the output of the first. i.e. Each point in the output of the first conv. corresponds to a 3x3 kernel on the source data
  ****
5. The 'up layer'

In [None]:
class DownBlock(nn.Module):
  def __init__(self, in_ch, out_ch):
    super().__init__()
    self.double_conv = DoubleConv(in_ch,out_ch)
    self.down_sample = nn.MaxPool2d(2)

  def forward(self,x):
    skip_out = self.double_conv(x)
    down_out = self.down_sample(skip_out)
    return skip_out, down_out

class UpBlock(nn.Module):
  def __init__(self,in_ch, out_ch):
    super().__init__()
    self.up_sample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
    self.double_conv = DoubleConv(in_ch,out_ch)

  def forward(self,x,skip_connection):
    x = self.up_sample(x)
    '''
    if x.shape[-2:] != skip_input.shape[-2:]:
      diff_h = skip_input.size(2) - x.size(2)
      diff_w = skip_input.size(3) - x.size(3)
      skip_input = skip_input[:, :, diff_h // 2:-(diff_h - diff_h // 2), diff_w // 2:-(diff_w - diff_w // 2)]
    x = torch.cat([x, skip_input], dim=1)
    '''
    print("upblock_xshape:", x.shape, "skip connection_shape:",skip_connection.shape)
    x = torch.cat([x,skip_connection],dim=1)
    x = self.double_conv(x)
    return x

In [None]:
class UNet(nn.Module):
    def __init__(self, out_classes=10):
        super().__init__()
        filters = [64, 128, 256, 512, 1024]
        # Downsampling Path
        self.down_conv1 = DownBlock(1,64)
        self.down_conv2 = DownBlock(64, 128)
        self.down_conv3 = DownBlock(128, 256)
        self.down_conv4 = DownBlock(256, 512)
        # Bottleneck
        self.double_conv = DoubleConv(512, 1024)
        # Upsampling Path
        self.up_conv4 = UpBlock(512 + 1024, 512)
        self.up_conv3 = UpBlock(256 + 512, 256)
        self.up_conv2 = UpBlock(128 + 256, 128)
        self.up_conv1 = UpBlock(128 + 64, 64)
        # Final Convolution
        self.conv_last = nn.Conv2d(64, out_classes, kernel_size=1)

    def forward(self, x):
        x, skip1_out = self.down_conv1(x)
        print("x shape ",x.shape," skip1 ",skip1_out.shape)
        x, skip2_out = self.down_conv2(x)
        x, skip3_out = self.down_conv3(x)
        x, skip4_out = self.down_conv4(x)
        x = self.double_conv(x)
        x = self.up_conv4(x, skip4_out)
        x = self.up_conv3(x, skip3_out)
        x = self.up_conv2(x, skip2_out)
        x = self.up_conv1(x, skip1_out)
        x = self.conv_last(x)
        return x


# Get UNet model
model = UNet()


# Train

In [None]:
# Define device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cpu


In [None]:
from torch import optim

model = model.to(device)

# Define loss and optimizer
#criterion = nn.BCELoss()  # Binary Cross-Entropy Loss
criterion = nn.CrossEntropyLoss()  # Suitable for multi-class classification

optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
images.shape

torch.Size([32, 1, 28, 28])

In [None]:
from tqdm import tqdm
# Training loop
num_epochs = 5  # Increase for better results
for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    for images, masks in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        images, masks = images.to(device), masks.to(device)

        #Forward pass
        outputs = model(images)  # Shape: (N, 10, H, W)
        loss = criterion(outputs, masks)  # Ground truth masks: (N, H, W)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    avg_train_loss = train_loss / len(train_loader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_train_loss:.4f}")

Epoch 1/5:   0%|          | 0/1500 [00:00<?, ?it/s]

x shape  torch.Size([32, 64, 28, 28])  skip1  torch.Size([32, 64, 14, 14])


Epoch 1/5:   0%|          | 0/1500 [00:17<?, ?it/s]

upblock_xshape: torch.Size([32, 1024, 56, 56]) skip connection_shape: torch.Size([32, 512, 14, 14])





RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 56 but got size 14 for tensor number 1 in the list.

# Testing and visualising

In [None]:
# Testing loop
model.eval()
test_loss = 0
with torch.no_grad():
    for images, masks in tqdm(test_loader, desc="Testing"):
        images, masks = images.to(device), masks.to(device)
        outputs = model(images)
        loss = criterion(outputs, masks)
        test_loss += loss.item()

avg_test_loss = test_loss / len(test_loader)
print(f"Test Loss: {avg_test_loss:.4f}")

In [None]:
# Visualize a few predictions
import matplotlib.pyplot as plt

model.eval()
for i in range(5):  # Visualize 5 examples
    image = X_test_tensor[i:i+1].to(device)
    mask = Y_test_tensor[i].cpu().numpy().squeeze()

    with torch.no_grad():
        pred = model(image).cpu().numpy().squeeze()

    plt.figure(figsize=(12, 4))
    plt.subplot(1, 3, 1)
    plt.title("Input")
    plt.imshow(image.cpu().squeeze(), cmap='gray')

    plt.subplot(1, 3, 2)
    plt.title("Ground Truth")
    plt.imshow(mask, cmap='gray')

    plt.subplot(1, 3, 3)
    plt.title("Prediction")
    plt.imshow(pred, cmap='gray')
    plt.show()
