# Jaws Segmentation Task

In [2]:
! pip install --user torch torchvision matplotlib numpy progressbar

In [4]:
# important libraries
import urllib.request
import zipfile
import os
import progressbar
from math import ceil
import torch
import gzip
import numpy as np
import glob
import matplotlib.pyplot as plt
from torchvision import transforms
import torch
import torch.nn as nn
import torchvision
import numpy as np
import torch.optim as optim
from tqdm import tqdm

## Dataset
#### Download_Data (data flag):
1. Download_Data = True, if you will download the data online
2. Download_Data = False, if you already downloaded the data

In [5]:
LOCAL_DATASET_PATH = 'dataset'
BATCH_SIZE = 16
Download_Data = False
## data url
AXIAL_TRAINING_DATASET = 'https://cvml-datasets.s3.eu-west-3.amazonaws.com/jaws-segmentation/v1/public/2d/axial/train.zip'
AXIAL_TESTING_DATASET = 'https://cvml-datasets.s3.eu-west-3.amazonaws.com/jaws-segmentation/v1/public/2d/axial/test.zip'
CORONAL_TRAINING_DATASET = 'https://cvml-datasets.s3.eu-west-3.amazonaws.com/jaws-segmentation/v1/public/2d/coronal/train.zip'
CORONAL_TESTING_DATASET = 'https://cvml-datasets.s3.eu-west-3.amazonaws.com/jaws-segmentation/v1/public/2d/coronal/test.zip'
SAGITTAL_TRAINING_DATASET = 'https://cvml-datasets.s3.eu-west-3.amazonaws.com/jaws-segmentation/v1/public/2d/sagittal/train.zip'
SAGITTAL_TESTING_DATASET = 'https://cvml-datasets.s3.eu-west-3.amazonaws.com/jaws-segmentation/v1/public/2d/sagittal/test.zip'
## data path after being downloaded
down_axial_training_data = 'dataset/axial/train/**/*.dicom.npy.gz'
down_axial_testing_data = 'dataset/axial/test/**/*.dicom.npy.gz'
down_coronal_training_data = 'dataset/coronal/train/**/*.dicom.npy.gz'
down_coronal_testing_data = 'dataset/coronal/test/**/*.dicom.npy.gz'
down_sagittal_training_data = 'dataset/sagittal/train/**/*.dicom.npy.gz'
down_sagittal_testing_data = 'dataset/sagittal/test/**/*.dicom.npy.gz'

#### Downloading Dataset

In [6]:
download_progress_bar = None
def show_progress(block_num, block_size, total_size):
    global download_progress_bar
    if download_progress_bar is None:
        download_progress_bar = progressbar.ProgressBar(maxval=total_size)
        download_progress_bar.start()

    downloaded = block_num * block_size
    if downloaded < total_size:
        download_progress_bar.update(downloaded)
    else:
        download_progress_bar.finish()
        download_progress_bar = None

def download_file(url, disk_path):
    print(f'downloading {url}')
    filename, _ = urllib.request.urlretrieve(url, reporthook=show_progress)
    os.makedirs(disk_path)
    with zipfile.ZipFile(filename, 'r') as zip:
        zip.extractall(disk_path)

def download_data(to=LOCAL_DATASET_PATH):
    download_file(AXIAL_TRAINING_DATASET, os.path.join(to, 'axial', 'train'))
    download_file(AXIAL_TESTING_DATASET, os.path.join(to, 'axial', 'test'))
    download_file(CORONAL_TRAINING_DATASET, os.path.join(to, 'coronal', 'train'))
    download_file(CORONAL_TESTING_DATASET, os.path.join(to, 'coronal', 'test'))
    download_file(SAGITTAL_TRAINING_DATASET, os.path.join(to, 'sagittal', 'train'))
    download_file(SAGITTAL_TESTING_DATASET, os.path.join(to, 'sagittal', 'test'))

In [26]:
if(Download_Data == True):
    download_data()

#### Class and methods required to read, split, set in dataloader and plot data

In [7]:
class JawsDataset(torch.utils.data.Dataset):
	def __init__(self, dicom_file_list, transforms):
		self.dicom_file_list = dicom_file_list
		self.transforms = transforms

	def __len__(self):
		return len(self.dicom_file_list)

	def __getitem__(self, idx):
		dicom_path = self.dicom_file_list[idx]
		label_path = dicom_path.replace('.dicom.npy.gz', '.label.npy.gz')
		dicom_file = gzip.GzipFile(dicom_path, 'rb')
		dicom = np.load(dicom_file)
		label_file = gzip.GzipFile(label_path, 'rb')
		label = np.load(label_file)
		return self.transforms(dicom), self.transforms(label)

def axial_dataset_train(transforms, validation_ratio = 0.1):
	files = glob.glob(down_axial_training_data)
	assert len(files) > 0
	validation_files_count = ceil(len(files) * validation_ratio)

	return (JawsDataset(files[validation_files_count:], transforms),
			JawsDataset(files[:validation_files_count], transforms))

def coronal_dataset_train(transforms, validation_ratio = 0.1):
	files = glob.glob(down_coronal_training_data)
	assert len(files) > 0
	validation_files_count = ceil(len(files) * validation_ratio)

	return (JawsDataset(files[validation_files_count:], transforms),
			JawsDataset(files[:validation_files_count], transforms))

def sagittal_dataset_train(transforms, validation_ratio = 0.1):
	files = glob.glob(down_sagittal_training_data)
	assert len(files) > 0
	assert len(files) > 0
	validation_files_count = ceil(len(files) * validation_ratio)

	return (JawsDataset(files[validation_files_count:], transforms),
			JawsDataset(files[:validation_files_count], transforms))

def axial_dataset_test(transforms):
	files = glob.glob(down_axial_testing_data)
	assert len(files) > 0
	return JawsDataset(files, transforms)

def coronal_dataset_test(transforms):
	files = glob.glob(down_coronal_testing_data)
	assert len(files) > 0
	return JawsDataset(files, transforms)

def sagittal_dataset_test(transforms):
	files = glob.glob(down_sagittal_testing_data)
	assert len(files) > 0
	return JawsDataset(files, transforms)


In [8]:
dataset_transforms = transforms.Compose([transforms.ToTensor(), transforms.Resize((128, 128)), transforms.Normalize(mean=[0.0], std=[1.0])])

In [9]:
## take the name of any plane (axial or coronal or sagittal) and return train, validation and test datasets
def get_plane_datasets(plane_type="axial"):
    if(plane_type.lower() == "sagittal"):
        points_train_dataset, points_validation_dataset = sagittal_dataset_train(dataset_transforms)
        points_test_dataset = sagittal_dataset_test(dataset_transforms)
    elif(plane_type.lower() == "coronal"):
        points_train_dataset, points_validation_dataset = coronal_dataset_train(dataset_transforms)
        points_test_dataset = coronal_dataset_test(dataset_transforms)
    else:
        points_train_dataset, points_validation_dataset = axial_dataset_train(dataset_transforms)
        points_test_dataset = axial_dataset_test(dataset_transforms)
    return points_train_dataset, points_validation_dataset, points_test_dataset

In [10]:
##set the dataset into a dataloaders
def get_plane_dataloader(train_ds, val_ds, test_ds):
    train_loader = torch.utils.data.DataLoader(train_ds, shuffle=True, batch_size=BATCH_SIZE)
    val_loader = torch.utils.data.DataLoader(val_ds, shuffle=True, batch_size=BATCH_SIZE)
    test_loader = torch.utils.data.DataLoader(test_ds, shuffle=True, batch_size=BATCH_SIZE)
    return train_loader, val_loader, test_loader

In [11]:
## split loader into images and their masks
def get_images_labels(data_loader):
	data_iter = iter(data_loader)
	images, labels = data_iter.next()
	return images, labels

In [13]:
## plot sample of the image and masks data
def plot_images_labels(images, labels):
    plt.figure(figsize=(16, 4))
    for index in range(8, min(16, len(images))):
        plt.subplot(2, 8, index + 1)
        plt.axis('off')
        plt.imshow(images[index].numpy().squeeze(), cmap='bone')
        plt.imshow(labels[index].numpy().squeeze(), alpha=0.3)

## Explore the dataset

In [None]:
## Get the planes' dataset
axial_train_dataset, axial_validation_dataset, axial_test_dataset = get_plane_datasets("axial")
coronal_train_dataset, coronal_validation_dataset, coronal_test_dataset = get_plane_datasets("coronal")
sagittal_train_dataset, sagittal_validation_dataset, sagittal_test_dataset = get_plane_datasets("sagittal")

In [None]:
## Print the length of each train, validation or test datasets for each one of the three planes types
print(f'axial training dataset: {len(axial_train_dataset)} slice')
print(f'axial validation dataset: {len(axial_validation_dataset)} slice')
print(f'axial testing dataset: {len(axial_test_dataset)} slice \n')

print(f'coronal training dataset: {len(coronal_train_dataset)} slice')
print(f'coronal validation dataset: {len(coronal_validation_dataset)} slice')
print(f'coronal testing dataset: {len(coronal_test_dataset)} slice \n')

print(f'sagittal training dataset: {len(sagittal_train_dataset)} slice')
print(f'sagittal validation dataset: {len(sagittal_validation_dataset)} slice')
print(f'sagittal testing dataset: {len(sagittal_test_dataset)} slice')

In [34]:
## convert datasets into dataloaders
axial_train_loader, axial_val_loader, axial_test_loader = get_plane_dataloader(axial_train_dataset, axial_validation_dataset, axial_test_dataset)
coronal_train_loader, coronal_val_loader, coronal_test_loader = get_plane_dataloader(coronal_train_dataset, coronal_validation_dataset, coronal_test_dataset)
sagittal_train_loader, sagittal_val_loader, sagittal_test_loader = get_plane_dataloader(sagittal_train_dataset, sagittal_validation_dataset, sagittal_test_dataset)

In [35]:
## Get images and masks of each dataloader

# 1. axial plane images and labels for (train, validation and test data)
axial_train_images, axial_train_labels = get_images_labels(axial_train_loader)
axial_val_images, axial_val_labels = get_images_labels(axial_val_loader)
axial_test_images, axial_test_labels = get_images_labels(axial_test_loader)

# 2. coronal plane images and labels for (train, validation and test data)
coronal_train_images, coronal_train_labels = get_images_labels(coronal_train_loader)
coronal_val_images, coronal_val_labels = get_images_labels(coronal_val_loader)
coronal_test_images, coronal_test_labels = get_images_labels(coronal_test_loader)

# 3. sagittal plane images and labels for (train, validation and test data)
sagittal_train_images, sagittal_train_labels = get_images_labels(sagittal_train_loader)
sagittal_val_images, sagittal_val_labels = get_images_labels(sagittal_val_loader)
sagittal_test_images, sagittal_test_labels = get_images_labels(sagittal_test_loader)

In [36]:
# plot sample of axial plane (train, val and test) images
plot_images_labels(axial_train_images, axial_train_labels)
plot_images_labels(axial_val_images, axial_val_labels)
plot_images_labels(axial_test_images, axial_test_labels)

In [37]:
# plot sample of coronal plane (train, val and test) images
plot_images_labels(coronal_train_images, coronal_train_labels)
plot_images_labels(coronal_val_images, coronal_val_labels)
plot_images_labels(coronal_test_images, coronal_test_labels)

In [38]:
# plot sample of sagittal plane (train, val and test) images
plot_images_labels(sagittal_train_images, sagittal_train_labels)
plot_images_labels(sagittal_val_images, sagittal_val_labels)
plot_images_labels(sagittal_test_images, sagittal_test_labels)

In [61]:
#print shape of an image and its mask
print(sagittal_train_images[2].shape)
print(sagittal_train_labels[2].shape)

In [62]:
#unique values of one image
sagittal_train_images[2].unique()

In [None]:
hist = torch.histc(sagittal_train_images[2], bins = 2, min = 0, max = 1)

bins = 2
x = range(bins)
plt.bar(x, hist, align='center')
plt.xlabel('Bins')

## Build UNET Model
model structure

In [65]:
def crop_img(old_tensor, current_tensor):
  current_size = current_tensor.size()[2]#height
  old_size = old_tensor.size()[2]
  cropped_pixel = (old_size - current_size) //2
  # print(cropped_pixel)
  return old_tensor[:, :, cropped_pixel:old_size-cropped_pixel, cropped_pixel:old_size-cropped_pixel]

In [66]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)

In [67]:
class UNET(nn.Module):
    def __init__(
            self, in_channels=1, out_channels=2, features=[64, 128, 256, 512],
    ):
        super(UNET, self).__init__()
        self.up_sampling = nn.ModuleList()
        self.down_sampling = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Down sampling part
        for feature in features:#small the large image
            self.down_sampling.append(DoubleConv(in_channels, feature))
            in_channels = feature
        self.base_layer = DoubleConv(features[-1], features[-1]*2)

        # Up sampling part
        for feature in reversed(features):#large the small image
            self.up_sampling.append(
                nn.ConvTranspose2d(
                    feature*2, feature, kernel_size=2, stride=2,
                )
            )
            self.up_sampling.append(DoubleConv(feature*2, feature))
        self.out_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
      #encoder path
        outs_doubleconv_downs = []
        for down in self.down_sampling:
            x = down(x)
            outs_doubleconv_downs.append(x)
            x = self.pool(x)
        x = self.base_layer(x)

        #decoder path
        outs_doubleconv_downs = outs_doubleconv_downs[::-1]
        for i in range(0, len(self.up_sampling), 2):
            x = self.up_sampling[i](x)
            out_doubleconv_downs = outs_doubleconv_downs[i//2]

            if x.shape != out_doubleconv_downs.shape:
                x = crop_img(out_doubleconv_downs, x)

            concat_img = torch.cat((out_doubleconv_downs, x), dim=1)
            x = self.up_sampling[i+1](concat_img)

        return self.out_conv(x)

## Utils
helper functions used in traning, validate and testing the models

In [69]:
# this method is used incase of train a model with only one planetype instead of the 3 together
# it returns the training, validation and testing dataloaders for a specific plane type
def get_plane_loader(plane_type="axial"):
    train_dataset, validation_dataset, test_dataset = get_plane_datasets(plane_type)
    train_loader, val_loader, test_loader = get_plane_dataloader(train_dataset, validation_dataset, test_dataset)

    return train_loader, val_loader, test_loader

In [70]:
# Functions to save predictions as images
def save_predictions_as_imgs(mask, mask_name, folder="saved_images/", device="cuda"):
    torchvision.utils.save_image(mask.unsqueeze(1), f"{folder}/{mask_name}.jpg")

In [71]:
# methodto concatunate datasets together to merge all planes datasets in one
def concat_datasets(dataset_1, dataset_2):
    dataset = torch.utils.data.ConcatDataset([dataset_1, dataset_2])
    return dataset

In [72]:
# Combine all planes' data into 3 loaders (train, validation and test)
# will be used to train a model on the entire 3 planes' training data
def get_all_planes_dataloaders():
  # get all three planes three datasets (train, validation, and test)
    axial_train_dataset, axial_validation_dataset, axial_test_dataset = get_plane_datasets("axial")
    coronal_train_dataset, coronal_validation_dataset, coronal_test_dataset = get_plane_datasets("coronal")
    sagittal_train_dataset, sagittal_validation_dataset, sagittal_test_dataset = get_plane_datasets("sagittal")

  #combine datasets of all 3 planes(axial, coronal, and sagittal)
  # 1. combine train datasets inside train_ds
    dataset = concat_datasets(axial_train_dataset, coronal_train_dataset)
    train_ds = concat_datasets(dataset, sagittal_train_dataset)
  # 2. combine validation datasets inside val_ds
    dataset = concat_datasets(axial_validation_dataset, coronal_validation_dataset)
    val_ds = concat_datasets(dataset, sagittal_validation_dataset)
  # 3. combine test datasets inside test_ds
    dataset = concat_datasets(axial_test_dataset, coronal_test_dataset)
    test_ds = concat_datasets(dataset, sagittal_test_dataset)
  
  # get dataloaders of all data
    train_loader, val_loader, test_loader = get_plane_dataloader(train_ds, val_ds, test_ds)
    return train_loader, val_loader, test_loader

In [73]:
# calculate accuracy for multiclass segmentation, this function is called for every patch
def calc_accuracy(pred, label):
    probs = torch.log_softmax(pred, dim = 1)
    _, tags = torch.max(probs, dim = 1)
    corrects = torch.eq(tags,label).float()
    acc = corrects.sum()/corrects.numel()
    return acc.item()

In [74]:
# calculate mean Iou for multiclass segmentation, this function is called for every patch
def calc_iou(label, pred, classes=7): 
    pred = torch.nn.functional.softmax(pred, dim=1)              
    pred = torch.argmax(pred, dim=1).squeeze(1)
    patch_iou = 0.0
    class_iou = 0.0
    pred = pred.view(-1)
    label = label.view(-1)

    for class_ in range(classes):
        pred_inds = (pred == class_)
        target_inds = (label == class_)
        ## if the current loop class is not exist on the target mask
        if target_inds.long().sum().item() == 0:
            class_iou = 0
        else: 
            class_intersection = (pred_inds[target_inds]).long().sum().item()
            class_union = pred_inds.long().sum().item() + target_inds.long().sum().item() - class_intersection
            class_iou = float(class_intersection) / float(class_union)#calc the mean iou for each class in a given patch

        patch_iou += class_iou
    patch_iou /= classes
    return patch_iou #return the mean iou of the means ious for each class 

In [75]:
#this method is used to plot lines of training and validation (accuracy, Iou, loss)
# X ==> list of range of used epochs number
# y1 ==> line one, list of training(accuracy, loss or iou)
# y2 ==> line two, list of validation(accuracy, loss or iou)
# y1_label ==> the label of line one, train(accuracy, loss or iou)
# y2_label ==> the label of line two, validation(accuracy, loss or iou)
# x_label ==> the label of x-axis, "number of epochcs"
# y_label ==> the label of y-axis, (accuracy, loss or iou)
# title ==> title of the entire graph
def plot_metric(X, y1, y2, y1_label, y2_label, x_label, y_label, title):
    plt.plot(X, y1, label = y1_label, marker='o')
    plt.plot(X, y2, label = y2_label, marker='o')
    plt.xlabel(x_label)
    plt.ylabel(y_label)
    plt.title(title)
    plt.legend()
    plt.show()

## Train a model on the entire 3 system planes' training data

In [108]:
# check if your system is runing or cpu or gpu
if torch.cuda.is_available():
    DEVICE = 'cuda:0'
    print('Running on the GPU')
else:
    DEVICE = "cpu"
    print('Running on the CPU')

#important variables
MODEL_PATH = 'model_all_data_10.pth.tar'
LOAD_MODEL = True
LEARNING_RATE = 0.001
EPOCHS = 11

In [115]:
# this function is to train the model on all 3 planes' training data, it is called N number of epochs
def train_model(train_data, model, optimizer, loss_fn, device):
    train_loss = [] #hold all the losses based on training data for each patch, then we will get the mean loss for every epoch
    acc = []  #hold all the accuracies based on training data for each patch, then we will get the mean accuracy for every epoch
    ious = [] #hold all the mean ious based on training data for each patch, then we will get the mean iou for every epoch
    data = tqdm(train_data)
    for i, batch in enumerate(data):
        #prepare data
        X, y = batch
        X = X.to(device) 
        y = y.long().squeeze(1).to(device)
        #start training the model
        model.train()
        preds = model(X)
        loss = loss_fn(preds, y)

        train_loss.append(loss.item())
        acc.append(calc_accuracy(preds, y))
        ious.append(calc_iou(y, preds, 7))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    avg_loss = np.mean(train_loss)
    avg_acc = np.mean(acc)
    avg_iou = np.mean(ious)
    return avg_loss, avg_acc, avg_iou, model

In [116]:
# this function is to evaluate the model on all 3planes' validation data, it is called N number of epochs directly after train function
def evaluae_validaion(val_data, model, loss_fn, device):
    val_loss = [] #hold all the losses based on validation data for each patch, then we will get the mean loss for every epoch
    acc = []  #hold all the accuracies based on validation data for each patch, then we will get the mean accuracy for every epoch
    ious = [] #hold all the mean ious based on validation data for each patch, then we will get the mean iou for every epoch
    with torch.no_grad():
        for i, batch in enumerate(tqdm(val_data)):
            #prepare the data for model evaluating
            X, y = batch
            X = X.to(device)
            y = y.long().squeeze(1).to(device)

            model.eval()
            preds = model(X)

            val_loss.append(loss_fn(preds, y).item())
            acc.append(calc_accuracy(preds, y))
            ious.append(calc_iou(y, preds,7))
      
    avg_loss = np.mean(val_loss)
    avg_acc = np.mean(acc)
    avg_iou = np.mean(ious)
    return avg_loss, avg_acc, avg_iou

In [117]:
# this function combine both train and evaluate functions in one function
def train_val_model(train_data, val_data, model, optimizer, loss_fn, device):
    avg_train_loss, avg_train_acc, avg_train_iou, model_ = train_model(train_data, model, optimizer, loss_fn, device)
    avg_val_loss, avg_val_acc, avg_val_iou = evaluae_validaion(val_data, model_, loss_fn, device)
    return avg_train_loss, avg_train_acc, avg_train_iou, avg_val_loss, avg_val_acc, avg_val_iou

In [118]:
# read the data 
train_loader, val_loader, test_loader = get_all_planes_dataloaders()

In [119]:
def main():
    global epoch
    epoch = 0 
    avg_loss_train = []
    avg_loss_val = []
    avg_acc_train = []
    avg_acc_val = []
    avg_iou_train = []
    avg_iou_val = []
    
#     train_loader, val_loader, test_loader = get_all_planes_dataloaders()
    print('Data Loaded Successfully!')

    # Defining the model, optimizer and loss function
    unet = UNET(in_channels=1, out_channels=7).to(DEVICE).train()
    optimizer = optim.Adam(unet.parameters(), lr=LEARNING_RATE)
    loss_function = nn.CrossEntropyLoss() 

    # Loading a previous stored model from MODEL_PATH variable
    if LOAD_MODEL == True:
        checkpoint = torch.load(MODEL_PATH, map_location='cpu')
        unet.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optim_state_dict'])
        epoch = checkpoint['epoch']
        avg_loss_train = checkpoint['avg_loss_train']
        avg_acc_train = checkpoint['avg_acc_train']
        avg_iou_train = checkpoint['avg_iou_train']
        avg_loss_val = checkpoint['avg_loss_val']
        avg_acc_val = checkpoint['avg_acc_val']
        avg_iou_val = checkpoint['avg_iou_val']
        print("Model successfully loaded!")    

    #Training the model for every epoch. 
    for e in range(epoch, EPOCHS):
        print('Epoch ' + str(e+1))
        avg_train_loss, avg_train_acc, avg_train_iou, avg_val_loss, avg_val_acc, avg_val_iou =\
         train_val_model(train_loader, val_loader, unet, optimizer, loss_function, DEVICE)
        
        avg_loss_train.append(avg_train_loss)
        avg_acc_train.append(avg_train_acc)
        avg_iou_train.append(avg_train_iou)
        avg_loss_val.append(avg_val_loss)
        avg_acc_val.append(avg_val_acc)
        avg_iou_val.append(avg_val_iou)

        print("epoch: " + str(e+1), " train_loss: " + str(avg_train_loss), " train_acc: " + str(avg_train_acc), " train_iou: ", str(avg_train_iou), \
              " validation_loss: " + str(avg_val_loss), "val_acc: " + str(avg_val_acc), " val_iou: ", str(avg_val_iou))
        #saving the model
        torch.save({
            'model_state_dict': unet.state_dict(),
            'optim_state_dict': optimizer.state_dict(),
            'epoch': e+1,
            'avg_loss_train': avg_loss_train,
            'avg_acc_train': avg_acc_train,
            'avg_iou_train': avg_iou_train,
            'avg_loss_val': avg_loss_val,
            'avg_acc_val': avg_acc_val,
            'avg_iou_val': avg_iou_val
        }, MODEL_PATH)
        print("Epoch completed and model successfully saved!")

    return avg_loss_train, avg_acc_train, avg_iou_train, avg_loss_val, avg_acc_val, avg_iou_val

In [120]:
if __name__ == '__main__':
    avg_loss_train, avg_acc_train, avg_iou_train, avg_loss_val, avg_acc_val, avg_iou_val = main()

In [121]:
X = list(range(1, len(avg_loss_train)+1))
# 1. plot loss of train and validation data
plot_metric(X, avg_loss_train, avg_loss_val, "train_loss", "val_loss", "number of epochs", "loss", "Train and Validation Loss")

In [122]:
# 2. plot accuracy of train and validation data
plot_metric(X, avg_acc_train, avg_acc_val, "train_acc", "val_acc", "number of epochs", "accuracy", "Train and Validation Accuracy")

In [123]:
# 3. plot iou of train and validation data
plot_metric(X, avg_iou_train, avg_iou_val, "train_iou", "val_iou", "number of epochs", "iou", "Train and Validation IoU")

## Testing

In [127]:
## create 2 folders to hold the groundtruth masks and the predicted masks
os.makedirs('predicted_masks')
os.makedirs('true_masks')

In [128]:
def get_predictions(data, model, num_class=7):
    loss = []
    acc = []
    ious = []
    loss_function = nn.CrossEntropyLoss()    
    model.eval()
    with torch.no_grad():
        for idx, batch in enumerate(tqdm(data)):
            X, y = batch
            X, y = X.to(DEVICE), y.long().squeeze(1).to(DEVICE)
            predictions = model(X)

            loss.append(loss_function(predictions, y).item())
            acc.append(calc_accuracy(predictions, y)) 
            ious.append(calc_iou(y, predictions, num_class))

            predictions = torch.nn.functional.softmax(predictions, dim=1)
            pred_labels = torch.argmax(predictions, dim=1) 
            pred_labels = pred_labels.float()
            
            mask_name = "pred_" + str(idx)
            save_predictions_as_imgs(pred_labels, mask_name, "predicted_masks") #save predicted masks as jpg images into predicted_masks folder 
            img_name = "true_" + str(idx)
            save_predictions_as_imgs(y.float(), img_name, "true_masks")#save ground truth masks as jpg images into true_masks folder
            
    avg_loss = np.mean(loss)
    avg_acc = np.mean(acc)
    avg_iou = np.mean(ious)
    return avg_loss, avg_acc, avg_iou

In [129]:
def test(path, test_data, num_class=7):
    net = UNET(in_channels=1, out_channels=num_class).to(DEVICE)
    checkpoint = torch.load(path, map_location='cpu')
    net.load_state_dict(checkpoint['model_state_dict'])
    print(f'{path} has been loaded and initialized')
    avg_loss, avg_acc, avg_iou = get_predictions(test_data, net)
    print("Model testing completed...")
    print("loss: ", str(avg_loss), "  accuracy: ", str(avg_acc), " Iou: ", str(avg_iou))

In [130]:
test(MODEL_PATH, test_loader, 7)