In [2]:
import matplotlib.pyplot as plt
import wandb
from tqdm.auto import tqdm
import os, sys
import time
import numpy as np
import collections
import torch
from torch import Tensor, nn
from torch.utils.data import Dataset, DataLoader
from torch import optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.cuda.amp import autocast, GradScaler
from torchvision.transforms import Resize, CenterCrop
from typing import Iterable, Dict, Callable, Tuple
import torch.nn.functional as F
import matplotlib.pyplot as plt
from random import randrange

from nnunet.training.model_restore import restore_model
import batchgenerators
from batchgenerators.transforms.local_transforms import *
from batchgenerators.dataloading.single_threaded_augmenter import SingleThreadedAugmenter
from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter
from batchgenerators.utilities.file_and_folder_operations import *
from nnunet.paths import preprocessing_output_dir
from nnunet.training.dataloading.dataset_loading import *
from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
from nnunet.run.load_pretrained_weights import load_pretrained_weights

sys.path.append('../..')
from dataset import CalgaryCampinasDataset, ACDCDataset, MNMDataset
from augment import SingleImageMultiViewDataLoader, MultiImageSingleViewDataLoader
from utils import EarlyStopping, epoch_average, average_metrics
from model.dae import AugResDAE
from model.unet import UNet2D
from model.wrapper import ModelAdapter
from losses import MNMCriterionAE, SampleDice, UnetDice
from trainer.ae_trainer import AETrainerACDCV2
from augment import nnUNet_train_augmentations

In [3]:
train_augmentor = batchgenerators.transforms.abstract_transforms.Compose(nnUNet_train_augmentations)


In [4]:
data_path = '../../../data/conp-dataset/projects/calgary-campinas/CC359/Reconstructed/'

trainset = CalgaryCampinasDataset(
    data_path=data_path, 
    site=6,
    split='train',
    augment=False, 
    normalize=True, 
    debug=True
)


In [6]:
dataloader = MultiImageSingleViewDataLoader(data=trainset, batch_size=2, return_orig=False)

In [8]:
train_gen = MultiThreadedAugmenter(dataloader, train_augmentor, 4, 2, seeds=None)

In [10]:
tmp = next(train_gen)

In [19]:
tmp['data'].mean()

tensor(-0.2401)

In [20]:
trainset[0]['input'].shape

torch.Size([1, 256, 256])

In [23]:
dataloader.generate_train_batch()['data'].shape

(2, 1, 256, 256)

In [4]:
it = 0
residual = True

cfg = {
    'debug': False,
    'log': True,
    'description': f'acdc_AugResDAE{it}_localAug_multiImgSingleView_{"res" if residual else "recon"}_balanced_same', #'mms_vae_for_nnUNet_fc3_0_bs50',
    'project': 'MICCAI2023-loose_ends',

    # Data params
    'n': 0,
    'root': '../../',
    'data_path': 'data/mnm/',
    'unet': f'acdc_unet8_{it}',
    'channel_out': 8,

    # Hyperparams
    'batch_size': 32,
    'lr': 1e-4,
    'augment': False,
    'difference': True,
    'loss': 'huber',  # huber or ce
    'target': 'output', #gt or output
    'reconstruction': True,
    'augmentations': 'local',
    'disabled_ids': ['shortcut0', 'shortcut1', 'shortcut2'], #['shortcut0', 'shortcut1', 'shortcut2']
}

nnUnet_prefix = '../../../../nnUNet/'

description = cfg['description']
root = cfg['root']

# # Unet
# unet_path = cfg['unet'] # + str(cfg['n'])
# unet = UNet2D(n_chans_in=1, n_chans_out=4, n_filters_init=cfg['channel_out']).cuda()
# model_path = f'{root}pre-trained-tmp/trained_UNets/{unet_path}_best.pt'
# state_dict = torch.load(model_path)['model_state_dict']
# unet.load_state_dict(state_dict)

### Dataloader
## Initialize trainer to get data loaders with data augmentations from training
pkl_file          = nnUnet_prefix + 'data/nnUNet_preprocessed/Task500_ACDC/nnUNetPlansv2.1_plans_2D.pkl'
fold              = 0
output_folder     = nnUnet_prefix + 'results/nnUnet/nnUNet/2d/Task027_ACDC/nnUNetTrainerV2__nnUNetPlansv2.1/'
dataset_directory = nnUnet_prefix + 'data/nnUNet_preprocessed/Task500_ACDC'

trainer = nnUNetTrainerV2(pkl_file, 0, output_folder, dataset_directory)
trainer.initialize()

train_loader = trainer.tr_gen
valid_loader = trainer.val_gen



if cfg['augmentations'] == 'all':
    train_transforms = [t for t in train_loader.transform.transforms]
    valid_transforms = [t for t in valid_loader.transform.transforms]

elif cfg['augmentations'] == 'output_invariant':
    data_only_transforms = (
        batchgenerators.transforms.resample_transforms.SimulateLowResolutionTransform,
        batchgenerators.transforms.noise_transforms.GaussianNoiseTransform,
        batchgenerators.transforms.noise_transforms.GaussianBlurTransform,
        batchgenerators.transforms.color_transforms.BrightnessMultiplicativeTransform,
        batchgenerators.transforms.color_transforms.ContrastAugmentationTransform,
        batchgenerators.transforms.color_transforms.GammaTransform,
        batchgenerators.transforms.utility_transforms.RemoveLabelTransform,
        batchgenerators.transforms.utility_transforms.RenameTransform,
        batchgenerators.transforms.utility_transforms.NumpyToTensor
    )

    train_transforms = [t for t in train_loader.transform.transforms if isinstance(t, data_only_transforms)]
    valid_transforms = [t for t in valid_loader.transform.transforms if isinstance(t, data_only_transforms)]

elif cfg['augmentations'] == 'local':
    original_transforms = (
        batchgenerators.transforms.resample_transforms.SimulateLowResolutionTransform,
        batchgenerators.transforms.noise_transforms.GaussianNoiseTransform,
        batchgenerators.transforms.utility_transforms.RemoveLabelTransform,
        batchgenerators.transforms.utility_transforms.RenameTransform,
        batchgenerators.transforms.utility_transforms.NumpyToTensor
    )

    scale = 200.
    local_transforms = [
        BrightnessGradientAdditiveTransform(scale=scale, max_strength=4, p_per_sample=0.2, p_per_channel=1),
        LocalGammaTransform(scale=scale, gamma=(2, 5), p_per_sample=0.2, p_per_channel=1),
        LocalSmoothingTransform(scale=scale, smoothing_strength=(0.5, 1), p_per_sample=0.2, p_per_channel=1),
        LocalContrastTransform(scale=scale, new_contrast=(1, 3), p_per_sample=0.2, p_per_channel=1),
    ]

    train_transforms = local_transforms + [t for t in train_loader.transform.transforms if isinstance(t, original_transforms)]
    valid_transforms = local_transforms + [t for t in valid_loader.transform.transforms if isinstance(t, original_transforms)]

train_augmentor = batchgenerators.transforms.abstract_transforms.Compose(train_transforms)
valid_augmentor = batchgenerators.transforms.abstract_transforms.Compose(valid_transforms)
### - Load dataset and init batch generator
train_data = ACDCDataset(data='train', debug=False, root='../../../../')
valid_data = ACDCDataset(data='val', debug=False, root='../../../../')

train_gen = MultiImageSingleViewDataLoader(train_data, batch_size=cfg['batch_size'], return_orig=True)
#train_gen = SingleThreadedAugmenter(train_gen, train_augmentor)
train_gen = MultiThreadedAugmenter(train_gen, train_augmentor, 4, 2, seeds=None)
valid_gen = MultiImageSingleViewDataLoader(valid_data, batch_size=cfg['batch_size'], return_orig=True)
#valid_gen = SingleThreadedAugmenter(valid_gen, valid_augmentor)
valid_gen = MultiThreadedAugmenter(valid_gen, valid_augmentor, 4, 2, seeds=None)


loading dataset
loading all case properties
2023-11-27 17:23:28.671876: Using splits from existing split file: ../../../../nnUNet/data/nnUNet_preprocessed/Task500_ACDC/splits_final.pkl
2023-11-27 17:23:28.692560: The split file contains 5 splits.
2023-11-27 17:23:28.692713: Desired fold for training: 0
2023-11-27 17:23:28.693316: This split has 160 training and 40 validation cases.
unpacking dataset
done
loading dataset
loading all case properties
loading dataset
loading all case properties


In [3]:
pkl_file

'../../../../nnUNet/data/nnUNet_preprocessed/Task500_ACDC/nnUNetPlansv2.1_plans_2D.pkl'

In [4]:
import pickle

In [8]:
pickle.load(open(pkl_file, "rb"))

{'num_stages': 1,
 'num_modalities': 1,
 'modalities': {0: 'MRI'},
 'normalization_schemes': OrderedDict([(0, 'nonCT')]),
 'dataset_properties': {'all_sizes': [(10, 256, 216),
   (10, 256, 216),
   (10, 256, 231),
   (10, 256, 231),
   (10, 255, 256),
   (10, 255, 256),
   (10, 256, 232),
   (10, 256, 232),
   (10, 216, 255),
   (10, 216, 255),
   (11, 256, 231),
   (11, 256, 231),
   (10, 223, 222),
   (10, 223, 222),
   (10, 256, 199),
   (10, 256, 199),
   (10, 256, 208),
   (10, 256, 208),
   (10, 256, 208),
   (10, 256, 208),
   (9, 256, 216),
   (9, 256, 216),
   (10, 256, 184),
   (10, 256, 184),
   (10, 256, 216),
   (10, 256, 216),
   (10, 216, 256),
   (10, 216, 256),
   (9, 216, 256),
   (9, 216, 256),
   (10, 256, 192),
   (10, 256, 192),
   (9, 256, 216),
   (9, 256, 216),
   (8, 256, 216),
   (8, 256, 216),
   (11, 256, 216),
   (11, 256, 216),
   (8, 256, 208),
   (8, 256, 208),
   (10, 255, 240),
   (10, 255, 240),
   (7, 256, 200),
   (7, 256, 200),
   (9, 256, 216),
 

In [7]:
pkl_file

'../../../../nnUNet/data/nnUNet_preprocessed/Task500_ACDC/nnUNetPlansv2.1_plans_2D.pkl'

In [7]:
train_transforms = [t for t in train_loader.transform.transforms]
valid_transforms = [t for t in valid_loader.transform.transforms]

In [8]:
import pprint
pprint.pprint(valid_transforms)

[RemoveLabelTransform( output_key = 'seg', input_key = 'seg', replace_with = 0, remove_label = -1 ),
 SegChannelSelectionTransform( label_key = 'seg', channels = [0], keep_discarded = False ),
 RenameTransform( delete_old = True, out_key = 'target', in_key = 'seg' ),
 DownsampleSegForDSTransform2( axes = None, output_key = 'target', input_key = 'target', order = 0, ds_scales = [[1, 1, 1], [0.5, 0.5], [0.25, 0.25], [0.125, 0.125], [0.0625, 0.0625], [0.03125, 0.03125]] ),
 NumpyToTensor( keys = ['data', 'target'], cast_to = 'float' )]


In [6]:
valid_transforms

[<batchgenerators.transforms.local_transforms.BrightnessGradientAdditiveTransform at 0x7f7ce35d8e20>,
 <batchgenerators.transforms.local_transforms.LocalGammaTransform at 0x7f7dd43df9d0>,
 <batchgenerators.transforms.local_transforms.LocalSmoothingTransform at 0x7f7ce0394af0>,
 <batchgenerators.transforms.local_transforms.LocalContrastTransform at 0x7f7ce0394a90>,
 RemoveLabelTransform( output_key = 'seg', input_key = 'seg', replace_with = 0, remove_label = -1 ),
 RenameTransform( delete_old = True, out_key = 'target', in_key = 'seg' ),
 NumpyToTensor( keys = ['data', 'target'], cast_to = 'float' )]

In [15]:
torch.save(train_transforms, 'test.pt')

In [16]:
test = torch.load('test.pt')

In [18]:
[SegChannelSelectionTransform( label_key = 'seg', channels = [0], keep_discarded = False ),
 SpatialTransform( independent_scale_for_each_axis = False, p_rot_per_sample = 0.2, p_scale_per_sample = 0.2, p_el_per_sample = 0.2, data_key = 'data', label_key = 'seg', patch_size = array([256, 224]), patch_center_dist_from_border = None, do_elastic_deform = False, alpha = (0.0, 200.0), sigma = (9.0, 13.0), do_rotation = True, angle_x = (-3.141592653589793, 3.141592653589793), angle_y = (-0.0, 0.0), angle_z = (-0.0, 0.0), do_scale = True, scale = (0.7, 1.4), border_mode_data = 'constant', border_cval_data = 0, order_data = 3, border_mode_seg = 'constant', border_cval_seg = -1, order_seg = 1, random_crop = False, p_rot_per_axis = 1, p_independent_scale_per_axis = 1 ),
 GaussianNoiseTransform( p_per_sample = 0.1, data_key = 'data', noise_variance = (0, 0.1), p_per_channel = 1, per_channel = False ),
 GaussianBlurTransform( p_per_sample = 0.2, different_sigma_per_channel = True, p_per_channel = 0.5, data_key = 'data', blur_sigma = (0.5, 1.0), different_sigma_per_axis = False, p_isotropic = 0 ),
 BrightnessMultiplicativeTransform( p_per_sample = 0.15, data_key = 'data', multiplier_range = (0.75, 1.25), per_channel = True ),
 ContrastAugmentationTransform( p_per_sample = 0.15, data_key = 'data', contrast_range = (0.75, 1.25), preserve_range = True, per_channel = True, p_per_channel = 1 ),
 SimulateLowResolutionTransform( order_upsample = 3, order_downsample = 0, channels = None, per_channel = True, p_per_channel = 0.5, p_per_sample = 0.25, data_key = 'data', zoom_range = (0.5, 1), ignore_axes = None ),
 GammaTransform( p_per_sample = 0.1, retain_stats = True, per_channel = True, data_key = 'data', gamma_range = (0.7, 1.5), invert_image = True ),
 GammaTransform( p_per_sample = 0.3, retain_stats = True, per_channel = True, data_key = 'data', gamma_range = (0.7, 1.5), invert_image = False ),
 MirrorTransform( p_per_sample = 1, data_key = 'data', label_key = 'seg', axes = (0, 1) ),
 MaskTransform( dct_for_where_it_was_used = OrderedDict([(0, False)]), seg_key = 'seg', data_key = 'data', set_outside_to = 0, mask_idx_in_seg = 0 ),
 RemoveLabelTransform( output_key = 'seg', input_key = 'seg', replace_with = 0, remove_label = -1 ),
 RenameTransform( delete_old = True, out_key = 'target', in_key = 'seg' ),
 DownsampleSegForDSTransform2( axes = None, output_key = 'target', input_key = 'target', order = 0, ds_scales = [[1, 1, 1], [0.5, 0.5], [0.25, 0.25], [0.125, 0.125], [0.0625, 0.0625], [0.03125, 0.03125]] ),
 NumpyToTensor( keys = ['data', 'target'], cast_to = 'float' )]

NameError: name 'SegChannelSelectionTransform' is not defined

In [None]:
from augment import nnUNet_augmentations