## **Medical Image Processing - Retinal Vessel Challenge**
### Test code


In [None]:
# Google Drive
from google.colab import drive
drive.mount('/content/drive')

**1. Install useful libraries and U-Net definition**


In [None]:
import os
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm

In [None]:
# Show versioning of deep learning libraries
import torch, torchvision
print(torch.__version__, torch.cuda.is_available())

In [None]:
# Directory that contains all the data/script
current_dir = "/content/drive/MyDrive"

In [None]:
import torch.nn as nn
import torch.nn.functional as F

class DoubleConv(nn.Module):
    """Applies two consecutive conv-batchnorm-relu layers"""
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),

            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class UNet(nn.Module):
    def __init__(self,
                 in_channels=1,
                 out_channels=1,
                 init_filters=16,
                 depth=4,
                 bilinear=True):
        super(UNet, self).__init__()
        self.depth = depth
        self.down_layers = nn.ModuleList()
        self.up_layers = nn.ModuleList()
        self.pool = nn.MaxPool2d(2)

        # Encoder
        filters = init_filters
        for d in range(depth):
            conv = DoubleConv(in_channels, filters)
            self.down_layers.append(conv)
            in_channels = filters
            filters *= 2

        # Bottleneck
        self.bottleneck = DoubleConv(in_channels, filters)

        # Decoder
        for d in range(depth):
            filters //= 2
            if bilinear:
                up = nn.Sequential(
                    nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
                    nn.Conv2d(filters * 2, filters, kernel_size=1)
                )
            else:
                up = nn.ConvTranspose2d(filters * 2, filters, kernel_size=2, stride=2)
            self.up_layers.append(nn.ModuleDict({
                'up': up,
                'conv': DoubleConv(filters * 2, filters)
            }))

        # Output layer
        self.out_conv = nn.Conv2d(init_filters, out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []
        for down in self.down_layers:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)

        for i in range(self.depth):
            skip = skip_connections[-(i+1)]
            up = self.up_layers[i]['up'](x)
            if up.size() != skip.size():
                # Resize in case of odd size mismatch
                up = F.interpolate(up, size=skip.shape[2:])
            x = torch.cat([skip, up], dim=1)
            x = self.up_layers[i]['conv'](x)

        return self.out_conv(x)

**2. Load the configuration and weights of the trained U-Net model**

In [None]:
import json
import torch

def load_model_from_checkpoint(checkpoint_dir, epoch_to_load):
    # Path to the JSON file with saved parameters
    params_path = os.path.join(checkpoint_dir, 'training_params.json')

    # Load parameters from JSON
    with open(params_path, 'r') as f:
        params = json.load(f)

    print("Loaded parameters:", params)

    # Create the model using the loaded parameters
    model = UNet(
        in_channels=params['in_channels'],
        out_channels=params['out_channels'],
        init_filters=params['init_filters'],
        depth=params['depth']
    )

    # Path to the checkpoint
    checkpoint_path = os.path.join(checkpoint_dir, f"model_epoch_{epoch_to_load}.pt")

    # Load model state
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.load_state_dict(torch.load(checkpoint_path, map_location=device))
    print(f"Checkpoint for epoch {epoch_to_load} loaded from {checkpoint_path}")

    return model, params

In [None]:
checkpoint_dir = os.path.join(current_dir, 'Ale', 'ultimissima', 'checkpoints')  # checkpoint directory
epoch_number = 27   # epoch to load

model, params = load_model_from_checkpoint(checkpoint_dir, epoch_number)
input_size = tuple(params["input_size"])

**3. Apply trained model to the test set**




In [None]:
#from torchvision import transforms
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
from skimage import exposure
from skimage.morphology import skeletonize
from skimage.measure import label
# Define paths
test_images_dir = os.path.join(current_dir, 'Progetto','Dataset_vessel_stu','test','image') # input images directory (modify if necessary)
test_masks_dir = os.path.join(current_dir,'Progetto','Dataset_vessel_stu','test','manual_py') # manual masks directory (modify if necessary)
output_masks_dir = os.path.join(current_dir,'Progetto','Dataset_vessel_stu','test','predictions_final_test') # output masks directory (modify if necessary)

# Create output folders if not present
os.makedirs(output_masks_dir, exist_ok=True)

# Send model to device and set to eval
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
model.eval()

# Preprocessing: Resize and convert to tensor
preprocess = A.Compose([
    A.Resize(height=input_size[0], width=input_size[1]),
    A.ToFloat(max_value=255.0),
    ToTensorV2()
])
# function to remove small skeleton---Post-processing function
def clean_by_skeleton_length(mask, min_length):
    labeled, num = label(mask.astype(np.uint8), return_num=True)
    refined = np.zeros_like(mask, dtype=np.uint8)

    for lab in range(1, num + 1):
        comp_mask = (labeled == lab)
        skel = skeletonize(comp_mask)
        length = np.sum(skel)

        if length >= min_length:
            refined[comp_mask] = 1

    return refined

# List test image files (assumes .png/.jpg/.jpeg)
image_files = [f for f in os.listdir(test_images_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]

for img_name in image_files:
    img_path = os.path.join(test_images_dir, img_name)


    # Load the image and Apply all the preprocessing
    image = np.array(Image.open(img_path).convert('RGB'))
    R = image[:,:,0].astype(np.float32)
    G = image[:,:,1].astype(np.float32)
    img_RG = (0.337 * R + 0.663 * G).astype(np.uint8)
    #Apply Gaussian filter (kernel_size = 3; sigma =1)
    img = cv2.GaussianBlur(img_RG, (3,3), 1)
    #Apply gamma correction (gamma = 0.9)
    img = img/255
    img_gamma = exposure.adjust_gamma(img, gamma=0.9)
    img_gamma = (img_gamma * 255).astype(np.uint8)
    image_pre = preprocess(image=img_gamma)
    input_tensor = image_pre["image"].unsqueeze(0).to(device)


    # Inference
    with torch.no_grad():
        output = model(input_tensor)
        pred_mask = torch.sigmoid(output).cpu().squeeze().numpy()


    # Threshold
    pred_mask_bin = (pred_mask > 0.5).astype(np.uint8)

    # Post-processing: remove small components based on skeleton length
    pred_mask_bin = clean_by_skeleton_length(pred_mask_bin, min_length=40)

    # Save predicted mask
    pred_mask_img = Image.fromarray((pred_mask_bin * 255).astype(np.uint8))
    pred_mask_img.save(os.path.join(output_masks_dir, img_name))

The following cell calculates metrics using sklearn.metrics and the calculate_cldice function. These implementations can be replaced with alternative evaluation scripts as required.

In [None]:

import numpy as np
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt

from sklearn.metrics import f1_score, precision_score, recall_score
from skimage.morphology import skeletonize
def calculate_cldice(gt, pred): # function to compute clDice

    # Skeleton GT and Pred
    skel_gt = skeletonize(gt > 0)
    skel_pred = skeletonize(pred > 0)

    # True positives skeleton
    tprec = np.sum(skel_pred & (gt > 0)) / (np.sum(skel_pred) + 1e-8)
    tsens = np.sum(skel_gt & (pred > 0)) / (np.sum(skel_gt) + 1e-8)

    cldice = 2 * tprec * tsens / (tprec + tsens + 1e-8)
    return cldice

# ============================================================
# METRIC EVALUATION: Recall and Precision added to the standard metrics
# ============================================================

Dices, Precisions, Recalls, clDices = [], [], [], []

# sorted lists of predicted and manual masks
list_mask_manual = sorted([
    os.path.join(test_masks_dir, f)
    for f in os.listdir(test_masks_dir)
    if f.lower().endswith('.png')
])

list_mask_auto = sorted([
    os.path.join(output_masks_dir, f)
    for f in os.listdir(output_masks_dir)
    if f.lower().endswith('.png')
])

for manual_path, auto_path in tqdm(zip(list_mask_manual, list_mask_auto),
                                   total=len(list_mask_manual),
                                   desc="Computing test metrics"):

    # Ground truth
    manual_mask = Image.open(manual_path).convert('L')
    manual_mask = manual_mask.resize(input_size, Image.NEAREST)
    manual_mask = (np.array(manual_mask) > 0).astype(np.uint8)

    # Predicted mask
    auto_mask = Image.open(auto_path).convert('L')
    auto_mask = auto_mask.resize(input_size, Image.NEAREST)
    auto_mask = (np.array(auto_mask) > 0).astype(np.uint8)

    # Metrics
    flat_true = manual_mask.flatten()
    flat_pred = auto_mask.flatten()

    dice = f1_score(flat_true, flat_pred, average='binary')
    prec = precision_score(flat_true, flat_pred, average='binary')
    rec  = recall_score(flat_true, flat_pred, average='binary')
    cld  = calculate_cldice(manual_mask, auto_mask)

    Dices.append(dice)
    Precisions.append(prec)
    Recalls.append(rec)
    clDices.append(cld)

# ============================================================
# Statistics
# ============================================================

print("Dice:     mean {:.4f} std {:.4f}".format(np.mean(Dices), np.std(Dices)))
print("Precision:mean {:.4f} std {:.4f}".format(np.mean(Precisions), np.std(Precisions)))
print("Recall:   mean {:.4f} std {:.4f}".format(np.mean(Recalls), np.std(Recalls)))
print("clDice:   mean {:.4f} std {:.4f}".format(np.mean(clDices), np.std(clDices)))

# ============================================================
# Boxplot
# ============================================================

fig, ax = plt.subplots(1, 4, figsize=(12, 4))
ax[0].boxplot(Dices);      ax[0].set_title("Dice")
ax[1].boxplot(Precisions); ax[1].set_title("Precision")
ax[2].boxplot(Recalls);    ax[2].set_title("Recall")
ax[3].boxplot(clDices);    ax[3].set_title("clDice")

plt.tight_layout()
plt.show()

