In [1]:
import matplotlib.pyplot as plt

import collections
import os, sys
import time
from typing import Iterable, Dict, Callable, Tuple

import numpy as np
import torch
from torch import Tensor, nn
from torch.utils.data import Dataset, DataLoader
from torch import optim
import torch.nn.functional as F
from torchvision.transforms import Resize, CenterCrop
import wandb
from tqdm.auto import tqdm
from batchgenerators.utilities.file_and_folder_operations import *
from nnunet.training.dataloading.dataset_loading import *
from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2

sys.path.append('..')
from utils import EarlyStopping, epoch_average, average_metrics
from dataset import CalgaryCampinasDataset
from model.ae import AE
from model.unet import UNet2D
from model.wrapper import Frankenstein
from losses import MNMCriterionAE, CalgaryCriterionAE, SampleDice, UnetDice
from trainer.ae_trainer import AETrainerCalgary, AETrainerACDC


nnUnet_prefix = '../../../nnUNet/'



Please cite the following paper when using nnUNet:

Isensee, F., Jaeger, P.F., Kohl, S.A.A. et al. "nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation." Nat Methods (2020). https://doi.org/10.1038/s41592-020-01008-z


If you have questions or suggestions, feel free to open an issue at https://github.com/MIC-DKFZ/nnUNet



In [2]:
cfg = {
    'debug': True,
    'log': False,
    'description': 'calgary_ae_test',
    'project': 'MICCAI2023',

    # Data params
    'n': 0,
    'root': '../../',
    'data_path': 'data/conp-dataset/projects/calgary-campinas/CC359/Reconstructed/',
    'train_site': 6,
    'unet': 'calgary_unet',
    'channel_out': 8,
    
    # Hyperparams
    'batch_size': 64,
    'augment': False,
    'difference': True,
    'loss': 'huber',
    'target': 'output',
    'identity_layers': ['shortcut0', 'shortcut1', 'shortcut2'],
    
    # outputs
    'plot_dir': '../experiments/unet/calgary/logs/'
}


description = cfg['description'] + str(cfg['n'])
if cfg['augment']:
    description += 'augment'

### data loading 
root      = cfg['root']
data_path = root + cfg['data_path']
train_set = CalgaryCampinasDataset(data_path=data_path, 
                                   site=cfg['train_site'], 
                                   augment=cfg['augment'], 
                                   normalize=True, 
                                   split='train', 
                                   debug=cfg['debug'])

valid_set = CalgaryCampinasDataset(data_path=data_path, 
                                   site=cfg['train_site'], 
                                   normalize=True, 
                                   split='validation', 
                                   debug=cfg['debug'])

train_loader = DataLoader(train_set, 
                          batch_size=cfg['batch_size'], 
                          shuffle=True, 
                          drop_last=False,
                          num_workers=10)

valid_loader = DataLoader(valid_set, 
                          batch_size=cfg['batch_size'], 
                          shuffle=False, 
                          drop_last=False,
                          num_workers=10)


### Unet
unet_path = cfg['unet'] + str(cfg['n'])
seg_model = UNet2D(n_chans_in=1, n_chans_out=1, n_filters_init=cfg['channel_out']).to(0)
model_path = f'{root}pre-trained-tmp/trained_UNets/{unet_path}_best.pt'
state_dict = torch.load(model_path)['model_state_dict']
seg_model.load_state_dict(state_dict)


### AE Params
layer_ids = ['shortcut0', 'shortcut1', 'shortcut2', 'up3']


                   # channel, spatial, latent, depth
ae_map   = {'up3': [     64,      32,     64,     2]}


AEs = nn.ModuleDict({layer_id: AE(in_channels = ae_map[layer_id][0], 
                                  in_dim      = ae_map[layer_id][1],
                                  latent_dim  = ae_map[layer_id][2],
                                  depth       = ae_map[layer_id][3],
                                  block_size  = 4) 
                          for layer_id in layer_ids if layer_id not in cfg['identity_layers']})


for layer_id in cfg['identity_layers']:
    AEs[layer_id] = nn.Identity()

model = Frankenstein(seg_model, 
                     AEs, 
                     disabled_ids=cfg['identity_layers'],
                     copy=True)

criterion = CalgaryCriterionAE(loss=cfg['loss'])

eval_metrics = {'Sample Volumetric Dice': SampleDice(data='calgary'),
                'UNet Volumetric Dice': UnetDice(data='calgary')}

trainer = AETrainerCalgary(model=model, 
                           unet=seg_model, 
                           criterion=criterion, 
                           train_loader=train_loader, 
                           valid_loader=valid_loader, 
                           root=root,
                           target=cfg['target'],
                           description=description,
                           lr=1e-4, 
                           eval_metrics=eval_metrics, 
                           log=cfg['log'],
                           n_epochs=1,
                           patience=4) #20
trainer.fit()

  0%|          | 0/1 [00:00<?, ?it/s]

../../pre-trained-tmp/trained_AEs/calgary_ae_test0_best.pt


In [3]:
cfg = {
        'debug': True,
        'log': False,
        'description': f'acdc_ae_test', #'mms_vae_for_nnUNet_fc3_0_bs50',
        'project': 'MICCAI2023',

        # Data params
        'n': 0,
        'root': '../../',
        'data_path': 'data/mnm/',
        'train_vendor': 'B',
        'unet': f'acdc_unet8_0',
        'channel_out': 8,

        # Hyperparams
        'batch_size': 32,
        'augment': False,
        'difference': True,
        'loss': 'huber',  # huber or ce
        'target': 'output', #gt or output
        'disabled_ids': ['shortcut0', 'shortcut1', 'shortcut2']
}

description = cfg['description']
root = cfg['root']

# Unet
unet_path = cfg['unet'] # + str(cfg['n'])
unet = UNet2D(n_chans_in=1, n_chans_out=4, n_filters_init=cfg['channel_out']).to(0)
model_path = f'{root}pre-trained-tmp/trained_UNets/{unet_path}_best.pt'
state_dict = torch.load(model_path)['model_state_dict']
unet.load_state_dict(state_dict)

### Dataloader
## Initialize trainer to get data loaders with data augmentations from training
pkl_file          = nnUnet_prefix + 'data/nnUNet_preprocessed/Task500_ACDC/nnUNetPlansv2.1_plans_2D.pkl'
fold              = 0
output_folder     = nnUnet_prefix + 'results/nnUnet/nnUNet/2d/Task027_ACDC/nnUNetTrainerV2__nnUNetPlansv2.1/'
dataset_directory = nnUnet_prefix + 'data/nnUNet_preprocessed/Task500_ACDC'

trainer = nnUNetTrainerV2(pkl_file, 0, output_folder, dataset_directory)
trainer.initialize()

train_loader = trainer.tr_gen
valid_loader = trainer.val_gen


### VAE Params
layer_ids = ['shortcut0', 'shortcut1', 'shortcut2', 'up3']

                   #    channel, spatial, latent,  depth, block 
ae_map   = {'up3': [        64,      32,    128,     2,      4]}

cfg['ae_map'] = ae_map
if cfg['log']:
    run = wandb.init(reinit=True, 
                     name=cfg['description'],
                     project=cfg['project'], 
                     config=cfg)
    cfg = wandb.config


AEs = nn.ModuleDict({'up3': AE(in_channels = ae_map['up3'][0], 
                               in_dim      = ae_map['up3'][1],
                               latent_dim  = ae_map['up3'][2],
                               depth       = ae_map['up3'][3],
                               block_size  = ae_map['up3'][4])})

for layer_id in cfg['disabled_ids']:
     AEs[layer_id] = nn.Identity()


model = Frankenstein(unet, 
                     AEs, 
                     disabled_ids=cfg['disabled_ids'],
                     copy=True)

model.cuda()
print()
criterion    = MNMCriterionAE(loss=cfg['loss'], diff=cfg['difference'])
eval_metrics = {'Sample Volumetric Dice': SampleDice(data='MNM'),
                'UNet Volumetric Dice': UnetDice(data='MNM')}

ae_trainer = AETrainerACDC(model=model, 
                           unet=unet, 
                           criterion=criterion, 
                           train_loader=train_loader, 
                           valid_loader=valid_loader, 
                           num_batches_per_epoch=trainer.num_batches_per_epoch,
                           num_val_batches_per_epoch=trainer.num_val_batches_per_epoch,
                           root=root,
                           target=cfg['target'],
                           description=description,
                           lr=1e-4, 
                           eval_metrics=eval_metrics, 
                           log=cfg['log'],
                           n_epochs=1, 
                           patience=8)


ae_trainer.fit()

loading dataset
loading all case properties
2023-10-10 08:18:13.187382: Using splits from existing split file: ../../../nnUNet/data/nnUNet_preprocessed/Task500_ACDC/splits_final.pkl
2023-10-10 08:18:13.202848: The split file contains 5 splits.
2023-10-10 08:18:13.203122: Desired fold for training: 0
2023-10-10 08:18:13.203817: This split has 160 training and 40 validation cases.
unpacking dataset
done



  0%|          | 0/1 [00:00<?, ?it/s]

using pin_memory on device 0
using pin_memory on device 0
../../pre-trained-tmp/trained_AEs/acdc_ae_test_best.pt
