### If run on google colab use the following cell

In [None]:
!pip install simpleitk
!pip install pytorch-lightning

from google.colab import drive
drive.mount('/content/drive')

In [1]:
#all imports 
%load_ext tensorboard

from pathlib import Path
from typing import List, Union, Tuple, Dict, Callable, Optional

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as ptl
import SimpleITK as sitk
import torch
import torch.nn as nn
import torchmetrics
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import normalize
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm

from unet_model import *
from dataset_dataloading import * 
from helper_functions import * 

### If run on colab use the following cell

In [None]:
from drive.MyDrive.Colab_Notebooks.MRI_WMH_project.code.unet_model import *
from drive.MyDrive.Colab_Notebooks.MRI_WMH_project.code.dataset_dataloading import *
from drive.MyDrive.Colab_Notebooks.MRI_WMH_project.code.helper_functions import *

#define experiment parameters
base_path = Path("drive/MyDrive/Colab_Notebooks/MRI_WMH_project") #path to folder where "wmh_data" folder is
tensorboard_path = Path("drive/MyDrive/Colab_Notebooks/MRI_WMH_project/code/Tensorboard_logs")
train_subset = "Amsterdam"
seed = 15

In [3]:
#define experiment parameters
base_path = Path(r"/home/odysseas/Desktop/UU/AI for medical imaging/MRI_segmentation") #path to folder where "wmh_data" folder is
tensorboard_path = Path(r"home/odysseas/Desktop/UU/AI for medical imaging/MRI_segmentation/code/Tensorboard_logs")
train_subset = "Amsterdam"
seed = 15

In [4]:
#create dataset path
dataset_path = base_path / "Data"


#create name for the tensorboard log
train_log = "trained_with_" + train_subset + "_" + str(seed)

#create paths to the train and test set
#train_dataset_path = dataset_path / "Train"
#test_dataset_path = dataset_path / "Test"
print(train_log)

trained_with_Amsterdam_15


In [5]:
#set the seed for every to the same number
ptl.seed_everything(seed, workers=True)

model = UNet(
    input_channels=2,
    num_classes=3,
    num_layers=5,
    input_features=16,
    loss_function=DiceLoss(classes=3),
    is_3d=False,
    final_activation=torch.nn.Softmax2d(),
    lr=1e-3,
)

train_datamodule = WMHTrainDataModule(
    train_dataset_directory = dataset_path,
    selected_train_subset = train_subset,
    val_split = 0.1,
    batch_size=16,
    # If you're getting errors in training, set
    # num_workers to 0 for clearer error messages.
    num_workers=16,
    # This setting will spend a few more minutes longer at startup
    # but will also run much faster when the training starts.
    # Set to False if you don't have enough memory for this.
    use_prefetch=True
)

trainer = ptl.Trainer(
    # Writes the information that a Tensorboard (next cell) instance can read and plot
    logger=ptl.loggers.TensorBoardLogger(save_dir=str(tensorboard_path), version = train_log),
    # We want to run on the GPU (use "cpu", if not)
    accelerator="cpu",
    # We only have 1 GPU, but if we had more, here is where we would set it
    # In the case of `accelerator="cpu"`, put the amount of cores you want to use here
    devices=1,
    callbacks=[
        # Prints the model structure to the console on start
        ptl.callbacks.ModelSummary(max_depth=2),
        # Stops training if our validation loss doesn't improve for `patience=n` epochs
        ptl.callbacks.EarlyStopping(monitor="val_loss", patience=15, verbose=True),
        # Saves a copy of the model whenever we improve our validation loss
        ptl.callbacks.ModelCheckpoint(dirpath=tensorboard_path / "checkpoints", filename="checkpoint", monitor="val_loss")
        ],
    # Even with early stopping, we might want to set up a maximum epoch amount
    max_epochs=100,
    # If we want to make sure we don't accidentally stop way too soon
    min_epochs=20,
    # Automatic Mixed Precision (AMP) lets us save our model parameters in 16bit floats,
    # instead of the common 32bit float. This saves us about half the memory normally used for the model.
    # The most recent GPUs (RTX 4000 series) are able to go to 8bit, even, but this is still not well supported,
    # so 16bit is preferable.
    precision=16,
    # After n batches, we log to the console.
    # If you don't have that many samples to work with, it can happen that the default value `50` is more than the
    # amount of steps in an epoch, so we just keep this value low for demonstrative purposes.
    log_every_n_steps=1,
    # By default, the `Trainer` uses a `tqdm` progress bar.
    # This can be a little buggy in notebooks, turn it off when this is the case for you.
    # Or leave it if it doesn't bother you.
    enable_progress_bar=True
)

Global seed set to 15
  rank_zero_warn(
Using bfloat16 Automatic Mixed Precision (AMP)
Trainer already configured with model summary callbacks: [<class 'pytorch_lightning.callbacks.model_summary.ModelSummary'>]. Skipping setting a default `ModelSummary` callback.
  return torch._C._cuda_getDeviceCount() > 0
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [7]:
# load tensorboard
experiment_dir = tensorboard_path / "lightning_logs"
%tensorboard --logdir "$experiment_dir"

In [19]:
trainer.fit(model=model, datamodule=train_datamodule)


   | Name             | Type       | Params | In sizes                               | Out sizes        
--------------------------------------------------------------------------------------------------------------
0  | final_activation | Softmax2d  | 0      | ?                                      | ?                
1  | loss_function    | DiceLoss   | 0      | ?                                      | ?                
2  | layers           | ModuleList | 1.9 M  | ?                                      | ?                
3  | layers.0         | DoubleConv | 2.7 K  | [1, 2, 240, 240]                       | [1, 16, 240, 240]
4  | layers.1         | DownBlock  | 14.0 K | [1, 16, 240, 240]                      | [1, 32, 120, 120]
5  | layers.2         | DownBlock  | 55.6 K | [1, 32, 120, 120]                      | [1, 64, 60, 60]  
6  | layers.3         | DownBlock  | 221 K  | [1, 64, 60, 60]                        | [1, 128, 30, 30] 
7  | layers.4         | DownBlock  | 885 K  | [1

Sanity Checking: 0it [00:00, ?it/s]

Prefetching dataset:   0%|          | 0/166 [00:00<?, ?it/s]

  return (x - min_val) / (max_val - min_val)


Prefetching dataset:   0%|          | 0/1162 [00:00<?, ?it/s]

Training: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


# Testing

In [53]:
#define test parameters
test_subset = "Utr_Ams_Sing"

In [54]:
test_datamodule = WMHTestDataModule(
    test_dataset_directory = dataset_path,
    selected_test_subset = test_subset,
    batch_size=16,
    # If you're getting errors in training, set
    # num_workers to 0 for clearer error messages.
    num_workers=16,
    # This setting will spend a few more minutes longer at startup
    # but will also run much faster when the training starts.
    # Set to False if you don't have enough memory for this.
    use_prefetch=True
)

In [55]:
trainer.test(model=model, datamodule=test_datamodule)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

Prefetching dataset:   0%|                                                                     | 0/179 [00:00<?, ?it/s][A
Prefetching dataset:   2%|█                                                            | 3/179 [00:00<00:06, 27.91it/s][A
Prefetching dataset:   3%|██                                                           | 6/179 [00:00<00:07, 23.38it/s][A
Prefetching dataset:   5%|███                                                          | 9/179 [00:00<00:07, 21.34it/s][A
Prefetching dataset:   7%|████                                                        | 12/179 [00:00<00:09, 17.71it/s][A
Prefetching dataset:   8%|████▋                                                       | 14/179 [00:00<00:10, 15.65it/s][A
Prefetching dataset:   9%|█████▎                                                      | 16/179 [00:00<00:11, 13.81it/s][A
Prefetching dataset:  10%|██████                                                      | 18/179 [

Prefetching dataset:  96%|████████████████████████████████████████████████████████▎  | 171/179 [00:27<00:01,  4.61it/s][A
Prefetching dataset:  96%|████████████████████████████████████████████████████████▋  | 172/179 [00:27<00:01,  4.57it/s][A
Prefetching dataset:  97%|█████████████████████████████████████████████████████████  | 173/179 [00:28<00:01,  4.49it/s][A
Prefetching dataset:  97%|█████████████████████████████████████████████████████████▎ | 174/179 [00:28<00:01,  4.43it/s][A
Prefetching dataset:  98%|█████████████████████████████████████████████████████████▋ | 175/179 [00:28<00:00,  4.42it/s][A
Prefetching dataset:  98%|██████████████████████████████████████████████████████████ | 176/179 [00:28<00:00,  4.42it/s][A
Prefetching dataset:  99%|██████████████████████████████████████████████████████████▎| 177/179 [00:28<00:00,  4.27it/s][A
Prefetching dataset:  99%|██████████████████████████████████████████████████████████▋| 178/179 [00:29<00:00,  4.19it/s][A
Prefetching data

Testing DataLoader 0: 100%|████████████████████████████████████████████████████████████| 45/45 [00:04<00:00, 10.20it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_loss           0.11364488303661346
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 0.11364488303661346}]