In [26]:
import torch
import monai
from monai.utils import first, set_determinism
from monai.transforms import (
    AsDiscrete,
    AsDiscreted,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandCropByPosNegLabeld,
    SaveImaged,
    ScaleIntensityRanged,
    Spacingd, 
    EnsureTyped,
    EnsureType,
    Invertd,
    MeanEnsembled,
    Activationsd
)
from monai.metrics import DiceMetric
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch
from monai.config import print_config
from monai.apps import download_and_extract
import matplotlib.pyplot as plt
import tempfile
import shutil
import os
import glob
import nibabel as nib
import numpy as np
from tqdm.notebook import tqdm
import nibabel as nib
import warnings
warnings.filterwarnings('ignore')
from monai.inferers import SimpleInferer, SlidingWindowInferer

In [15]:
from monai.networks.nets import UNet
from monai.networks.layers import Norm
import torch
from collections import OrderedDict
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 

UNet_meatdata = dict(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH
    )

In [30]:
root_dir = "/scratch/scratch6/akansh12/Parse_data/train/train/"
train_images = sorted(glob.glob(os.path.join(root_dir, "*", 'image', "*.nii.gz")))
train_labels = sorted(glob.glob(os.path.join(root_dir, "*", 'label', "*.nii.gz")))

data_dicts = [{"image": images_name, "label": label_name} for images_name, label_name in zip(train_images, train_labels)]
train_files, val_files = data_dicts[:-9], data_dicts[-9:]
set_determinism(seed = 0)

In [31]:
val_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        Orientationd(keys=["image", "label"], axcodes="LPS"),

        ScaleIntensityRanged(
            keys=["image"], a_min=-1000, a_max=1000,
            b_min=0.0, b_max=1.0, clip=True,
        ),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        EnsureTyped(keys=["image", "label"]),
    ]
)

val_ds = CacheDataset(
    data = val_files, transform = val_transforms,
    cache_rate = 1.0, num_workers = 4
)
val_loader = DataLoader(val_ds, batch_size = 1, shuffle = False, num_workers=4)

Loading dataset: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:34<00:00,  3.80s/it]


In [32]:
#metric
from monai.metrics import DiceMetric
dice_metric = DiceMetric(include_background=False, reduction="mean")

In [33]:
model_1 = UNet(**UNet_meatdata).to(device)
path2weights = "/scratch/scratch6/akansh12/challenges/parse2022/temp/selected_models/unet_1000_hu_160_0853.pth"
state_dict = torch.load(path2weights, map_location='cpu')
for keyA, keyB in zip(state_dict, model_1.state_dict()):
    state_dict = OrderedDict((keyB if k == keyA else k, v) for k, v in state_dict.items())
model_1.load_state_dict(state_dict)

model_2 = UNet(**UNet_meatdata).to(device)
path2weights = "/scratch/scratch6/akansh12/challenges/parse2022/temp/selected_models/unet_1000_hu_160_8550.pth"
state_dict = torch.load(path2weights, map_location='cpu')
for keyA, keyB in zip(state_dict, model_2.state_dict()):
    state_dict = OrderedDict((keyB if k == keyA else k, v) for k, v in state_dict.items())
model_2.load_state_dict(state_dict)

model_3 = UNet(**UNet_meatdata).to(device)
path2weights = "/scratch/scratch6/akansh12/challenges/parse2022/temp/selected_models/unet_1000_hu_160_w_augmentations_8551.pth"
state_dict = torch.load(path2weights, map_location='cpu')
for keyA, keyB in zip(state_dict, model_3.state_dict()):
    state_dict = OrderedDict((keyB if k == keyA else k, v) for k, v in state_dict.items())
model_3.load_state_dict(state_dict)

##More-data-models
model_4 = UNet(**UNet_meatdata).to(device)
path2weights = "/scratch/scratch6/akansh12/challenges/parse2022/temp/selected_models/unet_1000_hu_160_w_augmentations_more_data_8838.pth"
state_dict = torch.load(path2weights, map_location='cpu')
for keyA, keyB in zip(state_dict, model_4.state_dict()):
    state_dict = OrderedDict((keyB if k == keyA else k, v) for k, v in state_dict.items())
model_4.load_state_dict(state_dict)

model_5 = UNet(**UNet_meatdata).to(device)
path2weights = "/scratch/scratch6/akansh12/challenges/parse2022/temp/selected_models/unet_1000_hu_160_w_augmentations_more_data_focal_8860.pth"
state_dict = torch.load(path2weights, map_location='cpu')
for keyA, keyB in zip(state_dict, model_5.state_dict()):
    state_dict = OrderedDict((keyB if k == keyA else k, v) for k, v in state_dict.items())
model_5.load_state_dict(state_dict)



<All keys matched successfully>

In [34]:
models = [model_1, model_2, model_3, model_4, model_5]

### Ensemble max voting

In [43]:
post_pred = Compose([EnsureType(), AsDiscrete(argmax=True, to_onehot=2)])
post_label = Compose([EnsureType(), AsDiscrete(to_onehot=2)])
dice_metric = DiceMetric(include_background=False, reduction="mean")

In [44]:
metric_values = []
with torch.no_grad():
    for index, val_data in enumerate(tqdm(val_loader)):

        val_inputs, val_labels = val_data['image'].to(device), val_data['label'].to(device)
        roi_size = (160, 160, 160)
        sw_batch_size = 8
        val_output = 0
        
        for mod in tqdm(models):
            mod.eval()
            val_output_ = sliding_window_inference(
                            val_inputs, roi_size, sw_batch_size, mod)
            val_output_ = [post_pred(i) for i in decollate_batch(val_output_)]

            val_output = val_output + val_output_[0]

        val_output = [(val_output/len(models))>0.5]
        val_labels = [post_label(i) for i in decollate_batch(val_labels)]
        print(dice_metric(y_pred=val_output, y=val_labels))
    metric = dice_metric.aggregate().item()
    dice_metric.reset()
print(metric)

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

tensor([[0.9033]])


  0%|          | 0/5 [00:00<?, ?it/s]

tensor([[0.8959]])


  0%|          | 0/5 [00:00<?, ?it/s]

tensor([[0.7872]])


  0%|          | 0/5 [00:00<?, ?it/s]

tensor([[0.8136]])


  0%|          | 0/5 [00:00<?, ?it/s]

tensor([[0.7729]])


  0%|          | 0/5 [00:00<?, ?it/s]

tensor([[0.8838]])


  0%|          | 0/5 [00:00<?, ?it/s]

tensor([[0.9096]])


  0%|          | 0/5 [00:00<?, ?it/s]

tensor([[0.8726]])


  0%|          | 0/5 [00:00<?, ?it/s]

tensor([[0.8696]])
0.8564942479133606


In [None]:
metric_values = []
with torch.no_grad():
    for index, val_data in enumerate(tqdm(val_loader)):

        val_inputs, val_labels = val_data['image'].to(device), val_data['label'].to(device)
        roi_size = (160, 160, 160)
        sw_batch_size = 8
        val_output_ = sliding_window_inference(
                            val_inputs, roi_size, sw_batch_size, mod)

In [None]:
for mod in tqdm(models):
            mod.eval()
            val_output_ = sliding_window_inference(
                            val_inputs, roi_size, sw_batch_size, mod)
            val_output_ = [post_pred(i) for i in decollate_batch(val_output_)]

            val_output = val_output + val_output_[0]