In [31]:
import math
import numpy as np
from enum import Enum
from typing import ( 
    List, 
    Tuple,
    Optional,
    Dict
)
import random
import torch
from torch import Tensor
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import functional as F, InterpolationMode
from sklearn.cluster import KMeans
from sklearn.metrics import pairwise_distances_argmin_min
from batchgenerators.dataloading.data_loader import SlimDataLoaderBase
from batchgenerators.transforms.spatial_transforms import (
    SpatialTransform, 
    MirrorTransform
)
from batchgenerators.transforms.resample_transforms import SimulateLowResolutionTransform
from batchgenerators.transforms.noise_transforms import GaussianNoiseTransform, GaussianBlurTransform
from batchgenerators.transforms.color_transforms import (
    BrightnessMultiplicativeTransform, 
    ContrastAugmentationTransform, 
    GammaTransform
)
from batchgenerators.transforms.utility_transforms import (
    RemoveLabelTransform, 
    RenameTransform, 
    NumpyToTensor
)
from batchgenerators.dataloading.data_loader import SlimDataLoaderBase
from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter
from batchgenerators.transforms.local_transforms import (
    BrightnessGradientAdditiveTransform,
    LocalGammaTransform,
    LocalSmoothingTransform,
    LocalContrastTransform
)
from batchgenerators.transforms.abstract_transforms import (
    Compose,
    AbstractTransform
)

from omegaconf import OmegaConf
import sys
sys.path.append('..')
from dataset import *
from data_utils import *

In [51]:
conf = OmegaConf.load('../configs/basic_config.yaml')
conf.debug = True

In [52]:
print(OmegaConf.to_yaml(conf))

debug: true
wandb:
  log: true
  project: MICCAI2023-extension
fs:
  root: ../../
data:
  calgary:
    data_path: data/conp-dataset/projects/calgary-campinas/CC359/Reconstructed/
  mnm:
    data_path: data/mnm/
model:
  unet:
    calgary:
      pre: calgary_unet
      n_chans_in: 1
      n_filters_in: 8
      n_chans_out: 1
      training:
        save_loc: pre-trained
        train_site: 6
        augment: true
        validation: true
        batch_size: 32
        num_batches_per_epoch: 250
        num_val_batches_per_epoch: 50
        epochs: 250
        patience: 4
        lr: 0.001
    acdc:
      pre: acdc_unt8_
      n_chans_in: 1
      n_filters_in: 8
      n_chans_out: 4
      training:
        save_loc: pre-trained
        augment: true
        validation: true
        batch_size: 32
        num_batches_per_epoch: 250
        num_val_batches_per_epoch: 50
        epochs: 250
        patience: 4
        lr: 0.001



In [53]:
def get_brain_train_loader(
    training: str, # unet or dae
    cfg: OmegaConf
):
    return_orig = True if training == 'dae' else False
    transform_key = 'local_transforms' if training == 'dae' else 'all_transforms'
    
    data_path = cfg.fs.root + cfg.data.calgary.data_path
    model_cfg = cfg.model.unet.calgary
    
    train_set = CalgaryCampinasDataset(
        data_path=data_path, 
        site=model_cfg.training.train_site,
        augment=False, 
        normalize=True, 
        split='train', 
        debug=cfg.debug
    )
    
    train_loader = MultiImageSingleViewDataLoader(
        data=train_set, 
        batch_size=model_cfg.training.batch_size,
        return_orig=return_orig
    )
    
    transforms = Transforms()
    train_augmentor = transforms.get_transforms(transform_key)
    train_gen = MultiThreadedAugmenter(
        data_loader = train_loader, 
        transform = train_augmentor, 
        num_processes = 4, 
        num_cached_per_queue = 2, 
        seeds=None
    )
    
    if training == 'unet':
        valid_set = CalgaryCampinasDataset(
            data_path=data_path, 
            site=model_cfg.training.train_site,
            normalize=True, 
            volume_wise=True,
            split='validation'
        )

        valid_gen = DataLoader(
            valid_set, 
            batch_size=1,
            shuffle=False, 
            drop_last=False, 
            collate_fn=volume_collate
        )
        
    elif training == 'dae':
        valid_set = CalgaryCampinasDataset(
            data_path=data_path, 
            site=model_cfg.training.train_site,
            augment=False, 
            normalize=True, 
            split='validation', 
            debug=cfg.debug
        )

        valid_augmentor = transforms.get_transforms('local_val_transforms')
        valid_loader = MultiImageSingleViewDataLoader(
            valid_set,
            batch_size=model_cfg.training.batch_size,
            return_orig=True
        )
        valid_gen = MultiThreadedAugmenter(
            data_loader = valid_loader, 
            transform = valid_augmentor, 
            num_processes = 4, 
            num_cached_per_queue = 2, 
            seeds=None
        )
    
    return train_gen, valid_gen

In [56]:
def get_heart_train_loader(
    training: str,
    cfg: OmegaConf
):
    return_orig = True if training == 'dae' else False
    transform_key = 'local_transforms' if training == 'dae' else 'all_transforms'
    
    model_cfg = cfg.model.unet.acdc
    
    transforms = Transforms()
    train_set = ACDCDataset(
        data="train",
        debug=cfg['debug']
    )
    train_loader = MultiImageSingleViewDataLoader(
        data=train_set, 
        batch_size=model_cfg.training.batch_size,
        return_orig=False
    )    
    train_augmentor = transforms.get_transforms(transform_key)
    train_gen = MultiThreadedAugmenter(
        data_loader = train_loader, 
        transform = train_augmentor, 
        num_processes = 4, 
        num_cached_per_queue = 2, 
        seeds=None
    )
    
    val_set = ACDCDataset(
        data="val",
        debug=cfg['debug']
    )
    valid_loader = MultiImageSingleViewDataLoader(
        data=val_set, 
        batch_size=model_cfg.training.batch_size,
        return_orig=False
    )
    valid_augmentor = transforms.get_transforms('io_transforms')
    valid_gen = MultiThreadedAugmenter(
        data_loader = valid_loader, 
        transform = valid_augmentor, 
        num_processes = 4, 
        num_cached_per_queue = 2, 
        seeds=None
    )
    
    return train_gen, valid_gen

In [59]:
t, v = get_brain_train_loader('dae', conf)

In [58]:
t, v = get_heart_train_loader('unet', conf)

loading dataset
loading all case properties
loading dataset
loading all case properties
