# Import Dependencies

In [None]:
# Check system install
import torch
print(torch.__version__)
print(torch.cuda.is_available())  # Should return True if GPU is detected

# General Imports
import numpy as np
import pandas as pd
import random
from tqdm import tqdm 
from torch.optim import Adam, SGD
import monai
import matplotlib.pyplot as plt
from matplotlib import patches
from torch.utils.data import DataLoader
from concurrent.futures import ThreadPoolExecutor
from torch import nn
import torch.nn.functional as F
from torch.nn.functional import threshold, normalize
from monai.losses import DiceLoss
from statistics import mean
import os



# Hyperparameter Tuning
import ray
from ray import tune
from ray.tune import ExperimentAnalysis
from ray.tune.schedulers import ASHAScheduler







# Class Imports
# Reload modules so classes are reloaded every time
import importlib
import image_mask_dataset
import model_evaluator

from image_mask_dataset import ImageMaskDataset
from model_evaluator import ModelEvaluator




# MedSAM
from transformers import SamModel, SamProcessor, SamConfig

from segment_anything import sam_model_registry

#from MedSAM.utils.demo import BboxPromptDemo






2.6.0+cpu
False


## Gather data for each split

In [13]:
importlib.reload(image_mask_dataset)

# Initialize the processor
processor = SamProcessor.from_pretrained("facebook/sam-vit-large")

# Create dataset objects for each split
dataset_path = "Datasets/Dental project.v19i.coco-1"

test_dataset = ImageMaskDataset(dataset_path, "test", processor)
train_dataset = ImageMaskDataset(dataset_path, "train", processor)
valid_dataset = ImageMaskDataset(dataset_path, "valid", processor)



test_dataset[0]["pixel_values"].shape

# Test using a random image
# test_dataset.show_image_mask(random.randint(0,len(test_dataset)-1))
# train_dataset.show_image_mask(random.randint(0,len(train_dataset)-1))
# valid_dataset.show_image_mask(random.randint(0,len(valid_dataset)-1))



100%|██████████| 85/85 [00:00<00:00, 17366.45it/s]


Total valid image-mask pairs found: 82


100%|██████████| 1703/1703 [00:00<00:00, 18209.84it/s]


Total valid image-mask pairs found: 1700


100%|██████████| 157/157 [00:00<00:00, 18212.40it/s]

Total valid image-mask pairs found: 154





(640, 640, 3)

# Base Model Evaluation

### Initialise MedSAM Model

In [14]:
MedSAM_checkpoint = "Models/medsam_vit_b.pth"

medsam_model = sam_model_registry['vit_b'](checkpoint=MedSAM_checkpoint)
medsam_model = medsam_model.to("cuda")

medsam_model.eval()

FileNotFoundError: [Errno 2] No such file or directory: 'Models/medsam_vit_b.pth'

### Box Prompt Inference Demo

In [None]:
# Select random image
image_idx = random.randint(0, len(test_dataset)-1)
image = test_dataset.image_mask_pairs[image_idx][0]

# Display image
%matplotlib inline

test_dataset.show_image_mask(image_idx)

# Segment image
%matplotlib widget
bbox_prompt_demo = BboxPromptDemo(medsam_model)
bbox_prompt_demo.show(image)


In [None]:
%matplotlib inline

def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)
    
def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))    


In [None]:

importlib.reload(model_evaluator)
evaluator = ModelEvaluator(medsam_model, processor, test_dataset)


%matplotlib inline

# Get correct preprocessing
test_dataset.return_as_medsam = True
test_dataset.resize_mask = False

# Load random image
image_idx = random.randint(0, len(test_dataset)-1)


# Get tensors
img_np, box_np, gt_masks, bounding_boxes = test_dataset[image_idx].values()

# Get original image
test_dataset.return_as_medsam = False
img_original = test_dataset[image_idx]["pixel_values"]
W, H, _ = img_original.shape

# Show image
test_dataset.show_image_mask(image_idx)

# image embedding
with torch.no_grad():
    image_embedding = medsam_model.image_encoder(img_np)


# Run inference for all boxes in a batch
with torch.no_grad():
    seg_masks = evaluator.medsam_inference(image_embedding, box_np, H, W)  # List of 5 masks

# Plot results
fig, ax = plt.subplots(1, 2, figsize=(15, 7))


if len(seg_masks.shape) == 2:
    seg_masks = [seg_masks]


# Original image with bounding boxes
ax[0].imshow(img_original)
for box in bounding_boxes:
    show_box(box, ax[0])
ax[0].set_title("Input Image and Bounding Boxes")

# Image with segmentation masks
ax[1].imshow(img_original)
for box, mask in zip(bounding_boxes, seg_masks):  # Iterate over all boxes and their masks
    show_mask(mask, ax[1]) # random_colour = True
    show_box(box, ax[1])
ax[1].set_title("Base MedSAM Segmentation")

plt.show()

### MedSAM Base Model Evaluation

In [None]:
results = evaluator.evaluate_medsam_base_model()
evaluator.print_results()

# Set up Dataloaders

### Set and verify dataloaders

In [None]:
# Ensure images are returned as preprocessed tensors of the right size
test_dataset.preprocess_for_fine_tuning  = True
train_dataset.preprocess_for_fine_tuning = True
valid_dataset.preprocess_for_fine_tuning = True

test_dataset.resize_mask  = True
train_dataset.resize_mask = True
valid_dataset.resize_mask = True

test_dataset.return_individual_objects = True
train_dataset.return_individual_objects = True
valid_dataset.return_individual_objects = True




# Veryify item sizes
example = train_dataset[0]

for k,v in example.items():
  print(f"{k:<25} Shape: {str(v.shape):<30} Dtype: {v.dtype}")

train_dataset.show_image_mask(0)

In [None]:
# Set up dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=2  , shuffle=True, drop_last=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=1, shuffle=False)
test_dataloader = DataLoader(test_dataset,  batch_size=1, shuffle=False)




# Verify batch item sizes
batch = next(iter(train_dataloader))
for k,v in batch.items():
  print(f"{k:<25} Shape: {str(v.shape):<30} Dtype: {v.dtype}")




# Get image from batch
image = batch["pixel_values"][0].detach().cpu().numpy().transpose(1, 2, 0)  # Convert to HxWxC format

# Convert to HxWxC format (just adding a channel dimension if needed)
ground_truth = batch["obj_ground_truth_masks"][0][0].detach().cpu().numpy()  # Convert to numpy (H, W)
ground_truth = np.expand_dims(ground_truth, axis=-1)  # Add a channel dimension (H, W, 1)

print("Image Shape:", image.shape)

fig, (ax1, ax2, ax3) = plt.subplots(1,3,figsize=(24, 8))



ax1.imshow(image)  # Show the image
# Plot predicted boxes
first_box = True
for box in batch["input_boxes"][0]:
    rect = patches.Rectangle(
        (box[0], box[1]),  # x, y (top-left corner)
        box[2] - box[0],  # width
        box[3] - box[1],  # height
        linewidth=2,
        edgecolor='red',
        facecolor='none',
        label='Predicted Box'
    )

    if first_box:
       first_box = False
       rect.set_edgecolor("green")

    ax1.add_patch(rect)
    

ax1.set_title(f"Example Input")





print("Ground Truth Shape", ground_truth.shape)

ax2.imshow(ground_truth, cmap='gray')  # Show the second image
# Plot predicted boxes for the second image
box = batch["input_boxes"][0][0]
box = (box / torch.tensor([1024,1024,1024,1024], device="cpu")) * 256
rect = patches.Rectangle(
    (box[0], box[1]),  # x, y (top-left corner)
    box[2] - box[0],  # width
    box[3] - box[1],  # height
    linewidth=2,
    edgecolor='green',
    facecolor='none',
    label='Predicted Box'
)
    
ax2.add_patch(rect)
ax2.set_title(f"Example GT (Loss)")


ax3.imshow(batch["ground_truth_mask"][0], cmap="gray")
ax3.set_title(f"GT Mask")

plt.show()




# Hyperparameter Tuning (Ray Tune)

### Set Up Remote Actor for Valid Dataloader
 __**NOTE: This cell is where the the outputs from the tuning loop are output**__

In [None]:


# Create a Ray actor to distribute validation data across workers
@ray.remote
class DataLoaderActor:
    def __init__(self, dataloader):
        self.dataloader = dataloader
        self.iterator = iter(self.dataloader)

    def get_batch(self):
        try:
            return next(self.iterator)
        except StopIteration:
            self.iterator = iter(self.dataloader)  # Reset iterator if exhausted
            return next(self.iterator)

    def get_length(self):
        return len(self.dataloader)



# Wrap the DataLoader inside a Ray actor
valid_dataloader_actor = DataLoaderActor.remote(valid_dataloader)


# Define the custom short directory name function (Avoids window's restriction of paths being < 260 chars)
def short_dirname(trial):
    return "trial_" + str(trial.trial_id[:13])  # Shorten to 8 characters




In [None]:
def remove_invalid_boxes(input_boxes, obj_ground_truth_masks):
    # Create a mask to identify input_boxes that are exactly [0, 0, 0, 0]
    valid_mask = ~(input_boxes == torch.tensor([0, 0, 0, 0], dtype=input_boxes.dtype, device=input_boxes.device)).all(dim=-1)

    # Filter input boxes and corresponding masks for each image in the batch (batch size = 1)
    filtered_input_boxes = input_boxes[valid_mask]
    filtered_obj_ground_truth_masks = obj_ground_truth_masks[valid_mask]

    # Return filtered input boxes and ground truth masks, maintaining batch size of 1
    return filtered_input_boxes.unsqueeze(0), filtered_obj_ground_truth_masks.unsqueeze(0)



### Tuning The Optimisers

In [None]:
# Function to train and evaluate model with different optimizer settings
def tune_optimizer(config, dataloader_actor):


    # Clear CUDA memory at the start of each call
    torch.cuda.empty_cache()

    # Initialize model
    medsam_model = SamModel.from_pretrained("flaviagiammarino/medsam-vit-base")

    # Freeze encoder layers
    for name, param in medsam_model.named_parameters():
        if name.startswith("vision_encoder") or name.startswith("prompt_encoder"):
            param.requires_grad_(False)
    
    medsam_model.to("cuda")

    # Define optimizer based on config
    if config["optimizer"] == "adam":
        optimizer = Adam(
            params=medsam_model.mask_decoder.parameters(),
            lr=config["lr"],      
            weight_decay=config["weight_decay"]   
        )
    elif config["optimizer"] == "sgd":
        optimizer = SGD(
            params=medsam_model.mask_decoder.parameters(),
            lr=config["lr"],
            momentum=config["momentum"],
            weight_decay=config["weight_decay"]
        )

    # Use Focal Loss with predefined sensible parameters
    loss_fn = monai.losses.FocalLoss(
        gamma=2.0,  # Common default for Focal Loss
        reduction="mean",
        include_background=True
    )





    

    # Train for one epoch with the current configuration
    epoch_losses = []
    for _ in tqdm(range(ray.get(dataloader_actor.get_length.remote()))):
        batch = ray.get(dataloader_actor.get_batch.remote())  # Fetch batch using Ray actor




        # Get batch values for inference
        pixel_values = batch["pixel_values"].to("cuda")
        input_boxes = batch["input_boxes"].to("cuda")
        obj_ground_truth_masks = batch["obj_ground_truth_masks"].float().to("cuda").squeeze(1)  # Remove extra singleton dimension



        # To ge the mean of each individual image in the batch
        batch_loss_values = []


        # Loops through each image in the batch, removes the padding on input_boxes and obj_ground_truth_masks to ensure the loss isn't miscalculated by using empty inputs (because that happens??)
        # We take the mean of the batch still (using batch_loss_values) to ensure a smoother gradient by avoiding the noise from using individual images
        for image, input_box, obj_mask in zip(pixel_values, input_boxes, obj_ground_truth_masks):

            # Remove the padding from these batch values
            input_box, obj_mask = remove_invalid_boxes(input_box, obj_mask)


            # If the input somehow has no object masks we can skip
            if input_box.shape[1] > 0:

                # forward pass
                outputs = medsam_model(
                    pixel_values=image.unsqueeze(0), # Add batch dimension back
                    input_boxes=input_box,
                    multimask_output=False)


                # Get predicted masks and ground truth masks
                predicted_masks = outputs.pred_masks.squeeze(2)  # Remove extra singleton dimension from predicted masks (shape: [1, 20, 256, 256])
            
                # Convert object ground truth masks to binary to pass into MONAI loss function
                obj_mask = (obj_mask > 0).float()

                # Ensure the predicted and ground truth masks have the same shape
                #print("\n\nPredicted Mask shape: ",predicted_masks.shape)
                #print("obj gt shape: ", obj_mask.shape)

                # Convert logits to probabilities
                predicted_masks = torch.sigmoid(predicted_masks) 

                # Calculate loss using defined loss function
                batch_loss_values.append(loss_fn(predicted_masks, obj_mask))
                print("Object Loss: ",batch_loss_values[-1])



                # Show predictions
                # with torch.no_grad():
                #         first_pred_mask = torch.sigmoid(predicted_masks[0, 0]).cpu().numpy()  # Convert to numpy for plotting
                #         first_gt_mask = obj_mask[0, 0].cpu().numpy()

                #         fig, ax = plt.subplots(1, 2, figsize=(10, 5))
                #         ax[0].imshow(first_pred_mask, cmap="gray")
                #         ax[0].set_title("Predicted Mask")

                #         ax[1].imshow(first_gt_mask, cmap="gray")
                #         ax[1].set_title("Ground Truth Mask")

                #         plt.show()



            else:
                # Debug print to catch the missing masks
                print("No masks found for:", input_box)




        # Calculate the mean of the batch and convert into a torch tensor for backpropagation
        loss = torch.stack(batch_loss_values).mean() if batch_loss_values else None
        print("Batch Loss: ", loss)


        # If no masks found continue without processing batch
        if not loss:
            continue

        # backward pass
        optimizer.zero_grad()
        loss.backward()

        # optimize
        optimizer.step()
        epoch_losses.append(loss.item())



        


        




    # Compute and report mean loss for the epoch
    mean_epoch_loss = mean(epoch_losses)


    # Store metrics for these hyperparameters
    metrics = {
            "loss": mean_epoch_loss
        }
    
    tune.report(metrics)


In [None]:

# Define the search space for optimizer tuning
search_space_optimizer = {
    "optimizer": tune.choice(["adam", "sgd"]),
    "lr": tune.loguniform(1e-5, 1e-3), 
    "weight_decay": tune.uniform(0, 1e-3), 
    "momentum": tune.uniform(0, 0.99),  # SGD-specific momentum
}

# Get the absolute path for 'ray_results'
storage_path = os.path.abspath("Ray Results/Optimizer")

# Run Ray Tune with the absolute path
optimizer_result = tune.run(
    lambda config: tune_optimizer(config, valid_dataloader_actor),  # Lambda function to pass dataloader_actor
    config=search_space_optimizer,  # The search space
    num_samples=1,  # Number of trials to run
    trial_dirname_creator=short_dirname,  # Custom trial log directory name
    resources_per_trial={"gpu": 1},  # Ensure 1 GPU per trial
    storage_path=storage_path  # Use the absolute path for storage
)

In [None]:

# Get the best trial based on the lowest loss
best_trial = optimizer_result.get_best_trial(metric="loss", mode="min")

# Output the best hyperparameters
best_optimizer = best_trial.config["optimizer"]
best_momentum  = best_trial.config["momentum"] if best_optimizer == "sgd" else None  # Use momentum only for SGD

# Print results
print(f"Best Loss Value: {best_trial.last_result["loss"]}")
print(f"Best Optimizer: {best_optimizer}")
print(f"Best Learning Rate: {best_trial.config["lr"]:.6f}")
print(f"Best Weight Decay: {best_trial.config["weight_decay"]:.6f}")
if best_momentum is not None:
    print(f"Best Momentum: {best_momentum:.6f}")


### Find best saved params for optimizer:

In [None]:
import os
from ray.tune import ExperimentAnalysis

# Get the absolute path for 'Ray Results/Optimizer'
results_dir = os.path.abspath("Ray Results/Optimizer")

# Optimizer types to filter
optimizers = ["adam", "sgd"]

# Check if the directory exists
if not os.path.exists(results_dir):
    print(f"Directory {results_dir} does not exist.")
else:
    # Loop through each optimizer
    for optimizer_fn in optimizers:
        
        # List all files (subdirectories) in the results directory
        for subdir in os.listdir(results_dir):
            subdir_path = os.path.join(results_dir, subdir)
            
            # Only proceed if it's a directory (skip files)
            if os.path.isdir(subdir_path):
                try:
                    # Load the ExperimentAnalysis object for this subdir
                    analysis = ExperimentAnalysis(subdir_path)
                    
                    # Retrieve all trials and filter them by the current optimizer function
                    trials = analysis.trials  # Get all trials
                    filtered_trials = [trial for trial in trials if trial.config["optimizer"] == optimizer_fn]
                    
                    print(f"\nBest Parameters for {optimizer_fn.capitalize()} Optimizer:")
                    
                    if not filtered_trials:
                        print(f"        No valid trial found for {optimizer_fn} optimizer.")
                    else:
                        # Get the best trial (the one with the minimum loss)
                        best_trial = min(filtered_trials, key=lambda t: t.last_result["loss"])
                        
                        # Retrieve the best configuration and loss value
                        best_optimizer_config = best_trial.config  # Best hyperparameters
                        best_loss = best_trial.last_result["loss"]  # Best loss value
                        
                        # Print results for the current optimizer and subdirectory
                        print(f"        Config: {best_optimizer_config}")
                        print(f"        Best Loss Value: {best_loss:.6f}")
                
                except Exception as e:
                    print(f"Error processing {subdir_path}: {e}")


### Tuning The Loss Funtions


In [None]:

# Function to train and evaluate model with different loss function settings
def tune_loss(config, dataloader_actor):


    # Clear CUDA memory at the start of each call
    torch.cuda.empty_cache()

    # Initialize model
    medsam_model = SamModel.from_pretrained("flaviagiammarino/medsam-vit-base")

    # Freeze encoder layers
    for name, param in medsam_model.named_parameters():
        if name.startswith("vision_encoder") or name.startswith("prompt_encoder"):
            param.requires_grad_(False)
    
    medsam_model.to("cuda")

    # Baseline Optimizer
    optimizer = Adam(
        params=medsam_model.mask_decoder.parameters(),
        lr=0.0001,      
        weight_decay=0   
    )

    # Choose loss functions
    if config["loss"] == "dice":
        loss_fn = monai.losses.DiceLoss(
            squared_pred=config["squared_pred"],
            reduction="mean",
            include_background=config["include_background"]
        )
    elif config["loss"] == "focal":
        loss_fn = monai.losses.FocalLoss(
            gamma=config["gamma"],
            reduction="mean",
            include_background=config["include_background"]
        )
    elif config["loss"] == "tversky":
        loss_fn = monai.losses.TverskyLoss(
            alpha=config["alpha"],
            beta=config["beta"],
            reduction="mean",
            include_background=config["include_background"]
        )









   
    # Train for one epoch with the current configuration
    epoch_losses = []
    for _ in tqdm(range(ray.get(dataloader_actor.get_length.remote()))):
        batch = ray.get(dataloader_actor.get_batch.remote())  # Fetch batch using Ray actor




        # Get batch values for inference
        pixel_values = batch["pixel_values"].to("cuda")
        input_boxes = batch["input_boxes"].to("cuda")
        obj_ground_truth_masks = batch["obj_ground_truth_masks"].float().to("cuda").squeeze(1)  # Remove extra singleton dimension



        # To ge the mean of each individual image in the batch
        batch_loss_values = []


        # Loops through each image in the batch, removes the padding on input_boxes and obj_ground_truth_masks to ensure the loss isn't miscalculated by using empty inputs (because that happens??)
        # We take the mean of the batch still (using batch_loss_values) to ensure a smoother gradient by avoiding the noise from using individual images
        for image, input_box, obj_mask in zip(pixel_values, input_boxes, obj_ground_truth_masks):

            # Remove the padding from these batch values
            input_box, obj_mask = remove_invalid_boxes(input_box, obj_mask)


            # If the input somehow has no object masks we can skip
            if input_box.shape[1] > 0:

                # forward pass
                outputs = medsam_model(
                    pixel_values=image.unsqueeze(0), # Add batch dimension back
                    input_boxes=input_box,
                    multimask_output=False)


                # Get predicted masks and ground truth masks
                predicted_masks = outputs.pred_masks.squeeze(2)  # Remove extra singleton dimension from predicted masks (shape: [1, 20, 256, 256])
            
                # Convert object ground truth masks to binary to pass into MONAI loss function
                obj_mask = (obj_mask > 0).float()

                # Ensure the predicted and ground truth masks have the same shape
                #print("\n\nPredicted Mask shape: ",predicted_masks.shape)
                #print("obj gt shape: ", obj_mask.shape)

                # Convert logits to probabilities
                predicted_masks = torch.sigmoid(predicted_masks) 

                # Calculate loss using defined loss function
                batch_loss_values.append(loss_fn(predicted_masks, obj_mask))
                print("Object Loss: ",batch_loss_values[-1])



                # Show predictions
                # with torch.no_grad():
                #         first_pred_mask = torch.sigmoid(predicted_masks[0, 0]).cpu().numpy()  # Convert to numpy for plotting
                #         first_gt_mask = obj_mask[0, 0].cpu().numpy()

                #         fig, ax = plt.subplots(1, 2, figsize=(10, 5))
                #         ax[0].imshow(first_pred_mask, cmap="gray")
                #         ax[0].set_title("Predicted Mask")

                #         ax[1].imshow(first_gt_mask, cmap="gray")
                #         ax[1].set_title("Ground Truth Mask")

                #         plt.show()



            else:
                # Debug print to catch the missing masks
                print("No masks found for:", input_box)




        # Calculate the mean of the batch and convert into a torch tensor for backpropagation
        loss = torch.stack(batch_loss_values).mean() if batch_loss_values else None
        print("Batch Loss: ", loss)

        # If no masks found continue without processing batch
        if not loss:
            continue

        # backward pass
        optimizer.zero_grad()
        loss.backward()

        # optimize
        optimizer.step()
        epoch_losses.append(loss.item())












    # Compute and report mean loss for the epoch
    mean_epoch_loss = mean(epoch_losses)


    # Store metrics for these hyperparameters
    metrics = {
            "loss": mean_epoch_loss
        }
    
    tune.report(metrics)


In [None]:

# Define the search space for loss function tuning
search_space_loss = {
    "loss":               tune.choice(["dice", "focal", "tversky"]),
    "include_background": tune.choice([False]),    # Maybe add true, but I think its too imbalanced to use (binary imbalance)
    "squared_pred":       tune.choice([True, False]),  # Only for Dice loss
    "gamma":              tune.uniform(1.0, 5.0),  # Only for Focal loss (A higher value of gamma can be better for imbalanced classes (e.g. 3-5))
    "alpha":              tune.uniform(0.3, 0.7),  # Only for Tversky loss (false positives)
    "beta":               tune.uniform(0.3, 0.8)    # Only for Tversky loss (false negatives)
}


# Get the absolute path for 'ray_results'
storage_path = os.path.abspath("Ray Results/Loss")

# Run Ray Tune with the absolute path
loss_result = tune.run(
    lambda config: tune_loss(config, valid_dataloader_actor),  # Lambda function to pass dataloader_actor
    config=search_space_loss,  # The search space
    num_samples=1,  # Number of trials to run
    trial_dirname_creator=short_dirname,  # Custom trial log directory name
    resources_per_trial={"cpu": 8, "gpu": 1},  # Ensure 1 GPU per trial
    storage_path=storage_path  # Use the absolute path for storage
)


In [None]:
best_trial = loss_result.get_best_trial(metric="loss", mode="min")

print("Best Loss Value:", best_trial.last_result["loss"])
print("Best Loss Config:", best_trial.config)

### Find best params for loss functions

In [None]:

# Get the absolute path for 'Ray Results/Loss'
results_dir = os.path.abspath("Ray Results/Loss")

# Loss function types to filter
loss_functions = ["dice", "focal", "tversky"]

# Check if the directory exists
if not os.path.exists(results_dir):
    print(f"Directory {results_dir} does not exist.")
else:
    # Loop through each loss function
    for loss_fn in loss_functions:
        
        # List all files (subdirectories) in the results directory
        for subdir in os.listdir(results_dir):
            subdir_path = os.path.join(results_dir, subdir)
            
            # Only proceed if it's a directory (skip files)
            if os.path.isdir(subdir_path):
                try:
                    # Load the ExperimentAnalysis object for this subdir
                    analysis = ExperimentAnalysis(subdir_path)
                    
                    # Retrieve all trials and filter them by the current loss function
                    trials = analysis.trials  # Get all trials
                    filtered_trials = [trial for trial in trials if trial.config["loss"] == loss_fn]

                    print(f"\nBest Parameters for {loss_fn.capitalize()} Loss:")
                
                    if not filtered_trials:
                        print(f"        No valid trial found for {loss_fn} loss.")
                    else:
                        # Get the best trial (the one with the minimum loss)
                        best_trial = min(filtered_trials, key=lambda t: t.last_result["loss"])
                        
                        # Retrieve the best configuration and loss value
                        best_loss_config = best_trial.config  # Best hyperparameters
                        best_loss = best_trial.last_result["loss"]  # Best loss value
                        
                        # Print results for the current loss function and subdirectory
                        print(f"        Config: {best_loss_config}")
                        print(f"        Best Loss Value: {best_loss:.6f}")
                
                except Exception as e:
                    print(f"Error processing {subdir_path}: {e}")

# Fine-tuning MedSAM

## Fine-Tuning Hugging Face SamProcessor (MedSAM)

In [None]:
# Remove previous models from memory
medsam_model = None
del medsam_model
torch.cuda.empty_cache()

# load the pretrained weights for finetuning
medsam_model = SamModel.from_pretrained("flaviagiammarino/medsam-vit-base")

# make sure we only compute gradients for mask decoder
for name, param in medsam_model.named_parameters():
  if name.startswith("vision_encoder") or name.startswith("prompt_encoder"):
    param.requires_grad_(False)

  
device = "cuda"
medsam_model.to(device)

In [None]:


# Set optimisers (These will be used for hyperparameter tuning later)

# Initialize the optimizer and the loss function (May need to play with lower learning rates to avoid changing base model too much)
adam_optimizer = Adam(medsam_model.mask_decoder.parameters(), lr=0.00001, weight_decay=0)

sgd_optimizer  = SGD(medsam_model.mask_decoder.parameters(), lr=0.01, momentum=0.9)



# Set loss functions

# Define the DiceCELoss
dice_loss = monai.losses.DiceLoss(
    squared_pred=True,       # square predictions for  dice calculation (peanalises false positives)
    reduction='mean',        # how losses are aggregated (mean, sum, or none)
    include_background = False
)


# Define the Focal Loss (good for imbalanced pixel coverage)
focal_loss = monai.losses.FocalLoss(
    gamma=2.0,             # Focusing (higher = more focus on hard examples)
    reduction='mean',       # How losses are aggregated (mean, sum, or none)
    include_background = False
)


# Define Tversky Loss (false pos not that imporant, but false neg is bad)
tversky_loss = monai.losses.TverskyLoss(
    alpha=0.7,              # increase for less false positives
    beta=0.3,               # increase for less false negatives
    reduction='mean',       # how losses are aggregated (mean, sum, or none)
    include_background = False
)


# Maybe also try focal or dicefocal loss?


In [None]:
def remove_invalid_boxes(input_boxes, obj_ground_truth_masks):
    # Create a mask to identify input_boxes that are exactly [0, 0, 0, 0]
    valid_mask = ~(input_boxes == torch.tensor([0, 0, 0, 0], dtype=input_boxes.dtype, device=input_boxes.device)).all(dim=-1)

    # Filter input boxes and corresponding masks for each image in the batch (batch size = 1)
    filtered_input_boxes = input_boxes[valid_mask]
    filtered_obj_ground_truth_masks = obj_ground_truth_masks[valid_mask]

    # Return filtered input boxes and ground truth masks, maintaining batch size of 1
    return filtered_input_boxes.unsqueeze(0), filtered_obj_ground_truth_masks.unsqueeze(0)




### Base Training Loop

In [None]:
torch.cuda.empty_cache()

#Training loop
num_epochs = 1
batch_losses = []  # Store loss for each batch

#medsam_model.train()
for epoch in range(num_epochs):
    epoch_losses = []
    for batch in tqdm(train_dataloader):



        # print(batch.keys())


        # Get batch values for inference
        pixel_values = batch["pixel_values"].to(device)
        input_boxes = batch["input_boxes"].to(device)
        obj_ground_truth_masks = batch["obj_ground_truth_masks"].float().to(device).squeeze(1)  # Remove extra singleton dimension


        # To ge the mean of each individual image in the batch
        batch_loss_values = []


        # Loops through each image in the batch, removes the padding on input_boxes and obj_ground_truth_masks to ensure the loss isn't miscalculated by using empty inputs (because that happens??)
        # We take the mean of the batch still (using batch_loss_values) to ensure a smoother gradient by avoiding the noise from using individual images

        

        for image, input_box, obj_mask in zip(pixel_values, input_boxes, obj_ground_truth_masks):

            # Remove the padding from these batch values
            input_box, obj_mask = remove_invalid_boxes(input_box, obj_mask)


            # If the input somehow has no object masks we can skip
            if input_box.shape[1] > 0:

                # forward pass
                outputs = medsam_model(
                    pixel_values=image.unsqueeze(0), # Add batch dimension back
                    input_boxes=input_box,
                    multimask_output=False)


                # Get predicted masks and ground truth masks
                predicted_masks = outputs.pred_masks.squeeze(2)  # Remove extra singleton dimension from predicted masks (shape: [1, 20, 256, 256])
            
                # Convert object ground truth masks to binary to pass into MONAI loss function
                obj_mask = (obj_mask > 0).float()

                # Ensure the predicted and ground truth masks have the same shape
                #print("\n\nPredicted Mask shape: ",predicted_masks.shape)
                #print("obj gt shape: ", obj_mask.shape)

                # Convert logits to probabilities
                predicted_masks = torch.sigmoid(predicted_masks) 


                # Calculate loss using defined loss function
                batch_loss_values.append(dice_loss(predicted_masks, obj_mask))
                print("Object Loss: ",batch_loss_values[-1])



                # Show predictions
                # with torch.no_grad():
                #         first_pred_mask = torch.sigmoid(predicted_masks[0, 0]).cpu().numpy()  # Convert to numpy for plotting
                #         first_gt_mask = obj_mask[0, 0].cpu().numpy()

                #         fig, ax = plt.subplots(1, 2, figsize=(10, 5))
                #         ax[0].imshow(first_pred_mask, cmap="gray")
                #         ax[0].set_title("Predicted Mask")

                #         ax[1].imshow(first_gt_mask, cmap="gray")
                #         ax[1].set_title("Ground Truth Mask")

                #         plt.show()



            else:
                # Debug print to catch the missing masks
                print("No masks found for:", input_box)

                with torch.no_grad():
                        plt.imshow(image.cpu().numpy().squeeze().transpose(1,2,0))

       




        # Calculate the mean of the batch and convert into a torch tensor for backpropagation
        loss = torch.stack(batch_loss_values).mean() if batch_loss_values else None
        print("\nBatch Loss: ", loss)


        # If no masks found continue without processing batch
        if not loss:
            continue

        # backward pass
        adam_optimizer.zero_grad()
        loss.backward()

        # optimize
        adam_optimizer.step()
        epoch_losses.append(loss.item())

        # Store batch loss
        batch_losses.append(loss.item())





        # NOTE: sigmoid and softmax have the same output range here indicating that the SAM and MedSAM models only predict one class by default. This might 
        # Calculate the max and min values for both predicted and ground truth masks
        # pred_min_value = F.softmax(predicted_masks).min().item()
        # pred_max_value = F.softmax(predicted_masks).max().item()
        # gt_min_value = obj_ground_truth_masks.min().item()
        # gt_max_value = obj_ground_truth_masks.max().item()

        # Print the min and max values for both predicted and ground truth masks
        #print(f"Predicted Mask - Min Value: {pred_min_value}, Max Value: {pred_max_value}")
        #print(f"Ground Truth Mask - Min Value: {gt_min_value}, Max Value: {gt_max_value}")
        
        



       





    print(f'EPOCH: {epoch}')
    print(f'Mean loss: {mean(epoch_losses)}')
     

# Save the model's state dictionary to a file
torch.save(medsam_model.state_dict(), "Models/medsam_vit_b_object_masks.pth")
     


### Weighted Loss Training Loop (Do not use while undersampling!)

In [None]:

# To define class weights I have used the average pixel frequency * class freq in an inverse ratio to the total. 
# I have capped the background as it has significantly more pixels than any object, and fillings as they were too common.

# Define Class Weights
class_weights = torch.tensor([
    0.02,   # Background
    0.133,  # braces
    0.184,  # bridge
    0.086,  # cavity
    0.081,  # crown
    0.05,   # filling
    0.326,  # implant
    0.172   # lesion 
    ]).to(device)  


# Define the DiceCELoss
dice_loss = monai.losses.DiceLoss(
    sigmoid=True,            # sigmoid activation
    squared_pred=True,       # square predictions for  dice calculation (peanalises false positives)
    reduction="none",        # how losses are aggregated (mean, sum, or none)
    include_background = False
)


# Define the Focal Loss (good for imbalanced pixel coverage)
focal_loss = monai.losses.FocalLoss(
    gamma=2.0,             # Focusing (higher = more focus on hard examples)
    reduction="none",      # How losses are aggregated (mean, sum, or none)
    include_background = False
)

# Define Tversky Loss (false pos not that imporant, but false neg is bad)
tversky_loss = monai.losses.TverskyLoss(
    alpha=0.7,              # increase for less false positives
    beta=0.3,               # increase for less false negatives
    reduction="none",       # how losses are aggregated (mean, sum, or none)
    include_background = False
)


# Maybe also try focal or dicefocal loss?


In [None]:

#Training loop
num_epochs = 1
batch_losses = []  # Store loss for each batch

medsam_model.train()
for epoch in range(num_epochs):
    epoch_losses = []
    for batch in tqdm(train_dataloader):

        

        # print(batch.keys())


        # Get batch values for inference
        pixel_values = batch["pixel_values"].to(device)
        input_boxes = batch["input_boxes"].to(device)
        obj_ground_truth_masks = batch["obj_ground_truth_masks"].float().to(device).squeeze(1)  # Remove extra singleton dimension

         # To ge the mean of each individual image in the batch
        batch_loss_values = []

      
        
        # Loops through each image in the batch, removes the padding on input_boxes and obj_ground_truth_masks to ensure the loss isn't miscalculated by using empty inputs (because that happens??)
        # We take the mean of the batch still (using batch_loss_values) to ensure a smoother gradient by avoiding the noise from using individual images
        for image, input_box, obj_mask in zip(pixel_values, input_boxes, obj_ground_truth_masks):
             
            # Remove the padding from these batch values
            input_box, obj_mask = remove_invalid_boxes(input_box, obj_mask)


            # If the input somehow has no object masks we can skip
            if input_box.shape[1] > 0:


                # forward pass
                outputs = medsam_model(
                    pixel_values=image.unsqueeze(0), # Add batch dimension back
                    input_boxes=input_box,
                    multimask_output=False)



                # Get predicted masks and ground truth masks
                predicted_masks = outputs.pred_masks.squeeze(2)  # Remove extra singleton dimension from predicted masks (shape: [1, 20, 256, 256])
            
                # Convert object ground truth masks to binary to pass into MONAI loss function
                obj_mask = (obj_mask > 0).float()


                # Ensure the predicted and ground truth masks have the same shape
                #print("\n\nPredicted Mask shape: ",predicted_masks.shape)
                #print("obj gt shape: ", obj_mask.shape)

                # Convert logits to probabilities
                predicted_masks = torch.sigmoid(predicted_masks) 

                ground_truth_mask = batch["ground_truth_mask"]


                # Create a weight map by mapping each pixel’s ground truth to its corresponding weight
                weight_map = torch.zeros_like(ground_truth_mask, dtype=torch.float).cuda()
                for class_idx in range(len(class_weights)):  
                    weight_map[ground_truth_mask == class_idx] = class_weights[class_idx]


                # Calculate per-pixel weighted loss
                loss_per_pixel = dice_loss(predicted_masks, obj_mask)

                # Apply weight map to loss
                weighted_loss = (loss_per_pixel * weight_map.unsqueeze(1)).mean()

                # Add loss to batch list
                batch_loss_values.append(weighted_loss)
                print("Object Loss: ",weighted_loss)



                # Show predictions
                # with torch.no_grad():
                #         first_pred_mask = torch.sigmoid(predicted_masks[0, 0]).cpu().numpy()  # Convert to numpy for plotting
                #         first_gt_mask = obj_mask[0, 0].cpu().numpy()

                #         fig, ax = plt.subplots(1, 2, figsize=(10, 5))
                #         ax[0].imshow(first_pred_mask, cmap="gray")
                #         ax[0].set_title("Predicted Mask")

                #         ax[1].imshow(first_gt_mask, cmap="gray")
                #         ax[1].set_title("Ground Truth Mask")

                #         plt.show()


            else:
                # Debug print to catch the missing masks
                print("No masks found for:", input_box)



         # Calculate the mean of the batch and convert into a torch tensor for backpropagation
        loss = torch.stack(batch_loss_values).mean() if batch_loss_values else torch.tensor(0.0)
        print("Batch Loss: ", loss)

        # backward pass
        adam_optimizer.zero_grad()
        weighted_loss.backward()

        # optimize
        adam_optimizer.step()
        epoch_losses.append(weighted_loss.item())


        # Store batch loss
        batch_losses.append(weighted_loss.item())





        # NOTE: sigmoid and softmax have the same output range here indicating that the SAM and MedSAM models only predict one class by default. This might 
        # Calculate the max and min values for both predicted and ground truth masks
        # pred_min_value = F.softmax(predicted_masks).min().item()
        # pred_max_value = F.softmax(predicted_masks).max().item()
        # gt_min_value = obj_ground_truth_masks.min().item()
        # gt_max_value = obj_ground_truth_masks.max().item()

        # Print the min and max values for both predicted and ground truth masks
        #print(f"Predicted Mask - Min Value: {pred_min_value}, Max Value: {pred_max_value}")
        #print(f"Ground Truth Mask - Min Value: {gt_min_value}, Max Value: {gt_max_value}")
        

        



    print(f'EPOCH: {epoch}')
    print(f'Mean loss: {mean(epoch_losses)}')
     

# Save the model's state dictionary to a file
torch.save(medsam_model.state_dict(), "Models/medsam_vit_b_object_masks.pth")
     


In [None]:





#Training loop
num_epochs = 1

medsam_model.train()
for epoch in range(num_epochs):
    epoch_losses = []
    for batch in tqdm(train_dataloader):

        # forward pass
        outputs = medsam_model(pixel_values=batch["pixel_values"].to(device),
                        #input_masks=batch["labels"].to(device),
                        input_boxes=batch["input_boxes"].to(device),
                        multimask_output=False)


        # Get predicted masks and ground truth masks
        predicted_masks = outputs.pred_masks.squeeze(2)  # Remove extra singleton dimension from predicted masks (shape: [2, n, 256, 256])
        ground_truth_masks = batch["ground_truth_mask"].float().to(device).squeeze(1)  # Remove extra singleton dimension from ground truth (shape: [2, n, 256, 256])

       
       


    ################################### TESTING ########################################


        # # Initialize list for binary masks per class
        # num_classes = len(class_weights)  # Number of classes, including background
        # binary_masks = []


        # # Convert ground truth mask to binary mask for each class
        # for class_idx in range(1, num_classes):  # Skip background (class 0)
        #     binary_mask = (ground_truth_masks == class_idx).float()  # Binary mask for current class
        #     binary_masks.append(binary_mask)


        # # Stack the binary masks to match shape [batch_size, num_classes, height, width]
        # binary_masks = torch.stack(binary_masks, dim=1)  # Shape: [batch_size, num_classes-1, height, width]

        print("Min Class Label:", ground_truth_masks.min().item())  
        print("Max Class Label:", ground_truth_masks.max().item())


        # Reshape to merge the 'N' dimension into the batch
        predicted_masks = predicted_masks.view(-1, 1, 256, 256)  # (2*35, 1, 256, 256)
        ground_truth_masks = ground_truth_masks.view(-1, 1, 256, 256)  # (2*35, 1, 256, 256)

        # Convert ground truth to one-hot encoding for multi-class segmentation
        ground_truth_masks = F.one_hot(ground_truth_masks.squeeze(1).long(), num_classes=8).permute(0, 3, 1, 2).float()

        # Ensure predictions have the correct shape (B, C, H, W)
        predicted_masks = predicted_masks.repeat(1, 8, 1, 1)  # If needed, duplicate across 8 channels

        # Compute loss
        #loss = dice_ce_loss(predicted_masks, ground_truth_masks)






    ################################### TESTING ########################################

"""

        # Calculate loss using defined loss function
        loss = dice_ce_loss(predicted_masks, ground_truth_masks)
        print(loss)


        # backward pass
        adam_optimizer.zero_grad()
        loss.backward()

        # optimize
        adam_optimizer.step()
        epoch_losses.append(loss.item())




        # # NOTE: sigmoid and softmax have the same output range here indicating that the SAM and MedSAM models only predict one class by default. This might 
        # # Calculate the max and min values for both predicted and ground truth masks
        # pred_min_value = F.softmax(predicted_masks).min().item()
        # pred_max_value = F.softmax(predicted_masks).max().item()
        # gt_min_value = ground_truth_masks.min().item()
        # gt_max_value = ground_truth_masks.max().item()

        # Print the min and max values for both predicted and ground truth masks
        #print(f"Predicted Mask - Min Value: {pred_min_value}, Max Value: {pred_max_value}")
        #print(f"Ground Truth Mask - Min Value: {gt_min_value}, Max Value: {gt_max_value}")
        



    print(f'EPOCH: {epoch}')
    print(f'Mean loss: {mean(epoch_losses)}')


"""

# Save the model's state dictionary to a file
torch.save(medsam_model.state_dict(), "Models/medsam_vit_b_object_masks.pth")
     


### Batch Loss graph

In [None]:

# Compute moving average (window size = 10)
window_size = 10
smoothed_losses = np.convolve(batch_losses, np.ones(window_size) / window_size, mode='valid')

plt.figure(figsize=(10, 5))
plt.plot(batch_losses, alpha=0.3, label="Raw Loss")  # Light color for raw loss
plt.plot(range(window_size - 1, len(batch_losses)), smoothed_losses, color='red', label="Smoothed Loss")
plt.xlabel("Batch")
plt.ylabel("Loss")
plt.title("Training Loss per Batch (Smoothed)")
plt.legend()
plt.grid()
plt.show()

# Fine-Tuned MedSAM Evaluation

### Load Model

In [None]:
# Load the model configuration
model_config = SamConfig.from_pretrained("flaviagiammarino/medsam-vit-base")
processor = SamProcessor.from_pretrained("flaviagiammarino/medsam-vit-base")

# Create an instance of the model architecture with the loaded configuration
my_model = SamModel(config=model_config)
#Update the model by loading the weights from saved file.
my_model.load_state_dict(torch.load("Models\medsam_vit_b_object_masks_test.pth")) 

# set the device to cuda
device = "cuda"
my_model.to(device)
     

### Test Inference

In [None]:
batch = next(iter(train_dataloader))

# Get image from batch
image = batch["pixel_values"][0].detach().cpu().numpy().transpose(1, 2, 0)  # Convert to HxWxC format

# Get ground truth mask
ground_truth = batch["obj_ground_truth_masks"][0][0].detach().cpu().numpy()  # Convert to numpy (H, W)

# Get model prediction
image_tensor = batch["pixel_values"][0].unsqueeze(0).to(device)  # Add batch dimension
input_boxes = batch["input_boxes"][0].unsqueeze(0).to(device)

with torch.no_grad():
    outputs = my_model(pixel_values=image_tensor, input_boxes=input_boxes, multimask_output=False)
    predicted_mask = torch.sigmoid(outputs.pred_masks.squeeze(2)).cpu().numpy()[0]  # Convert to numpy (H, W)

# Plot the results
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(24, 8))

# Display input image
ax1.imshow(image)
for box in batch["input_boxes"][0]:
    rect = patches.Rectangle(
        (box[0], box[1]), 
        box[2] - box[0],  
        box[3] - box[1],  
        linewidth=2, 
        edgecolor='red', 
        facecolor='none'
    )
    ax1.add_patch(rect)

ax1.set_title("Example Input")

# Display ground truth mask
ax2.imshow(ground_truth, cmap='gray')

# Plot boxes for the second image
box = batch["input_boxes"][0][0]
box = (box / torch.tensor([1024,1024,1024,1024], device="cpu")) * 256

rect = patches.Rectangle(
    (box[0], box[1]),  
    box[2] - box[0],  
    box[3] - box[1],  
    linewidth=2, 
    edgecolor='green', 
    facecolor='none'
)
ax2.add_patch(rect)
ax2.set_title("Example GT (Loss)")

# Threshold prediction to get a clearer segmentation
threshold_prediction = (predicted_mask[0] > 0.7).astype(np.uint8)

# Display predicted mask
ax3.imshow(threshold_prediction, cmap="viridis")
ax3.set_title("Predicted Mask")

plt.show()

### Not working old code

In [None]:

importlib.reload(model_evaluator)
evaluator = ModelEvaluator(my_model, processor, test_dataset)


%matplotlib inline

# Get correct preprocessing
test_dataset.return_as_medsam = True
test_dataset.resize_mask = False

# Load random image
#image_idx = random.randint(0, len(test_dataset)-1)

image_idx = 33

# Get tensors
img_np, box_np, gt_masks, bounding_boxes = test_dataset[image_idx].values()

# Get original image
test_dataset.return_as_medsam = False
img_original = test_dataset[image_idx]["pixel_values"]
W, H, _ = img_original.shape

# Show image
test_dataset.show_image_mask(image_idx)

# image embedding
with torch.no_grad():
    image_embedding = my_model.image_encoder(img_np)


# Run inference for all boxes in a batch
with torch.no_grad():
    seg_masks = evaluator.medsam_inference(image_embedding, box_np, H, W)  # List of 5 masks

# Plot results
fig, ax = plt.subplots(1, 2, figsize=(15, 7))


if len(seg_masks.shape) == 2:
    seg_masks = [seg_masks]


# Original image with bounding boxes
ax[0].imshow(img_original)
for box in bounding_boxes:
    show_box(box, ax[0])
ax[0].set_title("Input Image and Bounding Boxes")

# Image with segmentation masks
ax[1].imshow(img_original)
for box, mask in zip(bounding_boxes, seg_masks):  # Iterate over all boxes and their masks
    show_mask(mask, ax[1]) # random_colour = True
    show_box(box, ax[1])
ax[1].set_title("Base MedSAM Segmentation")

plt.show()

In [None]:
# Verify batch item sizes
batch = next(iter(train_dataloader))
for k,v in batch.items():
  print(f"{k:<25} Shape: {str(v.shape):<30} Dtype: {v.dtype}")




print("Example Input:\n")

# Get image from batch
image = batch["pixel_values"][0].detach().cpu().numpy().transpose(1, 2, 0)  # Convert to HxWxC format

print(image.shape)

fig, ax = plt.subplots(figsize=(8, 8))
ax.imshow(image)  # Show the image

# Plot predicted boxes
for box in batch["input_boxes"][0]:
    rect = patches.Rectangle(
        (box[0], box[1]),  # x, y (top-left corner)
        box[2] - box[0],  # width
        box[3] - box[1],  # height
        linewidth=2,
        edgecolor='red',
        facecolor='none',
        label='Predicted Box'
    )
    ax.add_patch(rect)

    


ax.set_title(f"Test Inference")
plt.show()



### MedSAM Fine-Tuned Model Evaluation

In [None]:
test_dataset.preprocess_for_fine_tuning  = True
test_dataset.resize_mask  = True
test_dataset.return_individual_objects = True



importlib.reload(model_evaluator)
evaluator = ModelEvaluator(my_model, processor, test_dataset)

test_dataloader = DataLoader(test_dataset,  batch_size=1, shuffle=False)

results = evaluator.evaluate_medsam_model(test_dataloader)
evaluator.print_results()