In [None]:
from pathlib import Path
from typing import List

import pydicom

import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

from kedro.extras.datasets.pickle import PickleDataSet
from kedro.config import ConfigLoader

from monai.transforms import (
    HistogramNormalized,
    Compose,
    RandSpatialCropSamplesd,
    RandAxisFlipd,
    RandAffined,
    Rand2DElasticd,
    RandBiasFieldd,
    RandGaussianNoised,
    RandGaussianSmoothd,
    RandGaussianSharpend,
    RandKSpaceSpikeNoised,
    EnsureTyped
)
from monai.data import list_data_collate, CacheDataset

import torch
from torch.utils.data import TensorDataset, random_split, DataLoader

import kornia.augmentation as K
import plotly.express as px

from PIL import Image

import torch
from torch.utils.data import DataLoader, TensorDataset, random_split
from torchvision.utils import save_image
from torchvision import transforms

from torch.autograd import Variable

In [None]:
import os, sys
sys.path.append(os.path.abspath('..'))
sys.path.append(os.path.abspath('../src/'))

from src.tagseg.data.acdc_dataset import AcdcDataSet
from src.tagseg.data.dmd_dataset import DmdDataSet, DmdTimeDataSet, Slice
from src.tagseg.pipelines.data_splitting.nodes import split_data

In [None]:
conf_paths = ["../conf/base", "../conf/local"]
conf_loader = ConfigLoader(conf_paths)
conf_catalog = conf_loader.get("catalog*", "catalog*/**")

In [None]:
dataset = PickleDataSet(filepath='../' + conf_catalog['model_input']['filepath']).load()

In [None]:
dataset = list(map(lambda b: (b[0].cpu(), b[1].cpu(),), dataset))

In [None]:
train_val_split: float = .75
batch_size: int = 8

n: int = len(dataset)
n_train: int = round(n * train_val_split)
split: List[int] = [n_train, n - n_train]

dataset = [dict(image=image, label=label) for image, label in dataset]

train_set, val_set = random_split(
    dataset, split, generator=torch.Generator().manual_seed(42)
)

print(f"Split dataset into {split} train/val.")
print(f"Dataset length is {len(train_set)}/{len(val_set)} train/val.")

probability: float = 0.15

train_transforms = Compose(
    [
        HistogramNormalized(keys=["image"]),
        RandSpatialCropSamplesd(
            keys=["image", "label"],
            roi_size=(128, 128),
            num_samples=4,
            random_center=True,
            random_size=False,
        ),
        RandAxisFlipd(keys=["image", "label"], prob=probability),
        RandAffined(keys=["image", "label"], prob=probability),
        Rand2DElasticd(
            keys=["image", "label"],
            prob=probability,
            spacing=(16, 16),
            magnitude_range=(1, 2),
            rotate_range=0.25,
            padding_mode='zeros'
        ),
        RandBiasFieldd(keys=["image"], prob=probability),
        RandGaussianNoised(keys=["image"], prob=probability),
        RandGaussianSmoothd(keys=["image"], prob=probability),
        RandGaussianSharpend(keys=["image"], prob=probability),
        RandKSpaceSpikeNoised(keys=["image"], prob=probability),
        EnsureTyped(keys=["image", "label"]),
    ]
)

val_transforms = Compose(
    [
        HistogramNormalized(keys=["image"]),
        EnsureTyped(keys=["image", "label"]),
    ]
)

train_ds = CacheDataset(data=train_set, transform=train_transforms, cache_rate=1.0)
val_ds = CacheDataset(data=val_set, transform=val_transforms, cache_rate=1.0)

loader_train = DataLoader(
    train_ds, batch_size=batch_size, shuffle=True, collate_fn=list_data_collate)

loader_val = DataLoader(val_ds, batch_size=1)

In [None]:
batch = next(iter(loader_train))

In [None]:
batch['image'].shape

In [None]:
batch['image']

In [None]:
M, N = 20, 5
fig, ax = plt.subplots(M, N, figsize=(20, 100))

for i in range(M * N):
    m, n = i % M, i // M
    ax[m, n].imshow(dataset[i][0][0].cpu(), cmap='gray')
    ax[m, n].imshow(dataset[i][1].cpu(), cmap='Reds', alpha=0.3)
    ax[m, n].axis('off')