In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import torch as tc
import json
from torch import nn
from torchvision import models
from monai.networks.nets import UNet
from monai.metrics import DiceMetric, HausdorffDistanceMetric

import sys
sys.path.append('/content/drive/MyDrive/Aorta_Segmentation')
from Dataset_loader import Dataset

In [None]:
def get_info_from_path(path):
    info_dict = tc.load(path)
    train_loss_lst = info_dict["train_loss_lst"]
    val_loss_lst = info_dict["val_loss_lst"]
    time = info_dict["time"]
    return train_loss_lst, val_loss_lst, time

Unet_train_losses, Unet_val_losses, Unet_time = get_info_from_path('/content/drive/MyDrive/Aorta_Segmentation/Training_info/UNet_losses.pt')
FCN_train_losses, FCN_val_losses, FCN_time = get_info_from_path('/content/drive/MyDrive/Aorta_Segmentation/Training_info/FCN_losses.pt')
FCN_pretrained_train_losses, FCN_pretrained_val_losses, FCN_pretrained_time = get_info_from_path('/content/drive/MyDrive/Aorta_Segmentation/Training_info/FCN_pretrained_losses.pt')
Deeplabv3_train_losses, Deeplabv3_val_losses, Deeplabv3_time = get_info_from_path('/content/drive/MyDrive/Aorta_Segmentation/Training_info/deeplabv3_losses.pt')
Deeplabv3_pretrained_train_losses, Deeplabv3_pretrained_val_losses, Deeplabv3_pretrained_time = get_info_from_path('/content/drive/MyDrive/Aorta_Segmentation/Training_infodeeplabv3_pretrained_losses.pt')

In [None]:
def plot_losses(rows, cols, idx, title, train_losses, val_losses):
    plt.subplot(rows, cols, idx)
    plt.plot(train_losses, "r-", label="Training Loss")
    plt.plot(val_losses, "b-", label="Validation Loss")
    plt.title(title)
    plt.xlabel("Epoch",)
    plt.ylabel("Loss")
    plt.grid(True)
    if idx == 1:
        plt.legend

In [None]:
data_tuples = [
    ("UNet - Losses", Unet_train_losses, Unet_val_losses),
    ("FCN Losses", FCN_train_losses, FCN_val_losses),
    ("FCN Pretrained Losses", FCN_pretrained_train_losses, FCN_pretrained_val_losses),
    ("DLv3 Losses", Deeplabv3_train_losses, Deeplabv3_val_losses),
    ("DLv3 Pretrained Losses", Deeplabv3_pretrained_train_losses, Deeplabv3_pretrained_val_losses)
]

plt.figure(figsize=(15, 10), dpi=300)
for idx, tuple in enumerate(data_tuples):
    plot_losses(2, 3, idx+1, tuple[0], tuple[1], tuple[2])

plt.show()

In [None]:
with open("/content/drive/MyDrive/Aorta_Segmentation/Data_names/data_names_dict.txt", "w") as file:
    data_names_dict = json.load(file)

test_data_names = data_names_dict["test"]
Testing_dataset = Dataset(test_data_names)
Testing_dataset.Preprocess_Data()
testing_loader = tc.utils.data.DataLoader(Testing_dataset, batch_size=64, shuffle=True)

In [None]:
class TestModel:
    def __init__(self, model, model_name):
        self.model = model
        self.model_name = model_name
        self.predictions = []
        self.mask_lst = []
        self.threshold = 0.5
        self.device = tc.device("cuda" if tc.cuda.is_available() else "cpu")
        self.model = self.model.to(self.device)

        #Metrics
        self.dice_func = DiceMetric(include_background=True, reduction="mean", get_not_nans=True)
        self.hd95_func = HausdorffDistanceMetric(include_background=True, reduction="mean", get_not_nans=True)
        self.dice_score = []
        self.hd95 = [] 
    
    def pred_loop(self, testing_loader):
        for images, masks in testing_loader:
            images = images.to(self.device)
            masks = masks.to(self.device)
            output = self.model(images.unsqueeze(1))
            if isinstance(output, dict):
                output = output["out"]
            output = tc.nn.Sigmoid()(output)
            output = (output > self.threshold).int()
            self.predictions.append(output.cpu())
            self.mask_lst.append(masks.unsqueeze(1).cpu())
        self.predictions = tc.cat(self.predictions, dim=0)
        self.mask_lst = tc.cat(self.mask_lst, dim=0)
        self.calculate_dice_score()
        self.calculate_hd95()
    
    def test_model(self):
        self.model.eval()
        with tc.no_grad():
            self.pred_loop()
    
    def calculate_dice_score(self):
        dice_score = self.dice_func(self.predictions, self.mask_lst)
        dice_score = dice_score[~dice_score.isnan()]
        dice_score = dice_score[~dice_score.isinf()]
        self.dice_score = dice_score
    
    def calculate_hd95(self):
        hd95_score = self.hd95_func(self.predictions, self.mask_lst)
        hd95_score = hd95_score[~hd95_score.isnan()]
        hd95_score = hd95_score[~hd95_score.isinf()]
        self.hd95 = hd95_score
    
    def get_predictions(self):
        return self.predictions
    
    def get_masks(self):
        return self.mask_lst

    def print_model_metrics(self):
        print(f"------- {self.model_name} metrics -------")
        print(f"Average Dice score: {self.dice_score.mean():.4f}")
        print(f"Minimum Dice score: {self.dice_score.min():.4f}")
        print(f"Maximum Dice score: {self.dice_score.max():.4f}")
        print(f"Std Dice score: {self.dice_score.std():.4f}")
        print("----")
        print(f"Average HD95 score: {self.hd95.mean():.4f}")
        print(f"Minimum HD95 score: {self.hd95.min():.4f}")
        print(f"Maximum HD95 score: {self.hd95.max():.4f}")
        print(f"Std HD95 score: {self.hd95.std():.4f}")

In [None]:
model_Unet = UNet(
        spatial_dims = 2,
        in_channels=1,
        out_channels=1,
        channels=[16, 32, 64, 128, 256],
        strides=(2, 2, 2, 2),
        dropout=0.16454162080022391
    )

model_FCN = models.segmentation.fcn_resnet50(weights=None, num_classes=1)
model_FCN.backbone.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)

model_FCN_pretrained = models.segmentation.fcn_resnet50(weights=None, num_classes=1) #there is no point to load weights if we load state dict 
model_FCN_pretrained.backbone.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)

model_deeplabv3 = models.segmentation.deeplabv3_mobilenet_v3_large(weights=None, num_classes=1)
model_deeplabv3.backbone._modules["0"]._modules["0"] = nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1, bias=False)

model_deeplabv3_pretrained = models.segmentation.deeplabv3_mobilenet_v3_large(weights=None, num_classes=1)
model_deeplabv3_pretrained.backbone._modules["0"]._modules["0"] = nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1, bias=False)

model_Unet.load_state_dict(tc.load('/content/drive/MyDrive/Aorta_Segmentation/Models/UNet_trained.pth'))
model_FCN.load_state_dict(tc.load('/content/drive/MyDrive/Aorta_Segmentation/Models/FCN_trained.pth'))
model_FCN_pretrained.load_state_dict(tc.load('/content/drive/MyDrive/Aorta_Segmentation/Models/FCN_pretrained_trained.pth'))
model_deeplabv3.load_state_dict(tc.load('/content/drive/MyDrive/Aorta_Segmentation/Models/deeplabv3_trained.pth'))
model_deeplabv3_pretrained.load_state_dict(tc.load('/content/drive/MyDrive/Aorta_Segmentation/Models/deeplabv3_pretrained_trained.pth'))


In [None]:
models_dict = {
    "Unet": TestModel(model_Unet, "UNet"),
    "FCN": TestModel(model_FCN, "FCN"),
    "FCN pretrained": TestModel(model_FCN_pretrained, "FCN pretrained"),
    "DLv3": TestModel(model_deeplabv3, "Deeplabv3"),
    "DLv3 pretrained": TestModel(model_deeplabv3_pretrained, "Deeplabv3 pretrained")
}

for key in models_dict:
    models_dict[key].test_model()
    models_dict[key].print_model_metrics()

In [None]:
model_names = []
dices_lst = []
hd95_lst = []

for key in models_dict:
    model_names.append(models_dict[key].model_name)
    dices_lst.apeend(models_dict[key].dice_score.mean().item())
    hd95_lst.append(models_dict[key].hd95.mean().item())


plt.figure(figsize=(10,5))
ax = sns.barplot(x=model_names, y=dices_lst, palette=sns.color_palette(palette='cool'))
for i, value in enumerate(dices_lst):
    ax.text(i, value + 0.02, f'{value:.2f}', ha='center', va='bottom', fontsize=12)
plt.title("Dice Comparison")
plt.ylabel("Average Dice score")
plt.ylim(0, 1)
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

In [None]:
plt.figure(figsize=(10,5))
ax_2 = sns.barplot(x=model_names, y=hd95_lst, palette=sns.color_palette(palette='cool'))
for i, value in enumerate(hd95_lst):
    ax_2.text(i, value + 0.02, f'{value:.2f}', ha='center', va='bottom', fontsize=12)
plt.title("HD95 Comparison")
plt.ylabel("Average HD95 score")
plt.ylim(0, 23)
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()