In [None]:
import os
import torch
import segmentation_models_pytorch as smp
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
from utils.oct_dataset import OCTDataset
from utils.lossfunctions import DiceLoss
from utils.models import ResNetUNetWithAttention, MedSAM
from segment_anything import sam_model_registry
import torchmetrics

In [None]:
def save_image_with_prediction_and_mask(image, predicted, mask, image_id, save_dir, model_name):
    # Convert tensors to numpy arrays
    image_np = image.cpu().numpy().transpose(1, 2, 0)
    predicted_np = predicted.cpu().numpy().squeeze()
    mask_np = mask.cpu().numpy().squeeze()

    # Create a figure with 3 subplots
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))

    # Plot the original image
    axes[0].imshow(image_np)
    axes[0].set_title("Original Image")
    axes[0].axis("off")

    # Plot the predicted mask
    axes[1].imshow(predicted_np, cmap="gray")
    axes[1].set_title("Predicted Mask")
    axes[1].axis("off")

    # Plot the ground truth mask
    axes[2].imshow(mask_np, cmap="gray")
    axes[2].set_title("Ground Truth Mask")
    axes[2].axis("off")

    # Save the figure
    save_path = os.path.join(save_dir, f"{image_id}_prediction_{model_name}.png")
    plt.savefig(save_path)
    plt.close(fig)


In [None]:

def test_models(models_list, root_dir, save_dir):
    # Define device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # Initialize lists to store results
    model_names = []
    dice_coeffs = []
    image_ids = []

    # Define dataset and transformations
    transform = transforms.Compose([
        transforms.Resize((1024, 1024), interpolation=Image.NEAREST),
        transforms.ToTensor(),
    ])

    test_dataset = OCTDataset("/Users/studiesamuel/Library/CloudStorage/OneDrive-Aarhusuniversitet/Deep Learning/data_gentuity",
        transform=transform,
        train=False,
        is_gentuity=True,
    )
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=4)

    # Define loss function
    criterion = DiceLoss()

    # Loop through models and test each one
    for model_name, model_config in models_list:
        print(f"Testing model: {model_name}")

        # Initialize the model
        if model_config["model"] == "Unet":
            net = smp.Unet(
                encoder_name="resnet50",
                encoder_weights="imagenet",
                in_channels=3,
                classes=1,
            )
        elif model_config["model"] == "DeepLabV3+":
            net = smp.DeepLabV3Plus(
                encoder_name="resnet50",
                encoder_weights="imagenet",
                in_channels=3,
                classes=1,
            )
        elif model_config["model"] == "MedSam":
            sam_model = sam_model_registry['vit_b'](checkpoint="utils/medsam_vit_b.pth")
            net = MedSAM(
                image_encoder=sam_model.image_encoder,
                mask_decoder=sam_model.mask_decoder,
                prompt_encoder=sam_model.prompt_encoder,
            )
            checkpoint = torch.load(model_config["checkpoint_path"], weights_only=True, map_location=torch.device('cpu'))

            # Update the model state with the checkpoint
            net.load_state_dict(checkpoint["model"])  # Assuming the checkpoint has a key 'model' for the weights

        elif model_config["model"] == "AttentionUnet":
            net = ResNetUNetWithAttention()

        # Load model checkpoint
        if(model_config["model"] != "MedSam"):
            checkpoint_path = model_config["checkpoint_path"]
            model_state, optimizer_state = torch.load(checkpoint_path, weights_only=True, map_location=torch.device('cpu'))
            net.load_state_dict(model_state)
        
        net.to(device)
        net.eval()

        # Test the model
        model_dice_scores = []
        total_loss = 0  # Accumulate loss for logging

        with torch.no_grad():  # Disable gradient calculation
            for image_id, data in enumerate(test_loader):
                if model_config["model"] == "MedSam":
                    # For MedSAM, process with bounding boxes
                    images, masks, _, _ = data
                    images, masks = images.to(device), masks.to(device)

                    # Get image dimensions
                    batch_size, _, height, width = images.size()

                    # Create bounding boxes covering the entire image
                    bboxes = torch.tensor([[0, 0, width, height]] * batch_size, dtype=torch.float32).unsqueeze(1).to(device)

                    # Predict outputs with bounding boxes
                    outputs = net(images, bboxes)
                    predicted = (outputs > 0.5).float()
                else:
                    # For other models
                    images, masks, _, _ = data
                    images, masks = images.to(device), masks.to(device)

                    outputs = net(images)
                    predicted = (outputs > 0.5).float()

                # Calculate loss dice score from torchmetrics
                dice_metric = torchmetrics.Dice()
                dice_score = dice_metric(predicted, masks.int())
                
                model_dice_scores.append(dice_score.item())
                print(f"Dice score: {dice_score.item():.4f}")

                # Save image with predictions and ground truth mask
                save_image_with_prediction_and_mask(images[0], predicted[0], masks[0], image_id, save_dir, model_name)

                # Store image ID for plotting later
                image_ids.append(image_id)

                print(f"Progress: {len(model_dice_scores)} / {len(test_loader)}", end="\r")

                # # if 10 images break
                # if len(model_dice_scores) == 5:
                #     break



        # Store results
        model_names.extend([model_name] * len(model_dice_scores))
        dice_coeffs.extend(model_dice_scores)

        print(f"{model_name} - Average Dice accuracy: {sum(model_dice_scores) / len(model_dice_scores):.4f}")
    
    return model_names, dice_coeffs

models_list = [
    ("MedSAM", {"model": "MedSam", "checkpoint_path": "models_local/checkpoint_bs=6_medsam_frozen_DiceBCELoss.pth"}),
    ("AttentionUnet", {"model": "AttentionUnet", "checkpoint_path": "models_local/checkpoint_bs=6_AttentionUnetUnFrozen_DICEBCELoss.pt"}),
    ("U-Net", {"model": "Unet", "checkpoint_path": "models_local/checkpoint_bs=6_Unet_unfrozen_DiceBCELoss.pt"}),
    ("DeepLabV3+", {"model": "DeepLabV3+", "checkpoint_path": "models_local/checkpoint_bs=6_Deeplabv3_unfrozen_DiceBCELoss.pt"}),
    
]

root_dir = ""
save_dir = "output_images"  # Directory to save images with predictions
os.makedirs(save_dir, exist_ok=True)


In [None]:
# Run the function to get Dice scores and model names
model_names, dice_coeffs = test_models(models_list, root_dir, save_dir)


In [None]:

# Create a DataFrame for visualization
results_df = pd.DataFrame({
    "Model": model_names,
    "Dice Score": dice_coeffs,
})

# Generate a boxplot
plt.figure(figsize=(12, 8))
sns.boxplot(x="Model", y="Dice Score", data=results_df, hue="Model", legend=False)
plt.title("Performance comparison on terumo testset")
plt.ylabel("Dice Similarity Coefficient")
plt.xticks(rotation=45)
plt.show()

In [None]:
# Show saved image with predictions
# Open the image
image_path = "output_images/0_prediction_MedSAM.png"
image = Image.open(image_path)

# Display inline in the notebook
plt.imshow(image)
plt.axis('off')  # Hide axes
plt.show()  # Display the image inline
