### Imports

In [11]:
import torch
import monai

from monai.utils import first, set_determinism
from monai.transforms import (
    AsDiscrete,
    AsDiscreted,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandCropByPosNegLabeld,
    SaveImaged,
    ScaleIntensityRanged,
    Spacingd, 
    EnsureTyped,
    EnsureType,
    Invertd
)

# from monai.handlers.utils import from_engine
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.metrics import DiceMetric
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch
from monai.config import print_config
from monai.apps import download_and_extract
import matplotlib.pyplot as plt
import tempfile
import shutil
import os
import glob
import nibabel as nib
import numpy as np
from tqdm.notebook import tqdm

### Dataset

In [13]:
root_dir = "/scratch/scratch6/akansh12/Parse_data/train/train/"
train_images = sorted(glob.glob(os.path.join(root_dir, "*", 'image', "*.nii.gz")))
train_labels = sorted(glob.glob(os.path.join(root_dir, "*", 'label', "*.nii.gz")))

data_dicts = [{"images": images_name, "labels": label_name} for images_name, label_name in zip(train_images, train_labels)]
train_files, val_files = data_dicts[:-9], data_dicts[-9:]
set_determinism(seed = 0)

### Transforms

In [12]:
train_transforms = Compose(
    [
     LoadImaged(keys=['images', 'labels']),
     EnsureChannelFirstd(keys = ["images", "labels"]),
     Orientationd(keys=['images', 'labels'], axcodes = 'LPS'),
     Spacingd(keys=['images', 'labels'], pixdim = (1.5,1.5,2), mode = ("bilinear", 'nearest')),
     ScaleIntensityRanged(
            keys=["images"], a_min=-700, a_max=300,
            b_min=0.0, b_max=1.0, clip=True,
        ),
     CropForegroundd(keys=['images', 'labels'], source_key="images"),
     RandCropByPosNegLabeld(
            keys=['images', 'labels'],
            label_key="labels",
            spatial_size=(96, 96, 96),
            pos=1,
            neg=1,
            num_samples=4,
            image_key="images",
            image_threshold=0,
        ),
     EnsureTyped(keys=['images', 'labels']),     
          
    ]
)

val_transforms = Compose(
    [
        LoadImaged(keys=["images", "labels"]),
        EnsureChannelFirstd(keys=["images", "labels"]),
        Orientationd(keys=["images", "labels"], axcodes="LPS"),
        Spacingd(keys=["images", "labels"], pixdim=(
            1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
        ScaleIntensityRanged(
            keys=["images"], a_min=-700, a_max=300,
            b_min=0.0, b_max=1.0, clip=True,
        ),
        CropForegroundd(keys=["images", "labels"], source_key="images"),
        EnsureTyped(keys=["images", "labels"]),
    ]
)

### DataLoader

In [17]:
train_ds = CacheDataset(
    data = train_files, transform = train_transforms,
    cache_rate = 1.0, num_workers = 4
)

train_loader = DataLoader(train_ds, batch_size = 2, shuffle = True, num_workers=2)
val_ds = CacheDataset(
    data = val_files, transform = val_transforms,
    cache_rate = 1.0, num_workers = 4
)
val_loader = DataLoader(val_ds, batch_size = 1, shuffle = False, num_workers=2)

Loading dataset: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 91/91 [04:26<00:00,  2.93s/it]
Loading dataset: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:42<00:00,  4.71s/it]


### Model defination

In [18]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 

UNet_meatdata = dict(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH
)
model = UNet(**UNet_meatdata).to(device)

### Loss function and Optimizers

In [19]:
loss_function = DiceLoss(to_onehot_y=True, softmax=True)
loss_type = "DiceLoss"
optimizer = torch.optim.Adam(model.parameters(), 1e-4)
dice_metric = DiceMetric(include_background=False, reduction="mean")

In [20]:
Optimizer_metadata = {}
for ind, param_group in enumerate(optimizer.param_groups):
    optim_meta_keys = list(param_group.keys())
    Optimizer_metadata[f'param_group_{ind}'] = {key: value for (key, value) in param_group.items() if 'params' not in key}

In [28]:
max_epochs = 600
val_interval = 10
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = []
metric_values = []
post_pred = Compose([EnsureType(), AsDiscrete(argmax=True, to_onehot=2)])
post_label = Compose([EnsureType(), AsDiscrete(to_onehot=2)])


slice_to_track = 80

for epoch in tqdm(range(max_epochs)):
    model.train()
    epoch_loss = 0
    step = 0
    
    for batch_data in train_loader:
        step += 1
        inputs, labels = (
            batch_data['images'].to(device),
            batch_data['labels'].to(device)
        )
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss +=loss.item()
        print(f"{step}/{len(train_ds) // train_loader.batch_size}, "
                f"train_loss: {loss.item():.4f}")

    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)

    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

    if (epoch + 1) % val_interval == 0:

        model.eval()
        with torch.no_grad():
            for index, val_data in enumerate(val_loader):

                val_inputs, val_labels = val_data['images'].to(device), val_data['labels'].to(device)
                roi_size = (160, 160, 160)
                sw_batch_size = 4
                val_outputs = sliding_window_inference(
                            val_inputs, roi_size, sw_batch_size, model)

                output = torch.argmax(val_outputs, dim=1)[0, :, :, slice_to_track].float()

                val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]
                val_labels = [post_label(i) for i in decollate_batch(val_labels)]
                dice_metric(y_pred=val_outputs, y=val_labels)

                metric = dice_metric.aggregate().item()
                dice_metric.reset()

            metric_values.append(metric)
            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                torch.save(model.state_dict(), os.path.join(
                    root_dir, "/scratch/scratch6/akansh12/challenges/parse2022/temp/best_metric_model.pth"))

                best_model_log_message = f"saved new best metric model at the {epoch+1}th epoch"
                print(best_model_log_message)

                message1 = f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
                message2 = f"\nbest mean dice: {best_metric:.4f} "
                message3 = f"at epoch: {best_metric_epoch}"
                print(message1, message2, message3)

np.save("/scratch/scratch6/akansh12/challenges/parse2022/temp/epoch_loss.npy", epoch_loss_values)
np.save("/scratch/scratch6/akansh12/challenges/parse2022/temp/metric_values.npy", metric_values)

  0%|          | 0/600 [00:00<?, ?it/s]

1/45, train_loss: 0.6867
2/45, train_loss: 0.6869
3/45, train_loss: 0.6846
4/45, train_loss: 0.6794
5/45, train_loss: 0.6792
6/45, train_loss: 0.6813
7/45, train_loss: 0.6818
8/45, train_loss: 0.6701
9/45, train_loss: 0.6757
10/45, train_loss: 0.6768
11/45, train_loss: 0.6724
12/45, train_loss: 0.6710
13/45, train_loss: 0.6733
14/45, train_loss: 0.6716
15/45, train_loss: 0.6678
16/45, train_loss: 0.6688
17/45, train_loss: 0.6686
18/45, train_loss: 0.6642
19/45, train_loss: 0.6686
20/45, train_loss: 0.6621
21/45, train_loss: 0.6631
22/45, train_loss: 0.6591
23/45, train_loss: 0.6569
24/45, train_loss: 0.6623
25/45, train_loss: 0.6572
26/45, train_loss: 0.6486
27/45, train_loss: 0.6568
28/45, train_loss: 0.6553
29/45, train_loss: 0.6562
30/45, train_loss: 0.6584
31/45, train_loss: 0.6564
32/45, train_loss: 0.6550
33/45, train_loss: 0.6566
34/45, train_loss: 0.6514
35/45, train_loss: 0.6466
36/45, train_loss: 0.6510
37/45, train_loss: 0.6481
38/45, train_loss: 0.6611
39/45, train_loss: 0.

KeyboardInterrupt: 