In [1]:
# required to be able to import different notebooks
!pip install import-ipynb
import import_ipynb



In [None]:
# important libraries and notebooks
import matplotlib.pyplot as plt
import torch
import torchvision
from dataset import(
    get_plane_datasets,
    get_plane_dataloader
)

In [None]:
# this method is used incase of train a model with only one planetype instead of the 3 together
# it returns the training, validation and testing dataloaders for a specific plane type
def get_plane_loader(plane_type="axial"):
    train_dataset, validation_dataset, test_dataset = get_plane_datasets(plane_type)
    train_loader, val_loader, test_loader = get_plane_dataloader(train_dataset, validation_dataset, test_dataset)

    return train_loader, val_loader, test_loader

In [2]:
# Functions to save predictions as images
def save_predictions_as_imgs(mask, mask_name, folder="saved_images/", device="cuda"):
    torchvision.utils.save_image(mask.unsqueeze(1), f"{folder}/{mask_name}.jpg")

In [None]:
# methodto concatunate datasets together to merge all planes datasets in one
def concat_datasets(dataset_1, dataset_2):
    dataset = torch.utils.data.ConcatDataset([dataset_1, dataset_2])
    return dataset

In [None]:
def get_all_planes_dataloaders():
  # get all three planes three datasets (train, validation, and test)
    axial_train_dataset, axial_validation_dataset, axial_test_dataset = get_plane_datasets("axial")
    coronal_train_dataset, coronal_validation_dataset, coronal_test_dataset = get_plane_datasets("coronal")
    sagittal_train_dataset, sagittal_validation_dataset, sagittal_test_dataset = get_plane_datasets("sagittal")

  #combine datasets of all 3 planes(axial, coronal, and sagittal)
  # 1. combine train datasets inside train_ds
    dataset = concat_datasets(axial_train_dataset, coronal_train_dataset)
    train_ds = concat_datasets(dataset, sagittal_train_dataset)
  # 2. combine validation datasets inside val_ds
    dataset = concat_datasets(axial_validation_dataset, coronal_validation_dataset)
    val_ds = concat_datasets(dataset, sagittal_validation_dataset)
  # 3. combine test datasets inside test_ds
    dataset = concat_datasets(axial_test_dataset, coronal_test_dataset)
    test_ds = concat_datasets(dataset, sagittal_test_dataset)
  
  # get dataloaders of all data
    train_loader, val_loader, test_loader = get_plane_dataloader(train_ds, val_ds, test_ds)
    return train_loader, val_loader, test_loader

In [None]:
# calculate accuracy for multiclass segmentation, this function is called for every patch
def calc_accuracy(pred, label):
    probs = torch.log_softmax(pred, dim = 1)
    _, tags = torch.max(probs, dim = 1)
    corrects = torch.eq(tags,label).float()
    acc = corrects.sum()/corrects.numel()
    return acc.item()

In [None]:
# calculate mean Iou for multiclass segmentation, this function is called for every patch
def calc_iou(label, pred, classes=7): 
    pred = torch.nn.functional.softmax(pred, dim=1)              
    pred = torch.argmax(pred, dim=1).squeeze(1)
    patch_iou = 0.0
    class_iou = 0.0
    pred = pred.view(-1)
    label = label.view(-1)

    for class_ in range(classes):
        pred_inds = (pred == class_)
        target_inds = (label == class_)
        if target_inds.long().sum().item() == 0:
            class_iou = 0
        else: 
            class_intersection = (pred_inds[target_inds]).long().sum().item()
            class_union = pred_inds.long().sum().item() + target_inds.long().sum().item() - class_intersection
            class_iou = float(class_intersection) / float(class_union)#calc the mean iou for each class in a given patch

        patch_iou += class_iou
    patch_iou /= classes
    return patch_iou #return the mean iou of the means ious for each class 

In [None]:
#this method is used to plot lines of training and validation (accuracy, Iou, loss)
def plot_metric(X, y1, y2, y1_label, y2_label, x_label, y_label, title):
    plt.plot(X, y1, label = y1_label, marker='o')
    plt.plot(X, y2, label = y2_label, marker='o')
    plt.xlabel(x_label)
    plt.ylabel(y_label)
    plt.title(title)
    plt.legend()
    plt.show()