In [1]:
import random
from pathlib import Path
import time
import matplotlib.pyplot as plt

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import nibabel as nib
from torch.utils.data import DataLoader
from tqdm import tqdm

from decoder_pretrain import DecoderPretrainNet
from encoder_pretrain import EncoderPretrainNet
from gloss_dminus import GlobalLossDminus
from gloss_d import GlobalLossD
from dice_loss import DiceLoss
from seg_unet import UNet_pretrained, UNet
import torch.nn.functional as F

import json
import statistics
from sklearn.metrics import f1_score
import pickle
import copy

from training_utils import *
from data_augmentation_utils import DataAugmentation

torch.cuda.empty_cache()

with open('configs/preprocessing_datasets.json') as config_file:
    config_datasets = json.load(config_file)
with open('configs/seg_unet_Met.json') as config_file:
    config_seg = json.load(config_file)
with open('configs/config_encoder.json') as config_file:
    config_encoder = json.load(config_file)


In [32]:
def initialize_dataset(config_datasets, config_seg, dataset, total_n_volumes=24, n_volumes=2, 
                       split_set = 'train',shuffle = False, idx_vols_val = None, return_idx = False) :
            
    img_dataset = []
    mask_dataset = []
        
    idx_vols = select_random_volumes(total_n_volumes, n_volumes) 
    if idx_vols_val != None :
        assert len(idx_vols_val) + n_volumes <= total_n_volumes
        while any(item in idx_vols for item in idx_vols_val) :
            idx_vols = select_random_volumes(total_n_volumes, n_volumes) 
            
    count = -1
    data = None
    for config_dataset in config_datasets :
        if config_dataset['Data'] == dataset :  
            for path in Path(config_dataset['savedir']+ split_set +'/').rglob('subject_*/'):
                
                if dataset == 'Metastases' :
                    if "nii.gz" in str(path) or ".png" in str(path) :
                        continue
                    count += 1   
                    if count >= n_volumes :
                        break
                        
                    # Add the image and the corresponding mask to the datasets
                    for path_image in path.rglob("img.nii.gz") :
                        img_dataset.append(path_image)
                        print(path_image)
                    for path_mask in path.rglob("mask.nii.gz") :
                        mask_dataset.append(path_mask)
                    
                    if split_set == 'train' : 
                        data = dataset
                    
                else :
                    # We want total path not individual path of images
                    if "nii.gz" in str(path) or ".png" in str(path) :
                        continue
                    count += 1   

                    # Different criterion to stop adding train or test volumes
                    if split_set == 'test':
                        if count >= n_volumes :
                            break
                    else :
                        if count not in idx_vols :
                            continue

                    # Add the image and the corresponding mask to the datasets
                    for path_image in path.rglob("img.nii.gz") :
                        img_dataset.append(path_image)
                        print(path_image)
                    for path_mask in path.rglob("mask.nii.gz") :
                        mask_dataset.append(path_mask)
    
    # initalize the dataset
    dataset = TrainDataset(config_seg, img_dataset, mask_dataset, data)
    dataset_loader = DataLoader(dataset,
                                num_workers=1,
                                batch_size=config_seg['batch_size'],
                                shuffle=shuffle,
                                pin_memory=True,
                                drop_last=False)
    if return_idx :
        return dataset_loader, idx_vols
    else :
        return dataset_loader

In [108]:
class TrainDataset(torch.utils.data.Dataset) :
    def __init__(self, config, volumes, masks, dataset = None):
        self.path_volumes = volumes
        self.path_masks = masks
        #self.n_vols = 1
        
        for i in range(len(self.path_volumes)):
            vol_file = self.path_volumes[i]
            mask_file = self.path_masks[i]

            volume = nib.load(vol_file).get_fdata()
            mask = nib.load(mask_file).get_fdata()
            
            assert volume.shape == mask.shape

            new_img = volume.transpose(2, 0, 1)
            new_mask = mask.transpose(2, 0, 1)
            
            if i == 0: 
                self.imgs = new_img
                self.masks = new_mask
            else :
                self.imgs = np.concatenate((self.imgs, new_img), axis=0)
                self.masks = np.concatenate((self.masks, new_mask), axis=0)
                
        
        if dataset == 'Metastases' :
            indexes_tumors_slices = []
            for i in range(self.imgs.shape[0]) :
                if 1 in self.masks[i] :
                    indexes_tumors_slices.append(i)
            
            self.imgs_tumors = self.imgs[indexes_tumors_slices,:,:]
            self.masks_tumors = self.masks[indexes_tumors_slices,:,:]
            
            indexes = np.arange(0,self.imgs.shape[0], dtype=np.int32)
            indexes_healthy = np.delete(indexes,indexes_tumors_slices)
            
            self.imgs_healthy = self.imgs[indexes_healthy,:,:]
            self.masks_healthy = self.masks[indexes_healthy,:,:]
            
            indexes_balance_healthy = random.sample(range(0, self.imgs_healthy.shape[0]), len(indexes_tumors_slices))
            
            self.imgs_balance_healthy = self.imgs_healthy[indexes_balance_healthy,:,:]
            self.masks_balance_healthy = self.masks_healthy[indexes_balance_healthy,:,:]
            
            print(self.imgs.shape)
            self.imgs = np.concatenate((self.imgs_balance_healthy , self.imgs_tumors ), axis=0)
            self.masks = np.concatenate((self.masks_balance_healthy, self.masks_tumors), axis=0)
            
            print(self.imgs.shape)
    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        return self.imgs[idx], self.masks[idx]
    

In [109]:
x = torch.randn(3, 2)
x = torch.tensor([[[2,2,3], [14,23,3]], [[2,223,3], [1422,23,311]]])
x

tensor([[[   2,    2,    3],
         [  14,   23,    3]],

        [[   2,  223,    3],
         [1422,   23,  311]]])

In [110]:
223 in x

True

In [112]:
n_vol_train = config_seg['n_vol_train']
n_vol_val = config_seg['n_vol_val']

lambda_ = 0.6

n_vol_trains = [10]
seeds = [0]#, 10, 20, 30, 40, 50]
options = ['option_10']#, 'option_2', 'option_3', 'option_4', 'option_5', 'option_6', 'option_7'] 
losses_unet = [config_seg['loss_unet']] 

dataset = config_seg['dataset']
resize_size = config_seg['resize_size']
n_channels = config_seg['n_channels']
max_epochs = config_seg['max_epochs']
n_classes = config_seg['n_classes']
batch_size = config_seg['batch_size']
lr = config_seg['lr']
weight_pretrained = config_seg['weight_pretrained']

save_global_path = config_seg['save_global_path']

if dataset == 'Abide':
    n_classes = 15
    lab = [1,2,3,4,5,6,7,8,9,10,11,12,13,14]
    weights = torch.tensor([0.025, 0.075, 0.075, 0.035, 0.035, 0.075, 0.075, 0.075, 
                                   0.075, 0.075, 0.075, 0.075, 0.075, 0.075, 0.075], dtype=torch.float32)
    
    total_n_volumes = 24
    n_vol_test = 12
    
elif dataset == 'CIMAS' or dataset == 'ACDC' :
    n_classes = 4
    lab = [1,2,3]
    weights = torch.tensor([0.1, 0.3, 0.3, 0.3], dtype=torch.float32)
    
    if dataset == 'CIMAS' :
        total_n_volumes = 13
        n_vol_test = 7
    else :
        total_n_volumes = 69
        n_vol_test = 30
        
if dataset == 'Metastases':
    n_classes = 2
    lab = [1]
    weights = torch.tensor([0.1, 0.9], dtype=torch.float32)
    
    total_n_volumes =53
    n_vol_test = 2#24
    n_vol_val = 2#18
        
print('Test set')
dataset_loader_test = initialize_dataset(config_datasets, config_seg, dataset, total_n_volumes=n_vol_test, 
                                         n_volumes=n_vol_test, split_set = 'test', 
                                         shuffle = False)
        
        

data_aug = DataAugmentation()
for loss_unet in losses_unet :
    run = -1
    for seed in seeds :
        torch.cuda.empty_cache()
        run += 1
        torch.manual_seed(seed)
        # Load losses Pretrained encoder
        for n_vol_train in n_vol_trains:

            # Initialization of the datasets

            random.seed(seed)
            print('Validation set')
            random.seed(seed)
            
            if dataset == 'Metastases' :
                dataset_loader_validation = initialize_dataset(config_datasets, config_seg, dataset, 
                                                           total_n_volumes=total_n_volumes, 
                                                           n_volumes=n_vol_val, split_set = 'validation', 
                                                           shuffle = False, return_idx = False)
                random.seed(seed)
                print('Training set')
                dataset_loader_train = initialize_dataset(config_datasets, config_seg, dataset, 
                                                          total_n_volumes=total_n_volumes, 
                                                          n_volumes=n_vol_train, split_set = 'train', 
                                                          shuffle = True)
                
            else : 
                dataset_loader_validation, idx_vols_val = initialize_dataset(config_datasets, config_seg, dataset, 
                                                           total_n_volumes=total_n_volumes, 
                                                           n_volumes=n_vol_val, split_set = 'train', 
                                                           shuffle = False, return_idx = True)
                random.seed(seed)
                print('Training set')
                dataset_loader_train = initialize_dataset(config_datasets, config_seg, dataset, 
                                                          total_n_volumes=total_n_volumes, 
                                                          n_volumes=n_vol_train, split_set = 'train', 
                                                          shuffle = True, idx_vols_val = idx_vols_val)

Test set
../img_cropped/metastases/test/subject_567/img.nii.gz
../img_cropped/metastases/test/subject_7385/img.nii.gz
Validation set
../img_cropped/metastases/validation/subject_6578/img.nii.gz
../img_cropped/metastases/validation/subject_6376/img.nii.gz
Training set
../img_cropped/metastases/train/subject_1624/img.nii.gz
../img_cropped/metastases/train/subject_678/img.nii.gz
../img_cropped/metastases/train/subject_1220/img.nii.gz
../img_cropped/metastases/train/subject_4250/img.nii.gz
../img_cropped/metastases/train/subject_2432/img.nii.gz
../img_cropped/metastases/train/subject_1927/img.nii.gz
../img_cropped/metastases/train/subject_3442/img.nii.gz
../img_cropped/metastases/train/subject_1019/img.nii.gz
../img_cropped/metastases/train/subject_1523/img.nii.gz
../img_cropped/metastases/train/subject_412/img.nii.gz
(3360, 192, 192)
(926, 192, 192)
