In [1]:
import os
import logging
import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import torchvision.transforms as T
import torch.optim as optim

# Enable debug-level logging output
logging.basicConfig(level=logging.DEBUG)

In [2]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        # Basic double-convolution block used in UNet
        self.conv = nn.Sequential(
            # First Conv2d: kernel=3, stride=1, padding=1
            # Bias is disabled because BatchNorm already includes a bias term.
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),

            # BatchNorm helps stabilize training and reduces overfitting
            nn.BatchNorm2d(out_channels),

            # ReLU activation (inplace=True saves memory)
            nn.ReLU(inplace=True),

            #! Second Conv2d (keeps the same spatial size)
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    #! The forward() function defines how data flows through this block
    def forward(self, x):
        return self.conv(x)

In [3]:
# Build UNet from scratch
class UNet(nn.Module):
    """
    in_channels: number of input channels (e.g., RGB = 3)
    out_channels: number of output channels (e.g., 1 for binary segmentation mask)
    features: channel sizes used for each downsampling stage
    """

    #! Fill in the feature list to define your UNet channel widths
    def __init__(self, in_channels=3 , out_channels=1, features=[64, 128, 256, 512]):
        super().__init__()

        self.downs = nn.ModuleList()   # Encoder path (downsampling)
        self.ups = nn.ModuleList()     # Decoder path (upsampling)

        #! Build the encoder (downsampling path), Hint: Max pooling is performed in the forward function
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature
        #! Bottleneck block
        self.bottleneck = DoubleConv(features[-1], features[-1]*2)

        # Build the decoder (upsampling path)
        for feature in reversed(features):
            #! Learnable upsampling (ConvTranspose2d): doubles spatial size
            self.ups.append(nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2))

            #! After concatenation with skip connection, channel count doubles
            self.ups.append(DoubleConv(feature*2, feature))

        #! Final 1x1 convolution to produce output mask
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        """
        Forward pass of UNet:
        1. Go down the encoder and save skip connections.
        2. Apply bottleneck.
        3. Go up the decoder, concatenate skip connections, apply double convs.
        """
        skip_connections = []

        # Encoder forward pass
        for down in self.downs:
            logging.debug(f"shape of x: {x.shape}")
            x = down(x)
            skip_connections.append(x)
            x = F.max_pool2d(x, kernel_size=2)  # Downsample

        logging.debug(f"shape of x (after encoder): {x.shape}")

        # Bottleneck
        x = self.bottleneck(x)

        # Reverse skip connections for decoder
        skip_connections.reverse()

        # Decoder forward pass
        #! ups contains pairs: (ConvTranspose2d, DoubleConv)
        for i in range(0, len(self.ups), 2):
            logging.debug(f"shape of x (before upsample): {x.shape}")

            # 1. Upsample
            x = self.ups[i](x)

            # 2. Retrieve corresponding skip connection
            skip_connection = skip_connections[i // 2]

            #! 3. Concatenate along the channel dimension (B, C, H, W)
            x = torch.cat((skip_connection, x), dim=1)

            # 4. Apply double convolution block
            x = self.ups[i + 1](x)

        # Final conv to reduce channel dimension
        return self.final_conv(x)

In [4]:
class MyDataset(Dataset):
    def __init__(self, image_dir, mask_dir, img_transform, mask_transform):
        """
        image_dir: directory containing input images
        mask_dir: directory containing corresponding mask images
        img_transform: torchvision transformation applied to image
        mask_transform: torchvision transformation applied to mask
        """
        super().__init__()
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.img_transform = img_transform
        self.mask_transform = mask_transform

        # Store all filenames in the image directory
        self.images = os.listdir(image_dir)

    def __len__(self):
        # Total number of samples in the dataset
        return len(self.images)

    def __getitem__(self, index):
        """
        Returns:
            transformed image tensor
            transformed mask tensor
        """
        img_path = os.path.join(self.image_dir, self.images[index])

        mask_name = self.images[index].replace(".jpg", "_mask.gif")
        mask_path = os.path.join(self.mask_dir, mask_name)

        # Read image and mask using PIL
        image = Image.open(img_path)
        mask = Image.open(mask_path).convert("L")  # grayscale mask

        # Apply transform to image and mask
        return self.img_transform(image), self.mask_transform(mask)

In [5]:
def train(model, num_epochs, train_loader, loss_function, optimizer, device):
    """
    model: the UNet model
    num_epochs: number of training epochs
    train_loader: dataloader for training data
    loss_function: nn.BCELoss()
    optimizer: optimizer used for updating model weights
    device: gpu or cpu
    """
    for epoch in range(num_epochs):
        total_loss = 0.0

        # Iterate over all mini-batches
        for _, (x, y) in enumerate(train_loader):
            model.train()  # Set model to training mode

            # Move data to device
            x = x.to(device)
            y = y.to(device)

            # Forward pass
            out = model(x)
            out = torch.sigmoid(out)  # convert logits into probability (0~1 range)

            # Compute loss between prediction and ground-truth mask
            loss = loss_function(out, y)
            total_loss += loss.item()

            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        avg_loss = total_loss / len(train_loader)
        print(f"Epoch [{epoch+1}] Train Loss: {avg_loss:.4f}")

        # Evaluate model on validation set at the end of each epoch
        eval(model, val_loader, epoch, device)

In [6]:
def eval(model, val_loader, epoch, device):
    model.eval()  # Set model to evaluation mode
    num_correct = 0
    num_pixels = 0

    with torch.no_grad():  # Disable gradient computation
        for i, (x, y) in enumerate(val_loader):
            x = x.to(device)
            y = y.to(device)

            # Forward pass
            out_img = model(x)
            probability = torch.sigmoid(out_img)

            # Prediction threshold: > 0.5 -> foreground class
            predictions = (probability > 0.5)

            #! for each epoch, save the first image from the first batch (index=0)
            if i == 0:
                #! Hint: Move to CPU, convert to NumPy, reshape RGB:(H,W,C) or Grayscale:(H,W), scale to 0â€“255, convert to uint8
                img_input = x[0].cpu().permute(1, 2, 0).numpy()
                # Scale from 0-1 back to 0-255 and convert to uint8
                img_input = (img_input * 255).astype(np.uint8)
                #! Image.fromarray(img_input).save(f"sample_epoch{epoch+1}_input.png"), Image.fromarray(img_mask).save(f"sample_epoch{epoch+1}_mask.png")
                Image.fromarray(img_input).save(f"sample_epoch{epoch+1}_input.png")
                img_pred = predictions[0].float().cpu().squeeze().numpy()
                img_pred = (img_pred * 255).astype(np.uint8)
                Image.fromarray(img_pred, mode='L').save(f"sample_epoch{epoch+1}_pred.png")
            
            # Count correct pixels for accuracy calculation
            num_correct += (predictions == y).sum()
            num_pixels += y.numel()  # safer than BATCH_SIZE*W*H

    print(f"Epoch [{epoch+1}] Acc : {num_correct / num_pixels:.4f}")

In [7]:
import zipfile

# Path to .zip file
local_zip = "Cars.zip"

# Extract zip file into "images/" directory
zip_ref = zipfile.ZipFile(local_zip, "r")
zip_ref.extractall("images/")
zip_ref.close()

In [8]:
import torch
print(torch.__version__)
print(torch.cuda.is_available())

2.6.0+cu124
True


In [9]:
# setting torch.cuda
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"device: {device}")

#Create a UNet model object, and move to device
model = UNet(in_channels=3, out_channels=1, features=[64, 128, 256, 512]).to(device)

# hyper params
BATCH_SIZE = 16
NUM_EPOCHS = 3
IMG_WIDTH = 240
IMG_HEIGHT = 160

# ToTensor maps pixel values from [0,255] to [0,1]
img_transform = T.Compose([
    T.Resize((IMG_HEIGHT, IMG_WIDTH)),
    T.ToTensor(),
])
# Unlike normal images, masks should NOT be resized with bilinear interpolation,
# because it will create unwanted grayscale values such as 0.0003 or 0.0007.
# We use NEAREST mode to preserve the original 0/255 binary mask.
mask_transform = T.Compose([
    T.Resize((IMG_HEIGHT, IMG_WIDTH), interpolation=T.InterpolationMode.NEAREST),
    T.ToTensor(),
])

# Load data
all_data = MyDataset(image_dir="images/Cars/train/", mask_dir="images/Cars/masks/", img_transform=img_transform, mask_transform=mask_transform)

#Split Data to train_data(70 %) and validate_data(30 %), Hint: len(all_data) can get total number of samples in the dataset
train_size = int(0.7 * len(all_data))
val_size = len(all_data) - train_size
train_data, val_data = torch.utils.data.random_split(all_data, [train_size, val_size])

# create loader for mini-batch gradient descent
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_data, batch_size=BATCH_SIZE, shuffle=False)

# The loss function for binary classification
loss_function = nn.BCELoss()

# Choosing Adam as our optimizer
optimizer = optim.Adam(model.parameters(), lr=1e-3)

#! train
train(model, NUM_EPOCHS, train_loader, loss_function, optimizer, device)

device: cuda


DEBUG:root:shape of x: torch.Size([16, 3, 160, 240])
DEBUG:root:shape of x: torch.Size([16, 64, 80, 120])
DEBUG:root:shape of x: torch.Size([16, 128, 40, 60])
DEBUG:root:shape of x: torch.Size([16, 256, 20, 30])
DEBUG:root:shape of x (after encoder): torch.Size([16, 512, 10, 15])
DEBUG:root:shape of x (before upsample): torch.Size([16, 1024, 10, 15])
DEBUG:root:shape of x (before upsample): torch.Size([16, 512, 20, 30])
DEBUG:root:shape of x (before upsample): torch.Size([16, 256, 40, 60])
DEBUG:root:shape of x (before upsample): torch.Size([16, 128, 80, 120])
DEBUG:root:shape of x: torch.Size([16, 3, 160, 240])
DEBUG:root:shape of x: torch.Size([16, 64, 80, 120])
DEBUG:root:shape of x: torch.Size([16, 128, 40, 60])
DEBUG:root:shape of x: torch.Size([16, 256, 20, 30])
DEBUG:root:shape of x (after encoder): torch.Size([16, 512, 10, 15])
DEBUG:root:shape of x (before upsample): torch.Size([16, 1024, 10, 15])
DEBUG:root:shape of x (before upsample): torch.Size([16, 512, 20, 30])
DEBUG:roo

Epoch [1] Train Loss: 0.1769


DEBUG:root:shape of x: torch.Size([16, 3, 160, 240])
DEBUG:root:shape of x: torch.Size([16, 64, 80, 120])
DEBUG:root:shape of x: torch.Size([16, 128, 40, 60])
DEBUG:root:shape of x: torch.Size([16, 256, 20, 30])
DEBUG:root:shape of x (after encoder): torch.Size([16, 512, 10, 15])
DEBUG:root:shape of x (before upsample): torch.Size([16, 1024, 10, 15])
DEBUG:root:shape of x (before upsample): torch.Size([16, 512, 20, 30])
DEBUG:root:shape of x (before upsample): torch.Size([16, 256, 40, 60])
DEBUG:root:shape of x (before upsample): torch.Size([16, 128, 80, 120])
DEBUG:root:shape of x: torch.Size([16, 3, 160, 240])
DEBUG:root:shape of x: torch.Size([16, 64, 80, 120])
DEBUG:root:shape of x: torch.Size([16, 128, 40, 60])
DEBUG:root:shape of x: torch.Size([16, 256, 20, 30])
DEBUG:root:shape of x (after encoder): torch.Size([16, 512, 10, 15])
DEBUG:root:shape of x (before upsample): torch.Size([16, 1024, 10, 15])
DEBUG:root:shape of x (before upsample): torch.Size([16, 512, 20, 30])
DEBUG:roo

Epoch [1] Acc : 0.9841


DEBUG:root:shape of x: torch.Size([16, 3, 160, 240])
DEBUG:root:shape of x: torch.Size([16, 64, 80, 120])
DEBUG:root:shape of x: torch.Size([16, 128, 40, 60])
DEBUG:root:shape of x: torch.Size([16, 256, 20, 30])
DEBUG:root:shape of x (after encoder): torch.Size([16, 512, 10, 15])
DEBUG:root:shape of x (before upsample): torch.Size([16, 1024, 10, 15])
DEBUG:root:shape of x (before upsample): torch.Size([16, 512, 20, 30])
DEBUG:root:shape of x (before upsample): torch.Size([16, 256, 40, 60])
DEBUG:root:shape of x (before upsample): torch.Size([16, 128, 80, 120])
DEBUG:root:shape of x: torch.Size([16, 3, 160, 240])
DEBUG:root:shape of x: torch.Size([16, 64, 80, 120])
DEBUG:root:shape of x: torch.Size([16, 128, 40, 60])
DEBUG:root:shape of x: torch.Size([16, 256, 20, 30])
DEBUG:root:shape of x (after encoder): torch.Size([16, 512, 10, 15])
DEBUG:root:shape of x (before upsample): torch.Size([16, 1024, 10, 15])
DEBUG:root:shape of x (before upsample): torch.Size([16, 512, 20, 30])
DEBUG:roo

Epoch [2] Train Loss: 0.0586


DEBUG:root:shape of x: torch.Size([16, 3, 160, 240])
DEBUG:root:shape of x: torch.Size([16, 64, 80, 120])
DEBUG:root:shape of x: torch.Size([16, 128, 40, 60])
DEBUG:root:shape of x: torch.Size([16, 256, 20, 30])
DEBUG:root:shape of x (after encoder): torch.Size([16, 512, 10, 15])
DEBUG:root:shape of x (before upsample): torch.Size([16, 1024, 10, 15])
DEBUG:root:shape of x (before upsample): torch.Size([16, 512, 20, 30])
DEBUG:root:shape of x (before upsample): torch.Size([16, 256, 40, 60])
DEBUG:root:shape of x (before upsample): torch.Size([16, 128, 80, 120])
DEBUG:root:shape of x: torch.Size([16, 3, 160, 240])
DEBUG:root:shape of x: torch.Size([16, 64, 80, 120])
DEBUG:root:shape of x: torch.Size([16, 128, 40, 60])
DEBUG:root:shape of x: torch.Size([16, 256, 20, 30])
DEBUG:root:shape of x (after encoder): torch.Size([16, 512, 10, 15])
DEBUG:root:shape of x (before upsample): torch.Size([16, 1024, 10, 15])
DEBUG:root:shape of x (before upsample): torch.Size([16, 512, 20, 30])
DEBUG:roo

Epoch [2] Acc : 0.9900


DEBUG:root:shape of x: torch.Size([16, 3, 160, 240])
DEBUG:root:shape of x: torch.Size([16, 64, 80, 120])
DEBUG:root:shape of x: torch.Size([16, 128, 40, 60])
DEBUG:root:shape of x: torch.Size([16, 256, 20, 30])
DEBUG:root:shape of x (after encoder): torch.Size([16, 512, 10, 15])
DEBUG:root:shape of x (before upsample): torch.Size([16, 1024, 10, 15])
DEBUG:root:shape of x (before upsample): torch.Size([16, 512, 20, 30])
DEBUG:root:shape of x (before upsample): torch.Size([16, 256, 40, 60])
DEBUG:root:shape of x (before upsample): torch.Size([16, 128, 80, 120])
DEBUG:root:shape of x: torch.Size([16, 3, 160, 240])
DEBUG:root:shape of x: torch.Size([16, 64, 80, 120])
DEBUG:root:shape of x: torch.Size([16, 128, 40, 60])
DEBUG:root:shape of x: torch.Size([16, 256, 20, 30])
DEBUG:root:shape of x (after encoder): torch.Size([16, 512, 10, 15])
DEBUG:root:shape of x (before upsample): torch.Size([16, 1024, 10, 15])
DEBUG:root:shape of x (before upsample): torch.Size([16, 512, 20, 30])
DEBUG:roo

Epoch [3] Train Loss: 0.0412


DEBUG:root:shape of x: torch.Size([16, 3, 160, 240])
DEBUG:root:shape of x: torch.Size([16, 64, 80, 120])
DEBUG:root:shape of x: torch.Size([16, 128, 40, 60])
DEBUG:root:shape of x: torch.Size([16, 256, 20, 30])
DEBUG:root:shape of x (after encoder): torch.Size([16, 512, 10, 15])
DEBUG:root:shape of x (before upsample): torch.Size([16, 1024, 10, 15])
DEBUG:root:shape of x (before upsample): torch.Size([16, 512, 20, 30])
DEBUG:root:shape of x (before upsample): torch.Size([16, 256, 40, 60])
DEBUG:root:shape of x (before upsample): torch.Size([16, 128, 80, 120])
DEBUG:root:shape of x: torch.Size([16, 3, 160, 240])
DEBUG:root:shape of x: torch.Size([16, 64, 80, 120])
DEBUG:root:shape of x: torch.Size([16, 128, 40, 60])
DEBUG:root:shape of x: torch.Size([16, 256, 20, 30])
DEBUG:root:shape of x (after encoder): torch.Size([16, 512, 10, 15])
DEBUG:root:shape of x (before upsample): torch.Size([16, 1024, 10, 15])
DEBUG:root:shape of x (before upsample): torch.Size([16, 512, 20, 30])
DEBUG:roo

Epoch [3] Acc : 0.9897


In [10]:
# TA can use this for inspection
# You can also print out the model information for reference
BATCH_SIZE = 16
NUM_EPOCHS = 3
IMG_WIDTH = 240
IMG_HEIGHT = 160

from torchinfo import summary

model = UNet(in_channels=3, out_channels=1)
summary(model, input_size=(BATCH_SIZE, 3, IMG_HEIGHT, IMG_WIDTH))

DEBUG:root:shape of x: torch.Size([16, 3, 160, 240])
DEBUG:root:shape of x: torch.Size([16, 64, 80, 120])
DEBUG:root:shape of x: torch.Size([16, 128, 40, 60])
DEBUG:root:shape of x: torch.Size([16, 256, 20, 30])
DEBUG:root:shape of x (after encoder): torch.Size([16, 512, 10, 15])
DEBUG:root:shape of x (before upsample): torch.Size([16, 1024, 10, 15])
DEBUG:root:shape of x (before upsample): torch.Size([16, 512, 20, 30])
DEBUG:root:shape of x (before upsample): torch.Size([16, 256, 40, 60])
DEBUG:root:shape of x (before upsample): torch.Size([16, 128, 80, 120])


Layer (type:depth-idx)                   Output Shape              Param #
UNet                                     [16, 1, 160, 240]         --
â”œâ”€ModuleList: 1-1                        --                        --
â”‚    â””â”€DoubleConv: 2-1                   [16, 64, 160, 240]        --
â”‚    â”‚    â””â”€Sequential: 3-1              [16, 64, 160, 240]        38,848
â”‚    â””â”€DoubleConv: 2-2                   [16, 128, 80, 120]        --
â”‚    â”‚    â””â”€Sequential: 3-2              [16, 128, 80, 120]        221,696
â”‚    â””â”€DoubleConv: 2-3                   [16, 256, 40, 60]         --
â”‚    â”‚    â””â”€Sequential: 3-3              [16, 256, 40, 60]         885,760
â”‚    â””â”€DoubleConv: 2-4                   [16, 512, 20, 30]         --
â”‚    â”‚    â””â”€Sequential: 3-4              [16, 512, 20, 30]         3,540,992
â”œâ”€DoubleConv: 1-2                        [16, 1024, 10, 15]        --
â”‚    â””â”€Sequential: 2-5                   [16, 1024, 10, 15]     