In [53]:
import os
import json
import shutil
import tempfile
import time

import matplotlib.pyplot as plt
import numpy as np
import nibabel as nib

from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai import transforms
from monai.transforms import (
    AsDiscrete,
    Activations,
)

from monai.metrics import DiceMetric
from monai.utils.enums import MetricReduction
from monai.networks.nets import SwinUNETR, SegResNet, UNet
from monai import data
from monai.data import decollate_batch
from functools import partial

import torch


In [2]:
def datafold_read(datalist, basedir, fold=0, key="training"):
    with open(datalist) as f:
        json_data = json.load(f)

    json_data = json_data[key]

    for d in json_data:
        for k in d:
            if isinstance(d[k], list):
                d[k] = [os.path.join(basedir, iv) for iv in d[k]]
            elif isinstance(d[k], str):
                d[k] = os.path.join(basedir, d[k]) if len(d[k]) > 0 else d[k]

    tr = []
    val = []
    for d in json_data:
        if "fold" in d and d["fold"] == fold:
            val.append(d)
        else:
            tr.append(d)

    return tr, val

In [46]:
data_dir = '/home/ikboljon.sobirov/data/fs1_research/Ikboljon.Sobirov/imagecas/imagecas/resampled_space/'
json_file = '/home/ikboljon.sobirov/data/nas/ikboljon.sobirov/image_cas/chuqur_organish_asoslari/module_6/train_data.json'
fold = 1
roi = (96, 96, 96)
batch_size = 1


In [47]:
train_files, validation_files = datafold_read(datalist=json_file, basedir=data_dir, fold=fold)


In [48]:
# z_norm = (img - mean)/std
# min_max = (img - min)/(max - min)

In [49]:
train_transform = transforms.Compose(
    [
        transforms.LoadImaged(keys=["image_path", "mask_path"], ensure_channel_first=True),
        transforms.SpatialPadd(keys=["image_path", "mask_path"], spatial_size=roi),
        transforms.RandCropByPosNegLabeld(keys=["image_path", "mask_path"], 
                                        label_key="mask_path",
                                        spatial_size=roi,
                                        num_samples=4,
                                        image_key="image_path",),
        # transforms.CropForegroundd(
        #     keys=["image", "label"],
        #     source_key="image",
        #     k_divisible=[roi[0], roi[1], roi[2]],
        # ),
        # transforms.RandSpatialCropd(
        #     keys=["image", "label"],
        #     roi_size=[roi[0], roi[1], roi[2]],
        #     random_size=False,
        # ),
        # transforms.RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
        # transforms.RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
        # transforms.RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
        # transforms.NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
        # transforms.RandScaleIntensityd(keys="image", factors=0.1, prob=1.0),
        # transforms.RandShiftIntensityd(keys="image", offsets=0.1, prob=1.0),
    ]
)
val_transform = transforms.Compose(
    [
        transforms.LoadImaged(keys=["image_path", "mask_path"], ensure_channel_first=True),
        transforms.SpatialPadd(keys=["image_path", "mask_path"], spatial_size=roi),
    ]
)

train_ds = data.Dataset(data=train_files, transform=train_transform)

train_loader = data.DataLoader(
    train_ds,
    batch_size=batch_size,
    shuffle=True,
    num_workers=8,
    pin_memory=True,
)

val_ds = data.Dataset(data=validation_files, transform=val_transform)
val_loader = data.DataLoader(
    val_ds,
    batch_size=1,
    shuffle=False,
    num_workers=8,
    pin_memory=True,
)



In [50]:
a = next(iter(train_loader))

In [54]:
net = SegResNet(
    in_channels=1,
    out_channels=1,
)