In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os
project_path = "/content/drive/MyDrive/DL/U-Net"
os.chdir(project_path)
print("Current directory:", os.getcwd())

In [None]:
!pip install -q segmentation-models-pytorch
!pip install -q --upgrade torch torchvision

#### * * * * * Requirements * * * * * 
##### numpy==2.2.6
##### torch==2.8.0+cu126
##### wandb==0.21.3
##### matplotlib==3.10.0
##### segmentation-models-pytorch==0.5.0
##### torchvision==0.23.0+cu126
##### opencv-python==4.12.0.88
##### scikit-learn==1.6.1

In [None]:
import wandb

In [None]:
wandb.login()

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#### * * * * * Class And Methods For Data Handling & Preprocessing * * * * * ####

In [None]:
import cv2
import torch
import numpy as np
from torch.utils.data import DataLoader, Dataset

In [None]:
class ImageDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image, mask = self.data[idx]
        return image, mask

In [None]:
def get_data(type_, image_nums):

    data = []

    for image_num in image_nums:

        image = cv2.imread(f'./images-1024x768/{type_}/image-{image_num}.png')
        image = np.transpose(image, (2, 0, 1))
        image = torch.tensor(image, dtype = torch.float32)
        image = image / 255.0

        mask = cv2.imread(f'./masks-1024x768/{type_}/mask-{image_num}.png', cv2.IMREAD_GRAYSCALE)
        mask = (mask > 0).astype('int')
        mask = torch.tensor(mask, dtype = torch.long)
        data.append([image, mask])

    return data

#### * * * * * Implementation Of The U-Net Architecture * * * * ####

In [6]:
import copy
import torch
from torch import nn
import segmentation_models_pytorch as smp

ModuleNotFoundError: No module named 'torch'

In [None]:
class TwoConvolutions(nn.Module):

    '''
    Defines the necessary attributes and methods to compute two consecutive convolutions while maintaining the image resolution. The first
    intended to double or halve the number of channels depending on the block it resides (encoder or decoder), the second maintains the channels.

    Args:
        in_channels (int): Number of input channels of the first convolion, from which we double or halve.
        out_channels (int): Number of output channels, double or half of the input channels.
    '''
    
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.block = nn.Sequential(
            nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 3, stride = 1, padding = (1,1)),
            nn.ReLU(),
            nn.Conv2d(in_channels = out_channels, out_channels = out_channels, kernel_size = 3, stride = 1, padding = (1,1)),
            nn.ReLU(),
        )

    def forward(self, input_):

        output = self.block(input_)

        return output

In [None]:
class EncoderBlock(nn.Module):

    '''
    This defines the attributes and forward propagation of a single encoder block in the contracting path of the U-Net architecture. This block
    constists of two convolutions (the first which doubles the number of input channels, and second which maintains the channels, both with no
    effect on the resolution) followed by a max-pooling operation which halves the image resolution (halves width and height).

    Args:
        in_channels (int): Input channels.
        out_channels (int): Output channels prior to the max-pooling operation.
    '''
    
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.TwoConvolutions = TwoConvolutions(in_channels, out_channels)
        self.max_pooling = nn.MaxPool2d(kernel_size = 2, stride = 2)

    def forward(self, image):

        skip_features = self.TwoConvolutions(image)
        features = self.max_pooling(skip_features)

        return features, skip_features

In [None]:
class DecoderBlock(nn.Module):

    '''
    Implementation of a single decoder block found in the expansive part of the U-Net architecture. It consists of an upsampling operation on
    the image image from an immediate previous layer, the upsampling operation doubles the image resolution but halves the number of input 
    channels. We then concatenate the output of the upsampling operation with the output of two convolutions from a decoder block in the 
    contracting path over a skip connection, both concatenated tensors share the same image resolution and the number of channels.

    Args:
        in_channels (int): Number of channels prior to the upsampling operation, also same number of channels after concatenation.
        out_channels (int): Number of channels after upsampling operation (prior to concatenation), also number of channels after the two
                            consecutive convolutions.
        bilinear (bool): The argument dictates the type of upsampling operation to use, allowing one to choose between the two variants,
                         set True to use torch.nn.Upsample for bilinear upsampling or False to use torch.nn.ConvTranspose2d for upsampling.
    '''
    
    def __init__(self, in_channels, skip_channels, bilinear = False):
        super().__init__()

        if bilinear:
            self.conc_channels = in_channels + skip_channels
            self.upSampling = nn.Upsample(scale_factor = 2, mode = 'bilinear')
            self.TwoConvolutions = TwoConvolutions(self.conc_channels, self.conc_channels / 2)
        else:
            self.conc_channels = (in_channels / 2) + skip_channels
            self.upSampling = nn.ConvTranspose2d(in_channels, in_channels / 2, kernel_size = 2, stride = 2)
            self.TwoConvolutions = TwoConvolutions(self.conc_channels, self.conc_channels / 2)

    def forward(self, input_, skip_input):

        features = self.upSampling(input_)
        features = torch.cat([features, skip_input], dim = 1)
        features = self.TwoConvolutions(features)

        return features

In [None]:
class U_Net(nn.Module):

    '''
    The implementation allows for the creation of a U-Net model given any number of channels and image resolution with specified depth, 
    not just 512x512 images with 3 channels.

    Args:
        in_channels (int): Number of channels of the input image.
        out_channels (int): Number of output channels which corresponds with predicted classes.
        start_channels (int): This is the number of channels to output in the first convolution of the first encoder layer, from this value 
                              we double the number of channels at each encoder block (after max pooling) until we reach the bottleneck.
        depth (int): Depth of the model, can also be interpreted as the number of downsampling operations in the encoder until the bottleneck
                     or the number of upsampling operations from the bottleneck to the last output layer in the decorder.
        bilinear (bool): The argument dictates the type of upsampling operation to use, allowing one to choose between the two variants,
                         set True to use torch.nn.Upsample for bilinear upsampling or False to use torch.nn.ConvTranspose2d for upsampling.
    '''
    
    def __init__(self, in_channels, out_channels, start_channels, depth, bilinear = False):
        super().__init__()

        self.bilinear = bilinear
        
        encoder_channels = [in_channels] + [start_channels * (2 ** i) for i in range(depth + 1)]
        self.Encoder = nn.ModuleList([
            EncoderBlock(encoder_channels[i], encoder_channels[i+1]) for i in range(depth)
        ])
        
        self.Bottleneck = TwoConvolutions(encoder_channels[depth], encoder_channels[depth + 1])

        
        if self.bilinear:
            decoder_channels = [encoder_channels[depth + 1]] + [encoder_channels[depth - i] for i in range(depth)]
            decoder_channels_in = [self.channels[depth + 1]] + [self.channels[depth + 1 - i] for i in range(1, depth + 1)]
            decoder_channels_skip = [self.channels[depth - i] for i in range(depth)]
            
            self.Decoder = nn.ModuleList([
                DecoderBlock(decoder_channels_in[i], decoder_channels_skip[i], decoder_channels[i + 1], self.bilinear) for i in range(depth)
            ])
        else:
            decoder_channels = copy.copy(encoder_channels)
            decoder_channels.reverse()
            encoder_channels.pop()
            self.Decoder = nn.ModuleList([
                DecoderBlock(encoder_channels[i], encoder_channels[i + 1], self.bilinear) for i in range(depth)
            ])

        self.FinalConvolution = nn.Conv2d(in_channels = self.channels[-1], out_channels = out_channels, kernel_size = 3, stride = 1, padding = (1,1))

    def forward(self, image):

        encoder_features = []
        features = image

        for block in self.Encoder: # warning: first checkpoint is useless but it wont break the code ofcourse
            features, skip_features = torch.utils.checkpoint.checkpoint(block, features, use_reentrant = False)
            encoder_features.append(skip_features)

        features = torch.utils.checkpoint.checkpoint(self.Bottleneck, features, use_reentrant = False)
        encoder_features.reverse()

        for idx, block in enumerate(self.Decoder):
            features = torch.utils.checkpoint.checkpoint(block, features, encoder_features[idx], use_reentrant = False)

        mask = self.FinalConvolution(features)

        return mask


#### * * * * * Training And Model Evaluation Loops * * * * * ####

In [None]:
import copy
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

In [None]:
def metrics(masks, predictions):

    '''
    Computes the necessary metrics (loss, accuracy, iou, recall, precision, f1) for the performance evaluation of the model.

    Args: 
        masks (torch.Tensor): Tensor of true masks (also tensors).
        predictions (torch.Tensor): Predicted masks.
    '''

    predictions = np.concatenate(predictions).flatten()
    masks = np.concatenate(masks).flatten()
    class_ = 1

    accuracy = accuracy_score(masks, predictions)
    precision = precision_score(masks, predictions, average = 'binary', pos_label = class_, zero_division = 0)
    recall = recall_score(masks, predictions, average = 'binary', pos_label = class_, zero_division = 0)
    f1 = f1_score(masks, predictions, average = 'binary', pos_label = class_, zero_division = 0)

    actual_positives = (masks == class_)
    predicted_positives = (predictions == class_)
    intersection = np.logical_and(actual_positives, predicted_positives).sum()
    union = np.logical_or(actual_positives, predicted_positives).sum()
    iou = intersection / union if union > 0 else 0.0

    return accuracy, precision, recall, f1, iou

In [None]:
def performance_report(model, data, batch_size, device):

    '''
    Computes a report of the model's performance on a dataset by the metrics (loss, iou, ...).

    Args:
        model (nn.Module): U-Net model.
        data (list): Data could be training, validation or testing data.
        batch_size (int): Batch size
        device (string): Device to perform forward propagations through model.
    '''

    dataset = ImageDataset(data)
    dataloader = DataLoader(dataset, batch_size, shuffle = False)
    loss_function = nn.CrossEntropyLoss()
    total_loss = 0.0
    predictions = []
    masks = []
    model = model.to(device)
    model.eval()

    with torch.no_grad():

        for image, mask in dataloader:
            image = image.to(device)
            mask = mask.to(device)
            logits = model(image)
            loss = loss_function(logits, mask)

            total_loss += loss.item()
            prediction = torch.argmax(logits, dim = 1)
            predictions.append(prediction.cpu().numpy())
            masks.append(mask.cpu().numpy())

    average_loss = total_loss / len(dataloader)
    accuracy, precision, recall, f1, iou = metrics(masks, predictions)

    return average_loss, accuracy, precision, recall, f1, iou

In [None]:
def training_loop(model, training_data, validation_data, run_name, batch = 1, learning_rate = 1e-2, num_epochs = 10, device = "cpu"):

    '''
    Training loop for U-Net architecture models (any backbone), loop also performs model performance evaluation using validation dataset, 
    during training (for each epoch) we log metric data on the validation dataset to wandb. We use checkpointing to retain the best model by
    IoU after each epoch, the function thus returns this model.
    
    Args:
        model: U-Net model to train, model could be custom implementation or one with pretrained backbone from Segmentation Models pytorch library.
        training_data (list): Training dataset.
        validation_data (list): Validation dataset.
        run_name (string): name of run used to identify logged data and configuration details on wandb.
        batch (int): Batch size.
        learning_rate (float): Learning rate.
        num_epochs (int): Number of epochs.
        device (string): Device used to train model. 
    '''

    dataset = ImageDataset(training_data)
    dataloader = DataLoader(dataset, batch_size = batch, shuffle = True)

    model = model.to(device)
    best_iou = 0.0
    best_model_state = None

    loss_function = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)

    wandb.init(entity = "computer-vision-wits", project = "U-Net", name = run_name)

    config = wandb.config
    config.epochs = num_epochs
    config.batch_size = batch
    config.learning_rate = learning_rate

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0

        for image, mask in dataloader:
            image = image.to(device)
            mask = mask.to(device)
            optimizer.zero_grad()
            logits = model(image)
            loss = loss_function(logits, mask)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        train_loss = total_loss / len(dataloader)
        val_loss, accuracy, precision, recall, f1, iou = performance_report(model, validation_data, batch, device)

        wandb.log({
            'Training Loss': train_loss,
            'Validation Loss': val_loss,
            'Validation Accuracy': accuracy,
            'Validation Precision': precision,
            'Validation Recall': recall,
            'Validation F1': f1,
            'Validation IoU': iou
        })

        if best_iou < iou:
            best_iou = iou
            best_model_state = copy.deepcopy(model.state_dict())

        '''
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': train_loss
        }
        checkpoint_name = os.path.join("checkpoints", f"{run_name}_epoch{epoch}.pth")
        torch.save(checkpoint, checkpoint_name)
        '''

        print(f'Epoch: {epoch}')
        print(f'Training | Loss: {train_loss}')
        print(f"Validation | Loss: {val_loss} | IoU: {iou} | Accuracy: {accuracy} | Precision: {precision} | Recall: {recall} | F1: {f1} \n")

    wandb.finish()

    if best_model_state is not None:
        model.load_state_dict(best_model_state)

    return model

#### * * * * * Retrieve Data * * * * #####

In [None]:
training_data = get_data('train', [2, 7, 10, 12, 21, 24, 27, 28, 30, 43])
validation_data = get_data('val', [1, 11, 22, 32])
testing_data = get_data('test', [4, 16, 29, 36])

#### * * * * * Helper Functions For Experiments * * * * #####

In [None]:
import matplotlib.pyplot as plt

In [None]:
def plots(run_id):

    '''
    Retrieves model training data of a specified run from wandb and creates a plot to view training loss, validation loss and IoU over each epoch.

    Args:
        run_id (string): ID used to identify run in wandb.
    '''

    api = wandb.Api()
    run = api.run(f"computer-vision-wits/U-Net/{run_id}")
    history = run.history()

    plt.figure(figsize=(10, 6))
    plt.plot(history['Training Loss'], label='Training Loss')
    plt.plot(history['Validation Loss'], label='Validation Loss')
    plt.plot(history['Validation IoU'], label='Validation IoU')
    plt.title('Loss & IoU over Epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Loss/IoU')
    plt.ylim(0, 2)
    plt.legend()
    plt.grid(True)
    plt.show()

In [None]:
def best_model_report(model, testing_data, model_name):

    '''
    Reports model performance on the testing dataset.

    Args:
        model (nn.Module): Model to evaluate.
        testing_data (list): Testing dataset.
    '''

    loss, accuracy, precision, recall, f1, iou = performance_report(model = model, data = testing_data, batch_size = 1, device = "cpu")

    print(f"Performance Report of Best {model_name} Model By Validation IoU")
    print(f"Loss: {loss} | IoU: {iou} | Accuracy: {accuracy} | Precision: {precision} | Recall: {recall} | F1: {f1} \n")

In [None]:
def predict_masks(model, testing_data, device = "cpu"):

    '''
    Performs forward propagation and predicts masks given input image data.

    Args:
        model (nn.Module): U-Net model.
        testing_data (list): Testing dataset.
        device (string): Device used to perform computations.
    '''

    model = model.to(device)
    model.eval()
    predicted_masks = []

    with torch.no_grad():
        for image, _ in testing_data:
            image = image.to(device).unsqueeze(0)
            out_image = model(image)
            out_image = torch.argmax(out_image, dim=1)
            out_image = out_image.squeeze(0).cpu()
            predicted_masks.append(out_image)

    return predicted_masks

In [None]:
def display_masks(masks, cmap = 'binary'):

    '''
    Used to display/view given list of masks.
    
    Args:
        masks (list): Masks to display.
        cmap (string): Colormap.
    '''

    num_masks = len(masks)
    fig, axes = plt.subplots(1, num_masks, figsize = (num_masks * 3, 4))

    if num_masks == 1:
        axes = [axes]

    for i, mask in enumerate(masks):
        ax = axes[i]

        if cmap and len(mask.shape) == 2:
            ax.imshow(mask, cmap=cmap)
        else:
            ax.imshow(mask)

        ax.axis("off")

    plt.tight_layout()
    plt.show()

#### * * * * * U-NET BASIC ARCHITECTURE EXPERIMENTS (Usampling vs Bilinear Upsampling) * * * * * ####

In [None]:
'''
U-Net Basic Architecture using Upsampling.
'''

unet_up = U_Net(in_channels = 3, out_channels = 2, start_channels = 64, depth = 4, bilinear = False)
unet_up = training_loop(model = unet_up, training_data = training_data, validation_data = validation_data, run_name = "unet_upsampling", batch = 2, learning_rate = 1e-3, num_epochs = 15, device = device)

torch.torch.cuda.empty_cache()

In [None]:
plots("")

In [None]:
best_model_report(model = unet_up, testing_data = testing_data, model_name = "U-Net Upsampling")

predicted_masks = predict_masks(model = unet_up, testing_data = testing_data, device = "cpu")
display_masks(masks = predicted_masks)

In [None]:
'''
U-Net Basic Architecture using Bilinear Upsampling.
'''

unet_bilinear = U_Net(in_channels = 3, out_channels = 2, start_channels = 64, depth = 4, bilinear = True)
unet_bilinear = training_loop(model = unet_bilinear, training_data = training_data, validation_data = validation_data, run_name = "unet_bilinear", batch = 2, learning_rate = 1e-3, num_epochs = 15, device = device)

torch.torch.cuda.empty_cache()

In [None]:
plots("")

In [None]:
best_model_report(model = unet_bilinear, testing_data = testing_data, model_name = "U-Net Bilinear")

predicted_masks = predict_masks(model = unet_bilinear, testing_data = testing_data, device = "cpu")
display_masks(masks = predicted_masks)

##### Experiments Comments: ...

#### * * * * * U-NET BACKBONE ARCHITECTURES EXPERIMENTS * * * * * ####

In [None]:
''' 
U-Net with ResNet Backbone.
'''

unet_res = smp.Unet(encoder_name = "resnet34", encoder_weights = "imagenet", in_channels = 3, classes = 2)
unet_res = training_loop(model = unet_res, training_data = training_data, validation_data = validation_data, run_name = "unet_res", batch = 2, learning_rate = 1e-3, num_epochs = 15, device = device)

torch.torch.cuda.empty_cache()

In [None]:
plots("")

In [None]:
best_model_report(model = unet_res, testing_data = testing_data, model_name = "U-Net ResNet")

predicted_masks = predict_masks(model = unet_res, testing_data = testing_data, device = "cpu")
display_masks(masks = predicted_masks)

In [None]:
'''
U-Net with EfficientNet Backbone.
'''

unet_eff = smp.Unet(encoder_name = "efficientnet-b0", encoder_weights = "imagenet", in_channels = 3, classes = 2)
unet_eff = training_loop(model = unet_eff, training_data = training_data, validation_data = validation_data, run_name = "unet_eff", batch = 2, learning_rate = 1e-3, num_epochs = 15, device = device)

torch.torch.cuda.empty_cache()

In [None]:
plots("")

In [None]:
best_model_report(model = unet_eff, testing_data = testing_data, model_name = "U-Net EfficientNet")

predicted_masks = predict_masks(model = unet_eff, testing_data = testing_data, device = "cpu")
display_masks(masks = predicted_masks)

In [None]:
'''
U-Net with MobileNet V2 Backbone.
'''
unet_mobile = smp.Unet(encoder_name = "mobilenet_v2", encoder_weights = "imagenet", in_channels = 3, classes = 2)
unet_mobile = training_loop(model = unet_mobile, training_data = training_data, validation_data = validation_data, run_name = "unet_mobile", batch = 2, learning_rate = 1e-3, num_epochs = 15, device = device)

torch.torch.cuda.empty_cache()

In [None]:
plots("")

In [None]:
best_model_report(model = unet_mobile, testing_data = testing_data, model_name = "U-Net MobileNet")

predicted_masks = predict_masks(model = unet_mobile, testing_data = testing_data, device = "cpu")
display_masks(masks = predicted_masks)

##### Experiments Comments: ...

#### * * * * * DEEPLAB V3+ BACKBONE ARCHITECTURES EXPERIMENTS * * * * * ####

In [None]:
'''
DeepLab V3+ with ResNet Backbone
'''

deep_res = smp.DeepLabV3Plus(encoder_name = "resnet34", encoder_weights = "imagenet", in_channels = 3, classes = 2)
deep_res = training_loop(model = deep_res, training_data = training_data, validation_data = validation_data, run_name = "deep_res", batch = 2, learning_rate = 1e-3, num_epochs = 15, device = device)

torch.torch.cuda.empty_cache()

In [None]:
plots("")

In [None]:
best_model_report(model = deep_res, testing_data = testing_data, model_name = "DeepLab ResNet")

predicted_masks = predict_masks(model = deep_res, testing_data = testing_data, device = "cpu")
display_masks(masks = predicted_masks)

In [None]:
'''
DeepLab V3+ with EfficientNet Backbone
'''

deep_eff = smp.DeepLabV3Plus(encoder_name = "efficientnet-b0", encoder_weights = "imagenet", in_channels = 3, classes = 2)
deep_eff = training_loop(model = deep_eff, training_data = training_data, validation_data = validation_data, run_name = "deep_eff", batch = 2, learning_rate = 1e-3, num_epochs = 15, device = device)

torch.torch.cuda.empty_cache()

In [None]:
plots("")

In [None]:
best_model_report(model = deep_eff, testing_data = testing_data, model_name = "DeepLab EfficientNet")

predicted_masks = predict_masks(model = deep_eff, testing_data = testing_data, device = "cpu")
display_masks(masks = predicted_masks)

In [None]:
'''
DeepLab V3+ with MobileNet V2 Backbone

'''
deep_mobile = smp.DeepLabV3Plus(encoder_name = "mobilenet_v2", encoder_weights = "imagenet", in_channels = 3,classes = 2)
deep_mobile = training_loop(model = deep_mobile, training_data = training_data, validation_data = validation_data, run_name = "deep_mobile", batch = 2, learning_rate = 1e-3, num_epochs = 15, device = device)

torch.torch.cuda.empty_cache()

In [None]:
plots("")

In [None]:
best_model_report(model = deep_mobile, testing_data = testing_data, model_name = "DeepLab MobileNet")

predicted_masks = predict_masks(model = deep_mobile, testing_data = testing_data, device = "cpu")
display_masks(masks = predicted_masks)

##### Experiments Comments: ...