In [None]:
# SECTION 1: Initial Setup and libraries import

from google.colab import drive
drive.mount('/content/drive')               # Google Drive
import os
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm

# Deep Learning libraries
import torch, torchvision
print(torch.__version__, torch.cuda.is_available())

# Working directory
current_dir = '/content/drive/MyDrive/Progetto/'

# SECTION 2: U-Net definition

import torch.nn as nn                       # Modules for defining neural networks
import torch.nn.functional as F             # Operative functions

class DoubleConv(nn.Module):
    """Due convoluzioni consecutive con BatchNorm e ReLU"""
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.double_conv = nn.Sequential(   # Sequence of layers
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), # 2D Convolution
            nn.BatchNorm2d(out_channels),   # Batch Normalization
            nn.ReLU(inplace=True),          # ReLu

            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class UNet(nn.Module):
    """Definizione della rete UNet con encoder-decoder e skip connections"""
    def __init__(self, in_channels=1, out_channels=1, init_filters=24, depth=4, bilinear=True):
        super(UNet, self).__init__()
        self.depth = depth
        self.down_layers = nn.ModuleList()  # List of the encoder layers
        self.up_layers = nn.ModuleList()    # List of the decoder layers
        self.pool = nn.MaxPool2d(2)         # Pooling to reduce dimensions

        # Encoder
        filters = init_filters
        for d in range(depth):
            conv = DoubleConv(in_channels, filters)
            self.down_layers.append(conv)
            in_channels = filters
            filters *= 2                            # Doubles the filters

        # Bottleneck
        self.bottleneck = DoubleConv(in_channels, filters)

        # Decoder
        for d in range(depth):
            filters //= 2
            if bilinear:
                up = nn.Sequential(
                    nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
                    nn.Conv2d(filters * 2, filters, kernel_size=1)
                )
            else:
                up = nn.ConvTranspose2d(filters * 2, filters, kernel_size=2, stride=2)
            self.up_layers.append(nn.ModuleDict({
                'up': up,
                'conv': DoubleConv(filters * 2, filters)
            }))

        # Output layer
        self.out_conv = nn.Conv2d(init_filters, out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []               # List for skip connections
        for down in self.down_layers:       # Encoder
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)              # Bottleneck

        for i in range(self.depth):         # Decoder
            skip = skip_connections[-(i+1)] # Restores skip connection
            up = self.up_layers[i]['up'](x)
            if up.size() != skip.size():    # Managing size mismatches
                up = F.interpolate(up, size=skip.shape[2:])
            x = torch.cat([skip, up], dim=1)
            x = self.up_layers[i]['conv'](x)

        return self.out_conv(x)             # Output finale


# SECTION 3: Model testing

device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Selection of GPU if is avaiable
model = UNet(in_channels=1, out_channels=1, init_filters=24, depth=4)
model.to(device)                                                      # Move to device
summary(model, input_size=(1, 1024, 1024))                            # Show the model

class DiceLoss(nn.Module): # loss function based on the dice coefficient


    def __init__(self, smooth=1e-6):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, inputs, targets):
        inputs = torch.sigmoid(inputs)      # Convert logits in probability
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        intersection = (inputs * targets).sum()
        dice = (2. * intersection + self.smooth) / (inputs.sum() + targets.sum() + self.smooth)
        return 1 - dice                     # Loss = 1 - Dice


# SECTION 4: Training Configuration

input_size = (1024,1024)                      # Input image dimensions
in_channels = 1                               # Number of input channels
out_channels = 1                              # Number of output channels (1=binary segmentation)
init_filters = 16
depth = 4
criterion = DiceLoss()                        # Loss Function
n_epochs = 30
batch_size = 7
learning_rate = 0.0003
checkpoint_freq = 1                           # Checkpoints' save frequency
checkpoint_dir = os.path.join(current_dir,'CODICE ONLINE POST TUNING','checkpoints') ##################################################### CHANGE DIRECTORY

if not os.path.exists(checkpoint_dir):        # Create directory if it doesn't exist
    os.makedirs(checkpoint_dir)


# SECTION 5: Dataset e DataLoader

train_img_dir = os.path.join(current_dir,'Dataset_vessel_stu','train','image')
train_mask_dir = os.path.join(current_dir,'Dataset_vessel_stu','train','manual_py')
val_img_dir = os.path.join(current_dir,'Dataset_vessel_stu','val','image')
val_mask_dir = os.path.join(current_dir,'Dataset_vessel_stu','val','manual_py')

import json

params = {
    "input_size": input_size,
    "in_channels": in_channels,
    "out_channels": out_channels,
    "init_filters": init_filters,
    "depth": depth,
    "n_epochs": n_epochs,
    "batch_size": batch_size,
    "learning_rate": learning_rate,
    "checkpoint_freq": checkpoint_freq,
}

params_path = os.path.join(checkpoint_dir, 'training_params.json')
with open(params_path, 'w') as f:
    json.dump(params, f, indent=4)
print(f"Training parameters saved to {params_path}")

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import cv2
from skimage import exposure

class RetinalDataset(Dataset):

    def __init__(self, image_dir, mask_dir, transform=None,multiplier=1):
      self.image_dir = image_dir
      self.mask_dir = mask_dir
      self.transform = transform
      self.image_list = sorted(os.listdir(image_dir))
      self.mask_list = sorted(os.listdir(mask_dir))
      self.image_list = self.image_list * multiplier
      self.mask_list = self.mask_list * multiplier
      # Check that the lists have the same lenght
      assert len(self.image_list) == len(self.mask_list)

    def __len__(self):
      return len(self.image_list)
    def __getitem__(self, idx):
      img_path = os.path.join(self.image_dir, self.image_list[idx])
      mask_path = os.path.join(self.mask_dir, self.mask_list[idx])

      #Convert the image from RGB to 'grayscale whit no blue channel'
      image = np.array(Image.open(img_path).convert('RGB'))
      R = image[:,:,0].astype(np.float32)
      G = image[:,:,1].astype(np.float32)
      img_RG = (0.337 * R + 0.663 * G).astype(np.uint8)
      #Apply the preprocessing
      #Apply Gaussian filter (kernel_size = 3; sigma =1)
      img = cv2.GaussianBlur(img_RG, (3,3), 1)
      #Apply gamma correction (gamma = 0.9)
      img = img/255
      img_gamma = exposure.adjust_gamma(img, gamma=0.9)
      img_gamma = (img_gamma * 255).astype(np.uint8)

      mask = np.array(Image.open(mask_path).convert('L'))

      if self.transform:
        # Albumentations processes the image and mask together
        augmented = self.transform(image=img_gamma, mask=mask)
        image = augmented['image']
        mask = augmented['mask']

      # Conversion of the mask to a binary float (0.0 or 1.0)
      # Note: The image has already been converted to a float tensor by ToTensorV2 or ToFloat in the pipeline
      mask = (mask > 0).float()

      if mask.ndim == 2:
        mask = mask.unsqueeze(0)

      return image, mask


# SECTION 6: Trasformations and DataLoader. This section defines the trasformations applied to the images
#             and initializes dataset and dataloader.

import albumentations as A
from albumentations.pytorch import ToTensorV2

# Training: Resize + Augmentation + Scaling [0,1]
train_transform = A.Compose([
    A.Resize(height=input_size[0], width=input_size[1]), # Image resized here
    A.HorizontalFlip(p=0.6),
    A.VerticalFlip(p=0.6),
    A.RandomRotate90(p=0.6),
    A.Transpose(p=0.6),
    # Important: It scales pixel values to the [0, 1] range by dividing by 255, but without subtracting mean/std
    A.ToFloat(max_value=255.0),
    ToTensorV2()
])

# Validation: Resize + Scaling [0,1]
val_transform = A.Compose([
    A.Resize(height=input_size[0], width=input_size[1]),
    A.ToFloat(max_value=255.0),
    ToTensorV2()
])

# Dataset initialization
train_dataset = RetinalDataset(train_img_dir, train_mask_dir, transform=train_transform, multiplier=3) #the training set is multiplied by 3
val_dataset = RetinalDataset(val_img_dir, val_mask_dir, transform=val_transform,multiplier=1)

# Dataloader initialization
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)


# SECTION 7: Model initalization and optimizer

import torch.optim as optim

model = UNet(in_channels, out_channels, init_filters, depth).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)            # Adam optimizer


# SECTION 8: clDice metric definition

from skimage.morphology import skeletonize

def calculate_cldice(y_true, y_pred):
    """
    Calcola il clDice basato sulla formula:
    clDice = 2 * (TP_prec * TS_sens) / (TP_prec + TS_sens)
    """
    t_mask = y_true.astype(bool)          # Convert ground truth mask into boolean type
    p_mask = y_pred.astype(bool)          # Convert predicted mask into boolean type

    t_skeleton = skeletonize(t_mask)
    p_skeleton = skeletonize(p_mask)

    overlap_pred_skel_true_mask = p_skeleton & t_mask   #Intersection between predicted skelethon and groud truth mask
    tp_len = np.sum(p_skeleton)
    if tp_len == 0:
        tprec = 0
    else:
        tprec = np.sum(overlap_pred_skel_true_mask) / tp_len

    overlap_true_skel_pred_mask = t_skeleton & p_mask   #Intersection between ground truth skelethon and predicted mask
    ts_len = np.sum(t_skeleton)
    if ts_len == 0:
        tsens = 0
    else:
        tsens = np.sum(overlap_true_skel_pred_mask) / ts_len

    if tprec + tsens == 0:                              # divison by zero is avoided
        return 0.0

    cldice = 2.0 * tprec * tsens / (tprec + tsens)
    return cldice

from skimage.measure import label
from skimage.morphology import skeletonize

def clean_by_skeleton_length(mask, min_length): # post processing function to remove small component
    labeled, num = label(mask.astype(np.uint8), return_num=True)

    refined = np.zeros_like(mask, dtype=np.uint8)

    for lab in range(1, num + 1):
        comp_mask = (labeled == lab)

        # component skeleton
        skel = skeletonize(comp_mask)
        length = np.sum(skel)

        # elimination of components that are too short
        if length >= min_length:
            refined[comp_mask] = 1

    return refined

# SECTION 9: Training and validation

from sklearn.metrics import jaccard_score, f1_score, accuracy_score, precision_score, recall_score

#List to save the trend of the metrics
train_losses = []
val_losses = []
val_dscs = []
val_cldices = []
val_precs = []
val_recalls = []


#  batch size of 7 and set accumulation_steps to 2, you simulate a batch size of 14
accumulation_steps = 2

for epoch in range(n_epochs):

    model.train()
    train_loss = 0

    #Reset the gradients BEFORE starting the batch loop
    optimizer.zero_grad()

    for i, (images, masks) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{n_epochs} - Training")):
        images = images.to(device)
        masks = masks.to(device)

        # Note: Not calling optimizer.zero_grad() here because we want the gradients to accumulate

        outputs = model(images)
        loss = criterion(outputs, masks)        # compute the full loss

        # Normalization --> the loss is divided by the number of accumulation steps
        # This ensures that the accumulated gradients form an average
        loss = loss / accumulation_steps

        loss.backward()                         # Backpropagation

        # Conditional Step --> the weights are updated every accumulation_steps' iterations
        if (i + 1) % accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

        # Handling of the last batch
        elif (i + 1) == len(train_loader):
            optimizer.step()
            optimizer.zero_grad()

        # Correct logging --> Since we divided the loss for gradient accumulation, we multiply it back to display the actual value
        train_loss += (loss.item() * accumulation_steps) * images.size(0)

    train_loss /= len(train_loader.dataset)
    print(f"Epoch {epoch+1} Train Loss: {train_loss:.4f}")

   # Validation --> eval modality
    model.eval()
    val_loss = 0.0

    # List to accumulate the metrics image by image
    epoch_dscs = []
    epoch_cldices = []
    epoch_precs = []
    epoch_recalls = []

    with torch.no_grad():                       # Disable gradient
        for images, masks in val_loader:
            images = images.to(device)
            masks = masks.to(device)
            outputs = model(images)

            loss = criterion(outputs, masks)
            val_loss += loss.item() * images.size(0)

            if out_channels == 1:               # Binary segmentation
                probs = torch.sigmoid(outputs)

                preds = (probs > 0.5).long()    # Threshold 0.5
                true = masks.long()
            else:                               # Multi-class segmentation --> we never use it
                preds = torch.argmax(outputs, dim=1)
                true = masks.long()

            preds_np = preds.squeeze().cpu().numpy()    # It converts prediction in NumPy


            #if batch has a sigle image, it adjusts the shape
            if preds_np.ndim == 2:
                preds_np = np.expand_dims(preds_np, axis=0)

            # Post-processing: removal of elements that are too small by checking the length of the skeleton
            refined_preds = []
            for i in range(preds_np.shape[0]):
                refined = clean_by_skeleton_length(
                    preds_np[i], min_length=40
                )
                refined_preds.append(refined)

            preds_np = np.stack(refined_preds, axis=0)

            # Ground truth
            true_np = true.squeeze().cpu().numpy()
            if true_np.ndim == 2:
                true_np = np.expand_dims(true_np, axis=0)

            # CALCULATE METRICS FOR EACH IMAGE IN THE BATCH
            for i in range(preds_np.shape[0]):
                # Flatten masks to evaluate pixel‑wise metrics (required by sklearn)
                flat_true = true_np[i].flatten()
                flat_pred = preds_np[i].flatten()

                # Compute standard segmentation metrics for the current sample
                dsc_img = f1_score(flat_true, flat_pred, average='binary')          # Dice score
                prec_img = precision_score(flat_true, flat_pred, average='binary')  # Precision
                rec_img = recall_score(flat_true, flat_pred, average='binary')      # Recall

                # Compute clDice, which evaluates topological consistency of the vessel tree
                cldice_img = calculate_cldice(true_np[i], preds_np[i])

                # Store per‑image metrics for later averaging across the epoch
                epoch_dscs.append(dsc_img)
                epoch_precs.append(prec_img)
                epoch_recalls.append(rec_img)
                epoch_cldices.append(cldice_img)


    val_loss /= len(val_loader.dataset)         # Validation Loss Mean

    # Compute mean and standard deviation for each metric

    dsc = np.mean(epoch_dscs)
    dsc_std = np.std(epoch_dscs)

    cldice = np.mean(epoch_cldices)
    cldice_std = np.std(epoch_cldices)

    precision = np.mean(epoch_precs)
    prec_std = np.std(epoch_precs)

    recall = np.mean(epoch_recalls)
    rec_std = np.std(epoch_recalls)

    # Print as Mean ± Std
    print(f"Epoch {epoch+1} Val Loss: {val_loss:.4f}")
    print(f"Metrics (Mean ± Std):")
    print(f"DSC:    {dsc:.4f} ± {dsc_std:.4f}")
    print(f"clDice: {cldice:.4f} ± {cldice_std:.4f}")
    print(f"Prec:   {precision:.4f} ± {prec_std:.4f}")
    print(f"Recall: {recall:.4f} ± {rec_std:.4f}")

    # Append to the origianle list (we save only the mean value)
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    val_dscs.append(dsc)
    val_cldices.append(cldice)
    val_precs.append(precision)
    val_recalls.append(recall)

    if (epoch + 1) % checkpoint_freq == 0:              # Checkpoint save
        checkpoint_path = os.path.join(checkpoint_dir, f"model_epoch_{epoch+1}.pt")
        torch.save(model.state_dict(), checkpoint_path)
        print(f"Checkpoint saved at {checkpoint_path}")

# SECTION 10: Plot metrics

epochs = range(1, n_epochs+1)

plt.figure(figsize=(15,10))

# Loss
plt.subplot(2,2,1)
plt.plot(epochs, train_losses, label='Train Loss')
plt.plot(epochs, val_losses, label='Val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training & Validation Loss')
plt.legend()

# clDice and Dice
plt.subplot(2,2,2)
plt.plot(epochs, val_cldices, label='clDice', color='green')
plt.plot(epochs, val_dscs, label='Dice')
plt.xlabel('Epoch')
plt.ylabel('Score')
plt.title('Dice & clDice over Epochs')
plt.legend()

# Precision and Recall
plt.subplot(2,2,3)
plt.plot(epochs, val_precs, label='Precision')
plt.plot(epochs, val_recalls, label='Recall')
plt.xlabel('Epoch')
plt.ylabel('Score')
plt.title( 'Precision & Recall')
plt.legend()

plt.tight_layout()
plt.show()



In [None]:
from google.colab import runtime
runtime.unassign()