In [None]:
!pip install torchio

In [None]:
#Imports
import os
import random
import time

import pandas as pd
import numpy as np

import torch
import torch.nn as nn
import torchvision.transforms.functional as TF
import gzip
import os
import shutil


from tqdm import tqdm

LEARNING_RATE = 5E-4
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

TRAIN_BATCH_SIZE = 16
VAL_BATCH_SIZE = 2
TEST_BATCH_SIZE = 2

NUM_EPOCHS = 300
NUM_WORKERS = os.cpu_count()
IMAGE_HEIGHT = 128
IMAGE_WIDTH = 128
IMAGE_DEPTH = 128
PATCH_SIZE = (64,64,64)
NUM_PATCHES = 4
PIN_MEMORY = True
LOAD_MODEL = False
TRAIN_IMG_DIR = "/kaggle/input/full-dataset/Dataset001_ISLES22forUNET_uc/imagesTr"
TRAIN_MASK_DIR = "/kaggle/input/full-dataset/Dataset001_ISLES22forUNET_uc/labelsTr"
CHECKPOINT_DIR = "/kaggle/working/checkpoint"
SAVED_IMAGES_DIR = "/kaggle/working/saved_images"

In [None]:
#model
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, 3, 1, 1, bias=True),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True),
            nn.BatchNorm3d(out_channels),
            nn.Dropout3d(p=0.1),
            nn.Conv3d(out_channels, out_channels, 3, 1, 1, bias=True),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True),
            nn.BatchNorm3d(out_channels),
            nn.Dropout3d(p=0.1),
        )

    def forward(self, x):
        return self.conv(x)

class UNET(nn.Module):
    def __init__(
            self, in_channels=3, out_channels=1, features=[64, 128, 256, 512,1024],
    ):
        super(UNET, self).__init__()
        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()
        self.pool = nn.MaxPool3d(kernel_size=2, stride=2, padding=0)

        # Down part of UNET
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature

        # Up part of UNET
        for feature in reversed(features):
            self.ups.append(nn.ConvTranspose3d(feature*2, feature, kernel_size=2, stride=2,))
            self.ups.append(DoubleConv(feature*2, feature))

        self.bottleneck = DoubleConv(features[-1], features[-1]*2)
        self.final_conv = nn.Conv3d(features[0], out_channels, kernel_size=1)

    def forward(self, img):
        """
        Forward pass of the UNet
        Parameters
        ----------
        img : The input image of shape (BATCH_SIZE, 3, IMAGE_DEPTH, IMAGE_HEIGHT, IMAGE_WIDTH)

        Returns: The output of the UNet of shape (BATCH_SIZE, 1, IMAGE_DEPTH, IMAGE_HEIGHT, IMAGE_WIDTH)
        -------

        """
        # Connection is the list of outputs from the downsampling path.
        # We save it to keep local information. So where is the information
        connections = []
        #path down the UNet, finds important informations
        for down in self.downs:
            img = down(img)
            connections.append(img)
            img = self.pool(img)

        #link from downsampling to upsampling
        img = self.bottleneck(img)

        #reverse the connections list to go up the UNet
        connections = connections[::-1]
        for idx in range(0, len(self.ups), 2):
            #ConvTranspose3d is the upsampling layer
            img = self.ups[idx](img)
            #concatenates the output from the upsampling layer with the output from the downsampling layer
            connection = connections[idx//2]
            if(img.shape != connection.shape):
                new_connection = torch.zeros((connection.shape[0], connection.shape[1], img.shape[2], img.shape[3], img.shape[4]))
                for i in range(connection.shape[0]):
                    new_connection[i] = tio.Resize(img.shape[2:])(connection[i].cpu())
                connection = new_connection.cuda()
            #concatenat the image slices with the connection, along the channel axis
            img = torch.cat((img, connection), dim=1)
            #DoubleConv is the downsampling layer
            img = self.ups[idx+1](img)

        x = self.final_conv(img)

        return x

In [None]:
#dataset
from torch.utils.data import Dataset, DataLoader
import numpy as np
import nibabel as nib
from sklearn.utils import shuffle
import copy

def remove_missing(image_paths, labels_paths):
    """
    Remove images and labels that are missing
    Parameters
    ----------
    image_paths : The paths to the images in the dataset of shape (NUM_IMAGES, NUM_CHANNELS) of the form "Dataset001_ISLES22forUNET/imagesTr/ISLES_x_y.nii.gz"
    labels_paths : The paths to the labels in the dataset of shape (NUM_IMAGES) of the form "Dataset001_ISLES22forUNET/labelsTr/ISLES_x.nii.gz"

    Returns : The image and label paths with the missing images removed
    -------

    """
    missing_images = [203,204]
    missing_labels = [203,204]


    labels_ids = []
    image_ids = []

    for i in range(len(labels_paths)):
        labels_ids.append(labels_paths[i].split('_')[-1][:3])
    for i in range(len(image_paths)):
        image_ids.append(image_paths[i][0].split('_')[-2])

    for i in range(len(image_paths)):
        id = image_paths[i][0].split('_')[-2]
        if id not in labels_ids:
            missing_labels.append(i)


    for i in range(len(labels_paths)):
        id = labels_paths[i].split('_')[-1][:3]
        if id not in image_ids:
            missing_images.append(i)

    image_paths = np.delete(image_paths, missing_labels, axis=0)
    labels_paths = np.delete(labels_paths, missing_images, axis=0)

    return image_paths, labels_paths


class MRIImage(Dataset):
    def __init__(self, image_paths, labels_paths, transform=None, split_ratios=[0.5, 0.2, 0.1], mode = None , patch_size = (64, 64, 64),num_patches = 2):
        """
        Create a dataset from a dataframe of images and labels.
        Parameters
        ----------
        image_paths : The paths to the images in the dataset of shape (NUM_IMAGES * NUM_CHANNELS)
        labels_paths : The paths to the labels in the dataset of shape (NUM_IMAGES)
        """
        super(MRIImage, self).__init__()
        self.transform = transform
        self.split_ratios = split_ratios
        self.mode = mode

        images_list = np.array([os.path.join(image_paths, x) for x in os.listdir(image_paths)])
        labels_list = np.array([os.path.join(labels_paths, x) for x in os.listdir(labels_paths)])

        self.labels_paths = np.sort(labels_list)
        self.image_paths = np.sort(images_list).reshape(-1, 3)

        self.image_paths, self.labels_paths = remove_missing(self.image_paths, self.labels_paths)

        num_training_imgs = len(self.labels_paths)
        train_val_test = [int(x * num_training_imgs) for x in split_ratios]

        selected = np.arange(0, num_training_imgs)
        selected = shuffle(selected)

        self.train_image_path = self.image_paths[selected[:train_val_test[0]]] #might want to add .values
        self.train_label_path = self.labels_paths[selected[:train_val_test[0]]]
        self.val_image_path = self.image_paths[selected[train_val_test[0]:train_val_test[0] + train_val_test[1]]]
        self.val_label_path = self.labels_paths[selected[train_val_test[0]:train_val_test[0] + train_val_test[1]]]
        self.test_image_path = self.image_paths[selected[train_val_test[0] + train_val_test[1]:]]
        self.test_label_path = self.labels_paths[selected[train_val_test[0] + train_val_test[1]:]]

        self.patch_size = patch_size  # Each channel is 128x128x128, patches will be 64x64x64
        self.num_patches = num_patches
        
        self.previous_image = None
        self.previous_index = None

    def set_mode(self, mode):
        if mode != "train" and mode != "val" and mode != "test":
            raise ValueError("mode must be either train, val or test")
        self.mode = mode

    def __len__(self):
        if self.mode == "train":
            return len(self.train_label_path) * self.num_patches
        elif self.mode == "val":
            return len(self.val_label_path)
        elif self.mode == "test":
            return len(self.test_label_path) * self.num_patches

    def __getitem__(self, index):
        """
        Get an image and its label from the dataset using an index.
        Parameters
        ----------
        index : The index of the image to get

        Returns
        -------
        original_images: The original images of shape (3, IMAGE_DEPTH, IMAGE_HEIGHT, IMAGE_WIDTH)
        mask: The mask of shape (IMAGE_DEPTH, IMAGE_HEIGHT, IMAGE_WIDTH)
        -------

        """

        #select the correct mode
        if self.mode == "train":
            image_paths = self.train_image_path[index//self.num_patches]
            label_path = self.train_label_path[index//self.num_patches]
            transform = self.transform[0]
        elif self.mode == "val":
            image_paths = self.val_image_path[index]
            label_path = self.val_label_path[index]
            transform = self.transform[1]
        elif self.mode == "test":
            image_paths = self.test_image_path[index//self.num_patches]
            label_path = self.test_label_path[index//self.num_patches]
            transform = self.transform[2]
        else:
            raise ValueError("mode must be either train, val or test")

        dwi_path = image_paths[0]
        adc_path = image_paths[1]
        flair_path = image_paths[2]

        if(self.previous_index != index//self.num_patches):
            self.previous_index = index//self.num_patches
            original_images = self.create_image(adc_path, dwi_path, flair_path, label_path)
            self.previous_image = original_images
            if transform is not None:
                # original_images[0:3] = histogram(original_images[0:3])
                original_images = transform(original_images)
                original_images[0:3] = normalize(original_images[0:3])
                
        original_images = self.previous_image

        mask = np.array([original_images[3] > 0.5]).astype(np.float32)

        if self.mode == "val":
            return original_images[0:3], mask[0]

        input_data = torch.tensor(original_images[0:3])
        mask = torch.tensor(mask)

        patch_idx = index % self.num_patches
        start_indices = [
            patch_idx * (sz // self.num_patches) for sz in input_data[0].shape
        ]

        # Calculate the end indices for each dimension
        end_indices = [
            start_indices[dim] + self.patch_size[dim] for dim in range(len(start_indices))
        ]

        # Extract the patch
        channel_patch = []
        for i in range(len(original_images) - 1):
            channel_patch.append(input_data[i][
                                 start_indices[0]:end_indices[0],
                                 start_indices[1]:end_indices[1],
                                 start_indices[2]:end_indices[2],
                                 ])

        channel_patch = tio.Resize(self.patch_size)(torch.stack(channel_patch))

        mask_patch = mask[0][
                     start_indices[0]:end_indices[0],
                     start_indices[1]:end_indices[1],
                     start_indices[2]:end_indices[2],
                     ]

        mask_patch = tio.Resize(self.patch_size)(mask_patch.unsqueeze(0))[0]

        return channel_patch, mask_patch

    def create_image(self, adc_path, dwi_path, flair_path, label_path):
        """"""
        dwi_image = nib.load(dwi_path).get_fdata()
        original_images = np.zeros((4, dwi_image.shape[0], dwi_image.shape[1], dwi_image.shape[2]))
        original_images[0] = dwi_image
        original_images[1] = nib.load(adc_path).get_fdata()
        original_images[2] = nib.load(flair_path).get_fdata()
        original_images[3] = nib.load(label_path).get_fdata()
        return original_images

def normalize(image):
    """
    Normalize a 3D image
    Parameters
    ----------
    image: an image of shape (BATCH_SIZE, 3, IMAGE_DEPTH, IMAGE_HEIGHT, IMAGE_WIDTH)

    Returns: the normalized image
    -------
    """
    eps = 1e-10
    min_value = np.min(image)
    max_value = np.max(image)
    norm_0_1 = (image - min_value) / (max_value - min_value + eps)

    return np.clip(2*norm_0_1 - 1, -1, 1)

def get_train_val_test_Dataloaders(train_transforms, val_transforms, test_transforms):
    dataset = MRIImage(TRAIN_IMG_DIR, TRAIN_MASK_DIR, [train_transforms, val_transforms, test_transforms], patch_size=PATCH_SIZE, num_patches=NUM_PATCHES)

    train_set, val_set, test_set = copy.deepcopy(dataset), copy.deepcopy(dataset), copy.deepcopy(dataset)
    train_set.set_mode('train')
    val_set.set_mode('val')
    test_set.set_mode('test')

    train_dataloader = DataLoader(dataset=train_set, batch_size=TRAIN_BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, persistent_workers=True)
    val_dataloader = DataLoader(dataset=val_set, batch_size=VAL_BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, persistent_workers=True)
    test_dataloader = DataLoader(dataset=test_set, batch_size=TEST_BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, persistent_workers=True)

    return train_dataloader, val_dataloader, test_dataloader

In [None]:
#losses

def dice_coefficient(predicted, target, epsilon=1e-6):
    intersection = torch.sum(predicted * target)
    union = torch.sum(predicted) + torch.sum(target)
    dice_score = (2.0 * intersection + epsilon) / (union + epsilon)
    return dice_score

class DiceBCELoss_2(nn.Module):
    def __init__(self, device=DEVICE):
        super(DiceBCELoss_2, self).__init__()
        self.device = device

    def forward(self, predicted, target):
        # Ensure predicted and target tensors are of the same shape
        if predicted.shape != target.shape:
            predicted = predicted.squeeze(1)
        

        sig_predicted = nn.Sigmoid()(predicted)
        # Calculate Dice Loss
        dice_loss = 1 - dice_coefficient(sig_predicted, target)

        # Calculate Binary Cross Entropy Loss
        bce_loss = nn.BCEWithLogitsLoss()(predicted, target).to(self.device)

        # Combine both losses
        combined_loss = 0.25*dice_loss + 0.75*bce_loss

        return combined_loss

In [None]:
#utils

from torch.utils.data import DataLoader
import os
from torchmetrics.classification import *
import torch.nn.functional as F


def save_checkpoint(state,checkpoint_dir, epoch):
    """Saves model and training parameters at '{checkpoint_dir}/last_checkpoint.pytorch'.

    Args:
        state (dict): contains model's state_dict, optimizer's state_dict, epoch
            and best evaluation metric value so far
        checkpoint_dir (string): directory where the checkpoint are to be saved
    """

    if not os.path.exists(checkpoint_dir):
        os.mkdir(checkpoint_dir)

    last_file_path = os.path.join(checkpoint_dir, f'checkpoint_epoch{epoch}.pytorch')
    torch.save(state, last_file_path)

def load_checkpoint(checkpoint_path, model, optimizer=None,
                    model_key='state_dict', optimizer_key='optimizer'):
    """Loads model and training parameters from a given checkpoint_path
    If optimizer is provided, loads optimizer's state_dict of as well.

    Args:
        checkpoint_path (string): path to the checkpoint to be loaded
        model (torch.nn.Module): model into which the parameters are to be copied
        optimizer (torch.optim.Optimizer) optional: optimizer instance into
            which the parameters are to be copied

    Returns:
        state
    """
    if not os.path.exists(checkpoint_path):
        raise IOError(f"Checkpoint '{checkpoint_path}' does not exist")

    state = torch.load(checkpoint_path, map_location='cuda')
    model.load_state_dict(state[model_key])

    if optimizer is not None:
        optimizer.load_state_dict(state[optimizer_key])

    return state


def train_metrics(predictions, targets, device):
    """
    Calculate the accuracy and f1 score of the model
    Parameters
    ----------
    predictions: The predictions of the model of shape (BATCH_SIZE, 1, IMAGE_DEPTH, IMAGE_HEIGHT, IMAGE_WIDTH)
    targets: The ground truth of shape (BATCH_SIZE, 1, IMAGE_DEPTH, IMAGE_HEIGHT, IMAGE_WIDTH)
    Returns: The accuracy and f1 score
    -------
    """
    if predictions.shape != targets.shape:
        predictions = predictions.squeeze(1)
    
    predictions = nn.Sigmoid()(predictions)
    predictions = (predictions > 0.5).long()
    targets = (targets > 0.5).long()
    tp = torch.logical_and(predictions == 1, targets == 1).sum().item()
    tn = torch.logical_and(predictions == 0, targets == 0).sum().item()
    fp = torch.logical_and(predictions == 1, targets == 0).sum().item()
    fn = torch.logical_and(predictions == 0, targets == 1).sum().item()
    dice=dice_coefficient(predictions,targets).item()

    accuracy = (tp + tn) / (tp + tn + fp + fn)
    f1 = 2 * tp / (2 * tp + fp + fn+ 1e-10)
    return accuracy, f1, tp, tn, fp, fn, dice


def check_accuracy(loader, model, crop_patch_size, device="cuda"):
    """
    Check the accuracy and f1 score of the model on the loader
    Parameters
    ----------
    loader the validation loader
    model the model to use
    crop_patch_size the size of the patch to crop used before feeding the model
    device the device to use, defaults to "cuda"

    Returns : The accuracy and f1 score
    -------

    """
    num_correct = 0
    num_pixels = 0
    model.eval()

    tp, tn, fp, fn, f1, accuracy,dice = 0, 0, 0, 0, 0, 0,0
    num_iter = 0
    for x, y in loader:
        y = y.to(device)
        sx, sy, sz = crop_patch_size[0], crop_patch_size[1], crop_patch_size[2]
        #run over each patch
        for i in range(0, y.shape[1], sx):
            for j in range(0, y.shape[2], sy):
                for k in range(0, y.shape[3], sz):
                    crop_x = x[:, :, i:i + sx, j:j + sy, k:k + sz]
                    crop_y = y[:, i:i + sx, j:j + sy, k:k + sz]
                    binary_y = (crop_y > 0.5).float()
                    with torch.no_grad():
                        preds = compute_prediction(crop_patch_size, crop_x, model)

                    accuracy_t, f1_t, tp_t, tn_t, fp_t, fn_t,dice_t = train_metrics(preds, binary_y, device)
                    tp += tp_t
                    tn += tn_t
                    fp += fp_t
                    fn += fn_t
                    f1 += f1_t
                    accuracy += accuracy_t
                    dice+= dice_t

                    num_iter += 1


    print(
        f"Got average Accuracy : {accuracy/num_iter:.2f}"
    )
    print(f"True Positive: {tp}, True Negative: {tn}, False Positive: {fp}, False Negative: {fn}")
    print(f"Got average F1 score: {f1/num_iter:.4f}")
    model.train()

    return accuracy/num_iter, f1/num_iter, tp, tn, fp, fn, dice/num_iter


def compute_prediction(crop_patch_size, x, model):
    """
    Compute the prediction of the model on an image x
    Parameters
    ----------
    crop_patch_size the size of the patch to crop
    crop_x the crop to use of shape (BATCH_SIZE, 3, CROP_DEPTH, CROP_HEIGHT, CROP_WIDTH)
    model the model to use

    Returns
    -------

    """
    # resize the patch to the correct size. Computationaly expensive because it calls the cpu
    original_shape = x.shape
    crop_x = resize_tensor(x, crop_patch_size)
    # get the prediction
    preds = torch.sigmoid(model(crop_x.float()))
    preds = (preds > 0.5).float()
    preds = resize_tensor(preds, original_shape[2:])
    return preds

def resize_tensor(tensor: torch.Tensor, new_size):
    """
    Resize the tensor to the new size
    Parameters
    ----------
    tensor : The tensor to resize of shape (BATCH_SIZE, 3, IMAGE_DEPTH, IMAGE_HEIGHT, IMAGE_WIDTH)
    new_size : The new size of shape (IMAGE_DEPTH, IMAGE_HEIGHT, IMAGE_WIDTH)

    Returns : The resized tensor
    -------

    """
    if tensor.shape[2:] == new_size:
        return tensor
    if tensor.shape[1:] == new_size:
        return tensor

    return F.interpolate(tensor, size=new_size, mode='trilinear', align_corners=False).to(tensor.device)

def save_predictions_as_imgs(
        loader, model, crop_patch_size, epoch, folder="saved_images/", device="cuda"
):
    """
    Save the predictions of the model on the loader in the folder. Saves one image per batch.
    Parameters
    ----------
    loader : The loader to use, one iteration of the loader must return the image and the mask of shape (BATCH_SIZE, 3, IMAGE_DEPTH, IMAGE_HEIGHT, IMAGE_WIDTH) and (BATCH_SIZE, IMAGE_DEPTH, IMAGE_HEIGHT, IMAGE_WIDTH) respectively
    model : The model to use
    crop_patch_size : The size of the patch to crop
    epoch : The epoch to save the images
    folder : The folder to save the images, defaults to "saved_images/"
    device : The device to use, defaults to "cuda"
    """
    model.eval()
    batch_idx = 0
    if not os.path.exists(folder):
        os.mkdir(folder)

    subfolder = f"{folder}/epoch_{epoch}"
    if not os.path.exists(subfolder):
        os.mkdir(subfolder)

    for x, y in loader:
        x = x.to(device=device)
        y = y.to(device=device)
        true_image = y[0]
        full_pred = torch.zeros(y.shape[1], y.shape[2], y.shape[3])
        sx, sy, sz = crop_patch_size[0], crop_patch_size[1], crop_patch_size[2]
        #run over each patch
        for i in range(0, y.shape[1], sx):
            for j in range(0, y.shape[2], sy):
                for k in range(0, y.shape[3], sz):
                    crop_x = x[:, :, i:i + sx, j:j + sy, k:k + sz]
                    with torch.no_grad():
                        preds = compute_prediction(crop_patch_size, crop_x, model)

                    full_pred[i:i + sx, j:j + sy, k:k + sz] = preds[0, 0]

        for slice in range(0, full_pred.shape[0], full_pred.shape[0] // 4):
            save_image(batch_idx, full_pred, slice, subfolder)
            save_image(batch_idx, true_image, slice, subfolder)
        batch_idx += 1
        if batch_idx > 4:
            break
    model.train()


def save_image(batch_idx, img, slice, subfolder):
    pred_image = img[slice]
    torchvision.utils.save_image(pred_image, f"{subfolder}/pred_{batch_idx}_slice{slice}.png")
    
def log(metrics, index, epoch):
    """
    Log the metrics in a folder, creates the folder if it does not exist
    Parameters
    ----------
    metrics a dictionary containing the metrics, such as loss, f1, accuracy, tp, tn, fp, fn
    index  the index to log, either "train" or "val"
    -------
    """
    folder = f"logs/{index}"
    if not os.path.isdir(folder):
        os.makedirs(folder)
    metrics_list = []
    for key in metrics[index].keys():
        metrics_list.append(metrics[index][key])
    metrics_tensor = torch.tensor(metrics_list)
    torch.save(metrics_tensor, f"{folder}/metrics_epoch{epoch}.zip")

In [None]:
#main

import os
import numpy as np
import torchio as tio
import torch
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
import torchvision



def train_fn_patched(loader, model, optimizer, loss_fn, scaler):
    """
    Train the model for one epoch
    Parameters
    ----------
    loader: A dataloader of the training set
    model: The model to train
    optimizer: The optimizer to use
    loss_fn: The loss function to use
    scaler: The scaler to use for mixed precision training
    -------

    """
    
    
    model.train()
    loop = tqdm(loader)
    avg_loss = 0.0
    batch_accuracy, batch_f1, batch_tp, batch_tn, batch_fp, batch_fn ,batch_dice= 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,0.0
    number_iter = 0
    total_loss = 0.0
    for data, targets in loop:
        
        data = data.to(device=DEVICE)
        targets = targets.float().to(device=DEVICE)

        # forward
        with torch.cuda.amp.autocast():
            predictions = model(data.float())
            predictions=predictions.to(device=DEVICE)
            loss = loss_fn(predictions, targets).to(device=DEVICE)

        if np.isnan(loss.item()):
            print("Nan loss encountered")
            print(model(data))
            exit(1)
        # backward
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # update tqdm loop
        number_iter += 1
        total_loss += loss.item()
        loop.set_postfix(loss=total_loss / (number_iter + 1))
        accuracy, f1, tp, tn, fp, fn,dice = train_metrics(predictions, targets, DEVICE)
        batch_accuracy += accuracy
        batch_f1 += f1
        batch_tp += tp
        batch_tn += tn
        batch_fp += fp
        batch_fn += fn
        batch_dice+=dice

    return total_loss/number_iter, batch_accuracy/number_iter, batch_f1/number_iter, batch_tp, batch_tn, batch_fp, batch_fn, batch_dice/number_iter


def main(backup_rate = 100):
    #transform of a 3D image.

    train_transform = tio.Compose([
        #tio.RandomAffine(p=0.3),
        
        tio.CropOrPad((IMAGE_DEPTH, IMAGE_HEIGHT, IMAGE_WIDTH)),
        #tio.RandomAnisotropy(p=0.1),
        #tio.Blur(std=0.5, p=0.25),
        #tio.RandomMotion(degrees=15, translation=5, p=0.3),
        #tio.RandomBiasField(p=0.2),
        tio.RandomFlip(p=0.3),
        #tio.RandomElasticDeformation(max_displacement=10, p=0.05),
        tio.RandomSwap(p=0.3),
        # Normalization occurs later
    ])
    val_transform = tio.Compose([
        
        tio.CropOrPad((IMAGE_DEPTH, IMAGE_HEIGHT, IMAGE_WIDTH)),
    ])

    #model definition
    model = UNET(in_channels=3, out_channels=1)
    model = nn.DataParallel(model).cuda()
    model.to(DEVICE)

    loss_fn = DiceBCELoss_2(device=DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=LEARNING_RATE/NUM_EPOCHS)

    #Creating Dataloaders
    train_loader, val_loader, test_loader = get_train_val_test_Dataloaders(train_transform, val_transform, val_transform)

    scaler = torch.cuda.amp.GradScaler()

    losses = np.zeros(NUM_EPOCHS)
    metrics = {"train" : {"f1": [], "accuracy": [], "tp": [], "tn": [], "fp": [], "fn": [],"dice":[]},
               "val": {"f1": [], "accuracy": [], "tp": [], "tn": [], "fp": [], "fn": [],"dice":[]}}
    #Traing in batches, save every 10 epochs
    for epoch in range(NUM_EPOCHS):
        losses[epoch], accuracy, f1, tp, tn, fp, fn,dice = train_fn_patched(train_loader, model, optimizer, loss_fn, scaler)
        #print(f"train acc : {accuracy}")
        print(f"train f1 : {f1}")
        print(f"train dice : {dice}")
        metrics["train"]["f1"].append(f1)
        metrics["train"]["accuracy"].append(accuracy)
        metrics["train"]["tp"].append(tp)
        metrics["train"]["tn"].append(tn)
        metrics["train"]["fp"].append(fp)
        metrics["train"]["fn"].append(fn)
        metrics["train"]["dice"].append(dice)
        log(metrics, "train", epoch)
        # print some examples to a folder
        if(epoch%backup_rate == 0 and epoch!=0):
            save_predictions_as_imgs(
                val_loader, model, PATCH_SIZE, epoch, folder=SAVED_IMAGES_DIR, device=DEVICE)

            accuracy, f1, tp, tn, fp, fn,dice = check_accuracy(val_loader, model, PATCH_SIZE, device=DEVICE)
            
            metrics["val"]["f1"].append(f1)
            metrics["val"]["accuracy"].append(accuracy)
            metrics["val"]["tp"].append(tp)
            metrics["val"]["tn"].append(tn)
            metrics["val"]["fp"].append(fp)
            metrics["val"]["fn"].append(fn)
            metrics["val"]["dice"].append(dice)
            log(metrics, "val", epoch)

            checkpoint = {
                "state_dict": model.state_dict(),
                "optimizer": optimizer.state_dict(),
            }
            save_checkpoint(checkpoint, CHECKPOINT_DIR, epoch)
            
    save_predictions_as_imgs(
        val_loader, model, PATCH_SIZE,"final", folder=SAVED_IMAGES_DIR, device=DEVICE
    )
    accuracy, f1, tp, tn, fp, fn,dice = check_accuracy(val_loader, model, PATCH_SIZE, device=DEVICE)
    metrics["val"]["f1"].append(f1)
    metrics["val"]["accuracy"].append(accuracy)
    metrics["val"]["tp"].append(tp)
    metrics["val"]["tn"].append(tn)
    metrics["val"]["fp"].append(fp)
    metrics["val"]["fn"].append(fn)
    metrics["val"]["dice"].append(dice)
    log(metrics, "val", epoch)

    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    save_checkpoint(checkpoint,CHECKPOINT_DIR,NUM_EPOCHS)
    return losses, metrics


In [None]:
loss,metrics=main()

import json

if not os.path.exists("/kaggle/working/results"):
    os.mkdir("/kaggle/working/results")

loss_file_path = os.path.join("/kaggle/working/results", "loss.npy")
metrics_file_path = os.path.join("/kaggle/working/results", "metrics.json")

np.save(loss_file_path,loss)
    
with open(metrics_file_path, 'w') as f: 
    json.dump(metrics, f)

In [None]:
import gc
#!zip -r file.zip /kaggle/working/checkpoint/checkpoint_epoch50.pytorch
#from IPython.display import FileLink
#FileLink(r'file.zip')
gc.collect()
torch.cuda.empty_cache()