In [7]:
import torch
import monai
from monai.utils import first, set_determinism
from monai.transforms import (
    AsDiscrete,
    AsDiscreted,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandCropByPosNegLabeld,
    SaveImaged,
    ScaleIntensityRanged,
    Spacingd, 
    EnsureTyped,
    EnsureType,
    Invertd,
    KeepLargestConnectedComponent,
    AddChanneld,
    RandFlipd,
    RandRotate90d,
    RandShiftIntensityd,
    Zoomd,
    RandRotated,
    ToTensord
    
)

# from monai.handlers.utils import from_engine
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.metrics import DiceMetric
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch
from monai.config import print_config
from monai.apps import download_and_extract
import matplotlib.pyplot as plt
import tempfile
import shutil
import os
import glob
import nibabel as nib
import numpy as np
from tqdm.notebook import tqdm

### Dataset

In [3]:
root_dir = "/scratch/scratch6/akansh12/Parse_data/train/train/"
train_images = sorted(glob.glob(os.path.join(root_dir, "*", 'image', "*.nii.gz")))
train_labels = sorted(glob.glob(os.path.join("./labels/", "*.nii.gz")))

data_dicts = [{"images": images_name, "labels": label_name} for images_name, label_name in zip(train_images, train_labels)]
train_files, val_files = data_dicts[:-9], data_dicts[-9:]
set_determinism(seed = 0)

### Transforms

In [9]:
train_transforms = Compose(
    [
        LoadImaged(keys=["images", "labels"]),
        AddChanneld(keys=["images", "labels"]),
        Orientationd(keys=["images", "labels"], axcodes="LPS"),
        # Spacingd(keys=['images', 'labels'], pixdim = (1.5,1.5,2), mode = ("bilinear", 'nearest')),
        ScaleIntensityRanged(
            keys=["images"],
            a_min=-1000,
            a_max=1000,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        CropForegroundd(keys=["images", "labels"], source_key="images"),
        RandCropByPosNegLabeld(
            keys=["images", "labels"],
            label_key="labels",
            spatial_size=(160, 160, 160),
            pos=1,
            neg=1,
            num_samples=4,
            image_key="images",
            image_threshold=0,
        ),
        RandFlipd(
            keys=["images", "labels"],
            spatial_axis=[0],
            prob=0.30,
        ),
        RandFlipd(
            keys=["images", "labels"],
            spatial_axis=[1],
            prob=0.30,
        ),
        RandFlipd(
            keys=["images", "labels"],
            spatial_axis=[2],
            prob=0.30,
        ),
        RandRotate90d(
            keys=["images", "labels"],
            prob=0.20,
            max_k=3,
        ),
        RandShiftIntensityd(
            keys=["images"],
            offsets=0.10,
            prob=0.40,
        ),
        Zoomd(keys = ['images', 'labels'], zoom = 1.3, mode = ['area', 'nearest'], prob = 0.3),
        RandRotated(keys = ['images', 'labels'], prob=0.3, range_x =[0.4,0.4],mode = ("bilinear", 'nearest')),
        ToTensord(keys=["images", "labels"]),
    ]
)

val_transforms = Compose(
    [
        LoadImaged(keys=["images", "labels"]),
        EnsureChannelFirstd(keys=["images", "labels"]),
        Orientationd(keys=["images", "labels"], axcodes="LPS"),
        ScaleIntensityRanged(
            keys=["images"], a_min=-1000, a_max=1000,
            b_min=0.0, b_max=1.0, clip=True,
        ),
        CropForegroundd(keys=["images", "labels"], source_key="images"),
        EnsureTyped(keys=["images", "labels"]),
    ]
)

In [10]:
train_ds = CacheDataset(
    data = train_files, transform = train_transforms,
    cache_rate = 1.0, num_workers = 4
)

train_loader = DataLoader(train_ds, batch_size = 2, shuffle = True, num_workers=4)
val_ds = CacheDataset(
    data = val_files, transform = val_transforms,
    cache_rate = 1.0, num_workers = 4
)
val_loader = DataLoader(val_ds, batch_size = 1, shuffle = False, num_workers=4)

Loading dataset:   0%|                                                                                                                                                       | 0/91 [00:02<?, ?it/s]


KeyboardInterrupt: 

In [11]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 

UNet_meatdata = dict(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH
)
model = UNet(**UNet_meatdata).to(device)

In [12]:
## Weight Initialization


In [15]:
!zip -r labels_main_artery.zip ./labels/

updating: labels/ (stored 0%)
  adding: labels/PA000148.nii.gz (deflated 95%)
  adding: labels/PA000229.nii.gz (deflated 95%)
  adding: labels/PA000132.nii.gz (deflated 96%)
  adding: labels/PA000233.nii.gz (deflated 94%)
  adding: labels/PA000005.nii.gz (deflated 95%)
  adding: labels/PA000082.nii.gz (deflated 94%)
  adding: labels/PA000296.nii.gz (deflated 97%)
  adding: labels/PA000157.nii.gz (deflated 96%)
  adding: labels/PA000107.nii.gz (deflated 94%)
  adding: labels/PA000234.nii.gz (deflated 95%)
  adding: labels/PA000162.nii.gz (deflated 95%)
  adding: labels/PA000226.nii.gz (deflated 95%)
  adding: labels/PA000263.nii.gz (deflated 94%)
  adding: labels/PA000144.nii.gz (deflated 95%)
  adding: labels/PA000070.nii.gz (deflated 95%)
  adding: labels/PA000147.nii.gz (deflated 95%)
  adding: labels/PA000245.nii.gz (deflated 95%)
  adding: labels/PA000183.nii.gz (deflated 94%)
  adding: labels/PA000047.nii.gz (deflated 95%)
  adding: labels/PA000120.nii.gz (deflated 94%)
  adding: 