In [None]:
from pathlib import Path
from tqdm.notebook import tqdm

import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
import imgaug.augmenters as iaa
import numpy as np
import matplotlib.pyplot as plt

from dataset import BrainDataset
from model import UNet

In [None]:
# augmentation pipeline
seq = iaa.Sequential([
    iaa.Affine(scale=(0.85, 1.15), # Zoom in or out
               rotate=(-45, 45)),  # Rotate up to 45 degrees
    iaa.ElasticTransformation()  # Random Elastic Deformations
                ])

In [None]:
# Create the dataset objects
train_path = Path("preprocessed/train/")
val_path = Path("preprocessed/val/")
test_path = Path("preprocessed/test/")

train_dataset = BrainDataset(train_path, seq)
val_dataset = BrainDataset(val_path, None)
test_dataset = BrainDataset(test_path, None)

In [None]:
fig, axis = plt.subplots(3, 3, figsize=(9, 9))

for i in range(3):
    for j in range(3):
        slice_mri, label = train_dataset[4]
        mask_1 = np.ma.masked_where(label==0, label)
        axis[i][j].imshow(slice_mri[0], cmap="bone")
        axis[i][j].imshow(mask_1[0], cmap="autumn")
        axis[i][j].axis("off")

fig.suptitle("Sample augmentations")
plt.tight_layout()

In [None]:
batch_size = 32
num_workers = 4

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False)

In [None]:
class DiceLoss(torch.nn.Module):
    """
    class to compute the Dice Loss
    """
    def __init__(self):
        super().__init__()

    def forward(self, pred, mask):
                
        # Flatten label and prediction tensors
        pred = torch.flatten(pred)
        mask = torch.flatten(mask)
        counter = (pred * mask).sum()  # Numerator       
        denum = pred.sum() + mask.sum() + 1e-8  # Denominator. Add a small number to prevent NANS
        dice =  (2*counter)/denum
        return 1 - dice

In [None]:
class BrainTumorSegmentation(pl.LightningModule):
    def __init__(self):
        super().__init__()
        
        self.model = UNet()
        
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-5)
        self.loss_fn = DiceLoss()
        
    def forward(self, data):
        return torch.sigmoid(self.model(data))
    
    def training_step(self, batch, batch_idx):
        mri, mask = batch
        mask = mask.float()
        pred = self(mri)
        
        loss = self.loss_fn(pred, mask)
        
        self.log("Train Dice", loss)
            
        return loss
    
    def validation_step(self, batch, batch_idx):
        mri, mask = batch
        mask = mask.float()
        pred = self(mri)
        
        loss = self.loss_fn(pred, mask)
        
        self.log("Val Dice", loss)
            
        return loss
        
    def configure_optimizers(self):
        return [self.optimizer]

In [None]:
# Instanciate the model and set the random seed
torch.manual_seed(0)
model = BrainTumorSegmentation()

In [None]:
# Create the checkpoint callback
checkpoint_callback = ModelCheckpoint(
    monitor='Val Dice',
    save_top_k=10,
    mode='min')

In [None]:
# Create the trainer

gpus = 1
trainer = pl.Trainer(gpus=gpus, logger=TensorBoardLogger(save_dir="./logs"), log_every_n_steps=20,
                     callbacks=checkpoint_callback,max_epochs=75)

In [None]:
trainer.fit(model, train_loader, val_loader)

In [None]:
import nibabel as nib
from tqdm.notebook import tqdm
from celluloid import Camera
from IPython.display import HTML

In [None]:
model = BrainTumorSegmentation.load_from_checkpoint("logs/lightning_logs/version_4/checkpoints/epoch=59-step=116580.ckpt")

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.eval()
model.to(device)

In [None]:
preds = []
labels = []

for slice_mri, label in tqdm(test_dataset):
    slice_mri = torch.tensor(slice_mri).to(device).unsqueeze(0)
    with torch.no_grad():
        pred = model(slice_mri)
    preds.append(pred.cpu().numpy())
    labels.append(label)
    
preds = np.array(preds)
labels = np.array(labels)

In [None]:
new_labels = np.where(labels >= 1, 1, 0)
1-model.loss_fn(torch.from_numpy(preds), torch.from_numpy(new_labels))

In [None]:
dice_score = 1-DiceLoss()(torch.from_numpy(preds), torch.from_numpy(new_labels).unsqueeze(0).float())
print(f"The Test Dice Score is: {dice_score}")

In [None]:
subject = Path("UCSF-PDGM-nifti/UCSF-PDGM-0004_T1.nii.gz")
subject_mri = nib.load(subject).get_fdata()

In [None]:
# Helper functions for normalization and standardization
def normalize(full_volume):
    """
    Z-Normalization of the whole subject
    """
    mu = full_volume.mean()
    std = np.std(full_volume)
    normalized = (full_volume - mu) / std
    
    return normalized

def standardize(normalized_data):
    """
    Standardize the normalized data into the 0-1 range
    """
    standardized_data = (normalized_data - normalized_data.min()) / (normalized_data.max() - normalized_data.min())
    
    return standardized_data

In [None]:
standardized_scan = standardize(normalize(subject_mri))
standardized_scan.shape

In [None]:
preds = []
for i in range(standardized_scan.shape[-1]):
    slice = standardized_scan[:,:,i]
    with torch.no_grad():
        pred = model(torch.tensor(slice).unsqueeze(0).unsqueeze(0).float().to(device))[0][0]
        pred = pred > 0.5
    preds.append(pred.cpu())

In [None]:
fig = plt.figure()
camera = Camera(fig)  # create the camera object from celluloid

for i in range(standardized_scan.shape[-1]):
    plt.imshow(standardized_scan[:,:,i], cmap="bone")
    mask_ = np.ma.masked_where(preds[i]==0, preds[i])
    plt.imshow(mask_, alpha=0.5, cmap="autumn")
    plt.axis("off")
    
    camera.snap()  # Store the current slice
animation = camera.animate()  # create the animation

In [None]:
HTML(animation.to_html5_video())