## Test Pretraining

In [1]:
import random
from pathlib import Path
import time
import matplotlib.pyplot as plt
%matplotlib inline

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import nibabel as nib
from torch.utils.data import DataLoader
from tqdm import tqdm
import copy

from decoder_pretrain import DecoderPretrainNet
from encoder_pretrain import EncoderPretrainNet
from gloss_dminus import GlobalLossDminus
from gloss_d import GlobalLossD
from demeaned_gloss_d import DemeanedGlobalLossD
from pretraining_utils import *

import argparse
import json
import statistics


In [2]:
torch.cuda.empty_cache()

with open('configs/preprocessing_datasets.json') as config_file:
    config_datasets = json.load(config_file)
with open('configs/config_encoder-Copy1.json') as config_file:
    config_encoder = json.load(config_file)

# init W&B
#print("Using W&B in %s mode" % 'online')
#wandb.init(project=config_encoder["model"], mode='online')
seed = config_encoder['seed']
torch.manual_seed(seed)
np.random.seed(seed)

# choose model from config
model = EncoderPretrainNet(config_encoder)

# choose specified optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=config_encoder["lr"])
#criterion = {"global_dminus": lambda: GlobalLossDminus(config_encoder),
#             "global_d": lambda: GlobalLossD(config_encoder),
#             "representation_loss_global_d" : lambda : GlobalLossD(config_encoder),
#             "2_steps_global_d": lambda : GlobalLossD(config_encoder),
#             "demeaned_representation_loss": lambda : GlobalLossD(config_encoder),
             #"demeaned_representation_loss_2": lambda : GlobalLossD(config_encoder)
#            }[config_encoder['loss']]()

criterion = GlobalLossD(config_encoder)

if config_encoder['loss'] == '2_steps_global_d' :
    criterion2 = GlobalLossD(config_encoder, within_dataset = True)
if config_encoder['loss'] == 'demeaned_representation_loss' :
    criterion2 = DemeanedGlobalLossD(config_encoder)

In [3]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print("Using device: ", device)
model.to(device)
print("Running model %s" % config_encoder["model"])

n_parts = config_encoder['n_parts']
n_datasets = config_encoder['n_datasets']
n_volumes = config_encoder['n_volumes']
n_transforms = config_encoder['n_transforms']
resize_size = config_encoder['resize_size']
n_channels = config_encoder['n_channels']
loss_pretraining = config_encoder['loss'] 
lambda_ = config_encoder['lambda']
batch_size = n_parts * n_datasets * n_volumes
perp_val = 80
max_epochs = 2#config_encoder['max_epochs']
weight_loss = config_encoder['weight_loss']
save_global_path = config_encoder['save_global_path']
temp_fac = config_encoder["temp_fac"]

date = str(time.strftime("%Y%h%d_%Hh%M"))
if loss_pretraining == 'representation_loss' or loss_pretraining == '2_steps_global_d':
    save_directory = save_global_path + loss_pretraining + 'teeeest/' + str(n_datasets) +'datasets_' + str(n_volumes) \
                    + 'volumesPerBatch_' + str(n_transforms) + 'transforms'+'_lb' + str(lambda_)+'_tempfac' + str(temp_fac)
else :
    save_directory = save_global_path + loss_pretraining + 'teeeest/' + str(n_datasets) +'datasets_' + str(n_volumes) \
                    + 'volumesPerBatch_' + str(n_transforms) + 'transforms_tempfac' + str(temp_fac)
save_models = save_directory + '/save_models/'
Path(save_models).mkdir(parents=True, exist_ok=True)

print('Parameters used : ')
print('Full traing dataset, no validation set')
print('loss : ', loss_pretraining)
print('weight_loss : ', weight_loss)
print('n_volumes : ', n_volumes)
print('n_transforms : ', n_transforms)
print('n_datasets : ', n_datasets)
print('max_epochs : ', max_epochs)
print('batch size : ', batch_size)
print('temp_fac : ', temp_fac)
print('save folder : ', save_directory)


Using device:  cpu
Running model encoder_pretrain
Parameters used : 
Full traing dataset, no validation set
loss :  demeaned_representation_loss_2
weight_loss :  1
n_volumes :  3
n_transforms :  2
n_datasets :  4
max_epochs :  2
batch size :  48
temp_fac :  1
save folder :  ./trained_models/Cardiac_only/pretraining_FullTrainSet/demeaned_representation_loss_2teeeest/4datasets_3volumesPerBatch_2transforms_tempfac1


In [4]:
class PreTrainDatasetDemeaned(torch.utils.data.Dataset) :
    def __init__(self, volumes):
        self.path_volumes = volumes
        
        for i in range(len(self.path_volumes)):
            vol_file = self.path_volumes[i]
            volume = nib.load(vol_file).get_fdata()
            new_img = volume.transpose(2, 0, 1)
            
            if i == 0: 
                self.imgs = new_img
            else :
                self.imgs = np.concatenate((self.imgs, new_img), axis=0)
                
    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        return self.imgs[idx]


In [5]:
count_datasets = 0
for config_dataset in config_datasets:
    if config_dataset['experiment'] == 'pretraining' :
        if count_datasets >= n_datasets :
            break
        count_datasets += 1
assert count_datasets == n_datasets

# Dataset Initialization
datasets_train = []
datasets_validation = []
datasets_loaders_demeaned = []
        
print('Datasets used for pretraining :')

count_datasets = 0
for config_dataset in config_datasets:
    
    if config_dataset['experiment'] == 'pretraining' :
        if count_datasets >= n_datasets :
            break
        count_datasets += 1
        
        print(config_dataset['Data'])
        
        current_datasets_train = []
        current_datasets_validation = []
        count = 0
        for path in Path( config_dataset["savedir"]).rglob( "train/*/img.nii.gz"):
            if count > 0 :
                break
            current_datasets_train.append(path)
            count += 1
        for path in Path( config_dataset["savedir"]).rglob( "validation/*/img.nii.gz"):
            if count > 0 :
                break
            current_datasets_train.append(path)
            count += 1
        for path in Path( config_dataset["savedir"]).rglob( "test/*/img.nii.gz"):
            current_datasets_validation.append(path)
            
        print('n vol in train : ', len(current_datasets_train))
        print('n vol in validation : ', len(current_datasets_validation))
        datasets_train.append(current_datasets_train)
        datasets_validation.append(current_datasets_validation)
        
        dataset_demeaned = PreTrainDatasetDemeaned(current_datasets_train)
        datasets_loader_demeaned = DataLoader(dataset_demeaned,
                            num_workers=1,
                            batch_size = 32,
                            pin_memory=True,
                            shuffle=False,
                            drop_last=False)
        datasets_loaders_demeaned.append(datasets_loader_demeaned)
        print(len(datasets_loaders_demeaned))
        
        

trans = custom_transforms(config_encoder)



Datasets used for pretraining :
ACDC
n vol in train :  1
n vol in validation :  50
1
Chaos
n vol in train :  1
n vol in validation :  5
2
HCP
n vol in train :  1
n vol in validation :  13
3
Medical Decathelon Prostate
n vol in train :  1
n vol in validation :  8
4


In [6]:
count_datasets = 0
for config_dataset in config_datasets:
    if config_dataset['experiment'] == 'pretraining' :
        if count_datasets >= n_datasets :
            break
        count_datasets += 1
assert count_datasets == n_datasets

# Dataset Initialization
datasets_train = []
datasets_validation = []
#datasets_loaders_demeaned = []
        
print('Datasets used for pretraining :')

count_datasets = 0
for config_dataset in config_datasets:
    
    if config_dataset['experiment'] == 'pretraining' :
        if count_datasets >= n_datasets :
            break
        count_datasets += 1
        
        print(config_dataset['Data'])
        
        current_datasets_train = []
        current_datasets_validation = []
        
        for path in Path( config_dataset["savedir"]).rglob( "train/*/img.nii.gz"):
            
            current_datasets_train.append(path)
            
        for path in Path( config_dataset["savedir"]).rglob( "validation/*/img.nii.gz"):
            
            current_datasets_train.append(path)
            
        for path in Path( config_dataset["savedir"]).rglob( "test/*/img.nii.gz"):
            current_datasets_validation.append(path)
            
        print('n vol in train : ', len(current_datasets_train))
        print('n vol in validation : ', len(current_datasets_validation))
        datasets_train.append(current_datasets_train)
        datasets_validation.append(current_datasets_validation)
        
        dataset_demeaned = PreTrainDatasetDemeaned(current_datasets_train)
        datasets_loader_demeaned = DataLoader(dataset_demeaned,
                            num_workers=1,
                            batch_size = 32,
                            pin_memory=True,
                            shuffle=False,
                            drop_last=False)
        #datasets_loaders_demeaned.append(datasets_loader_demeaned)
        print(len(datasets_loaders_demeaned))
        
        
dataset_train = PreTrainDataset(config_encoder, datasets_train)
dataset_loader_train = DataLoader(dataset_train,
                            num_workers=1,
                            batch_size= n_volumes,
                            pin_memory=True,
                            shuffle=True,
                            drop_last=True)

dataset_validation = PreTrainDataset(config_encoder, datasets_validation)
dataset_loader_validation = DataLoader(dataset_validation,
                            num_workers=1,
                            batch_size=n_volumes,
                            shuffle=False,
                            drop_last=True)

trans = custom_transforms(config_encoder)



Datasets used for pretraining :
ACDC
n vol in train :  49
n vol in validation :  50
4
Chaos
n vol in train :  15
n vol in validation :  5
4
HCP
n vol in train :  36
n vol in validation :  13
4
Medical Decathelon Prostate
n vol in train :  24
n vol in validation :  8
4


if loss_pretraining == 'demeaned_representation' :
    # Calculate the mean representation of each datasets
    all_volumes_per_dataset = []
    for i_dataset in range(n_datasets) :

        for count, path_vol in enumerate(datasets_train[i_dataset]) :
            vol_file = path_vol
            volume = torch.tensor(nib.load(vol_file).get_fdata().transpose(2, 0, 1)).view((-1, n_channels, *resize_size))

            if count == 0 :
                volumes = volume
            else :
                volumes = torch.cat([volumes, volume])
        print(volumes.shape)
        all_volumes_per_dataset.append(volumes)
print(len(all_volumes_per_dataset))

In [7]:
torch.zeros((2,2))

tensor([[0., 0.],
        [0., 0.]])

In [9]:
steps = 0
losses = pd.DataFrame(columns = ['epoch', 'train loss', 'validation loss'])

# Training
for epoch in range(max_epochs) :
    print("Doing Train...")
    print("Epoch {:03d}".format(epoch))

    model.train()
    batch_train_loss = []
    
    # Calculate mean representation of each dataset from current model
    if loss_pretraining == 'demeaned_representation_loss' or loss_pretraining == 'demeaned_representation_loss_2' :
        with torch.no_grad():
            model.eval()
            mean_representations = []
            for i_dataset in range(n_datasets):
                if epoch == 0 :
                    mean_representation = torch.zeros((128, 6, 6))
                else :
                    for id_batch, batch_x in enumerate(tqdm(datasets_loaders_demeaned[i_dataset])):
                        batch = batch_x.float().to(device)
                        batch = batch.view((-1, n_channels, *resize_size))
                        if id_batch == 0 :
                            pred_representation = model.enc(batch).squeeze()
                        else :
                            pred_representation = torch.cat([pred_representation, model.enc(batch).squeeze()])
                    mean_representation = torch.mean(pred_representation, dim = 0).to(device)
                mean_representations.append(mean_representation)
          
    
    for id_batch, batch_x in enumerate(tqdm(dataset_loader_train)):

        optimizer.zero_grad()
        
        batch = batch_x.float().to(device)
        batch = batch.view((-1, n_channels, *resize_size))
        train_batch = batch
        
        
        if loss_pretraining == 'demeaned_representation_loss' :
            # similar as representation_loss but with demeaned version loss
            # subtract the mean of the representation of the repectif dataset of last epoch to compute the loss
            for i in range(n_transforms) :
                train_batch = torch.cat([train_batch, trans(batch, dtrans ='option_2')])
            pred = model(train_batch).squeeze()
            pred_representation = model.enc(train_batch).squeeze()
            
            loss_1 = criterion(pred)
            loss_2 = criterion2(pred_representation, mean_representations)
            
            loss = lambda_ * loss_1 + (1 - lambda_) * loss_2
            
        elif loss_pretraining == 'demeaned_representation_loss_2' :
            for i in range(n_transforms) :
                train_batch = torch.cat([train_batch, trans(batch, dtrans ='option_2')])
            pred_representation = model.enc(train_batch).squeeze()
            
            # Demeaned reg_pred with respective mean of the dataset 
            pred_demeaned_representation = pred_representation.clone()
            for i in range(pred_representation.shape[0]) :
                pred_demeaned_representation[i,:,:,:] -=  mean_representations[((i // n_parts) % n_datasets)] 
                #print((i // n_parts) % n_datasets)
            pred = model.g1(pred_demeaned_representation).squeeze()
            
            loss = criterion(pred)
            
        if id_batch == 0 :
            pred_train = pred
        else :
            pred_train = torch.cat([pred_train, pred.detach()])
        
        batch_train_loss.append(loss.item())

        loss.backward()
        optimizer.step()
        
    if (epoch+1) % 2 == 0 :
        directory_predictions = str(save_directory) + "/predictions/" + str(epoch) +"/" 
        Path(directory_predictions).mkdir(parents=True, exist_ok=True)
        torch.save(pred_train, directory_predictions+ "pred_train.pt")
        
    train_loss = statistics.mean(batch_train_loss)
    print("Current train loss: %f" % train_loss)  
    
    if (epoch+1) % 2 == 0 :
        
        # Validation 
        batch_val_loss = []

        print("Doing Validation...")
        with torch.no_grad():
            model.eval()
            predictions = pd.DataFrame(columns = ['type', 'prediction'])

            for id_batch, batch_x in enumerate(tqdm(dataset_loader_validation)):

                batch = batch_x.float().to(device)
                batch = batch.view((-1, n_channels, *resize_size))
                val_batch = batch

                
                if loss_pretraining == 'demeaned_representation_loss' :
                    for i in range(n_transforms) :
                        val_batch = torch.cat([val_batch, trans(batch, dtrans ='option_2')])
                        
                    #val_batch = torch.cat([val_batch, trans(batch, dtrans ='option_2')])
                    pred_representation = model.enc(val_batch).squeeze()
                    pred = model(val_batch).squeeze()
                    
                    loss_1 = criterion(pred)
                    loss_2 = criterion2(pred_representation, mean_representations)
                    
                    loss = lambda_ * loss_1 + (1 - lambda_) * loss_2
                    
                elif loss_pretraining == 'demeaned_representation_loss_2' :
                    for i in range(n_transforms) :
                        val_batch = torch.cat([val_batch, trans(batch, dtrans ='option_2')])
                    pred_representation = model.enc(val_batch).squeeze()

                    # Demeaned reg_pred with respective mean of the dataset 
                    pred_demeaned_representation = pred_representation.clone()
                    for i in range(pred_representation.shape[0]) :
                        pred_demeaned_representation[i,:,:,:] -=  mean_representations[((i // n_parts) % n_datasets)] 
                        #print((i // n_parts) % n_datasets)
                    pred = model.g1(pred_demeaned_representation).squeeze()

                    loss = criterion(pred)
                    
                if id_batch == 0 :
                    pred_validation = pred
                else :
                    pred_validation = torch.cat([pred_validation, pred])

                batch_val_loss.append(loss.item())

        torch.save(pred_validation, directory_predictions+"pred_validation.pt")
    
        validation_loss = statistics.mean(batch_val_loss)

        print("Current validation loss: %f" % validation_loss)  
    
    if (epoch+1) % 2 == 0 :
        losses = losses.append([{'epoch': epoch, 'train loss' : train_loss, 'validation loss' : validation_loss}])
        steps += 1
    else : 
        losses = losses.append([{'epoch': epoch, 'train loss' : train_loss}])
        steps += 1
        
    # Save loss and model at each 50 epoch
    if (epoch+1) % 2 == 0 :
        torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss,
                },str(save_models) + "checkpoints_" + str(epoch)+".pt")

        losses.to_pickle(str(save_directory) + "/losses.pkl")
    
losses.to_pickle(str(save_directory) + "/losses.pkl")
print('Done')


  0%|          | 0/8 [00:00<?, ?it/s]

Doing Train...
Epoch 000


100%|██████████| 8/8 [01:12<00:00,  9.05s/it]
  0%|          | 0/1 [00:00<?, ?it/s]

Current train loss: 10.970399
Doing Train...
Epoch 001


100%|██████████| 1/1 [00:00<00:00,  4.82it/s]
100%|██████████| 4/4 [00:01<00:00,  3.03it/s]
100%|██████████| 6/6 [00:02<00:00,  2.42it/s]
100%|██████████| 1/1 [00:00<00:00,  2.91it/s]
100%|██████████| 8/8 [01:12<00:00,  9.09s/it]
  0%|          | 0/2 [00:00<?, ?it/s]

Current train loss: 10.624954
Doing Validation...


100%|██████████| 2/2 [00:12<00:00,  6.17s/it]

Current validation loss: 9.993177
Done



