# Import needed

In [None]:
import SimpleITK as sitk
import os
import glob
import numpy as np
import pydicom
import nibabel as nib
import pandas as pd
import time
import shutil

import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

import skimage
from skimage import data, measure
from skimage.io import imread
from skimage.filters import threshold_otsu
from skimage.segmentation import clear_border

from skimage.measure import regionprops
from skimage.morphology import closing, square

from skimage.color import label2rgb

from exp_utils import *
from model_utils import *
from utils import l2_regularisation



from data import *

from StitchingDeTr import *

import torchvision.transforms as T

import pickle
import torch
import torch.nn.functional as F
%matplotlib inline

In [None]:
torch.cuda.current_device()

# Data process/loading

In [None]:
data_dir='/media/hmn-mednuc/InternalDisk_1/datasets/GAINED/resampled_croped/'

In [None]:
from torch.utils.data import Dataset, DataLoader, random_split

In [None]:
def prepare_loader(dataset,batch_size=10,shuffle=True):
    
    partition_train_val=[int(len(dataset)*0.8),int(len(dataset)*0.2)+1]
    #print(len(dataset))
    #print(sum(partition_train_val))
    

        
    train_set,valid_set = random_split(dataset,partition_train_val)
    
    train_loader = DataLoader(train_set, batch_size=batch_size, num_workers = 0, shuffle=shuffle)
    val_loader = DataLoader(valid_set, batch_size=batch_size, num_workers = 0, shuffle=shuffle)

#     train_loader = DataLoader(train_set, batch_size=batch_size, num_workers = 0, shuffle=shuffle, pin_memory=torch.cuda.is_available())
#     val_loader = DataLoader(valid_set, batch_size=batch_size, num_workers = 0, shuffle=shuffle, pin_memory=torch.cuda.is_available())
    return train_loader,val_loader

# Model def/train

In [None]:
from HPUnet_torch import HierarchicalProbUNet

In [None]:
base_channels = 24
num_convs_per_block = 3
default_channels_per_block = (
     base_channels,
	2* base_channels,
	 4*base_channels,
     8*base_channels,
	 8*base_channels,
	 8*base_channels,
	 8*base_channels,
    8*base_channels,
8*base_channels)
input_channels = tuple([1])+tuple([i for i in default_channels_per_block])

channels_per_block = default_channels_per_block
down_channels_per_block = tuple([i / 2 for i in default_channels_per_block])
#net=Hierarchical_Core(dim=2,input_channels=list(input_channels),channels_per_block=list(channels_per_block),
#               down_channels_per_block=list(down_channels_per_block), convs_per_block=3,
#               blocks_per_level=3,Posterior=False)

#HPUnetscri=StitchingDecoder(dim=2,latent_dims=[1,1,1,1],input_channels=list(input_channels),channels_per_block=list(channels_per_block),num_classes=6,
#               down_channels_per_block=list(down_channels_per_block), convs_per_block=3,
#               blocks_per_level=3)


net=HierarchicalProbUNet(dim=2,latent_dims=[1,1,1,1],input_channels=list(input_channels),channels_per_block=list(channels_per_block),num_classes=2,
               down_channels_per_block=list(down_channels_per_block), convs_per_block=3,
               blocks_per_level=3)

In [None]:
checkpoint_path = './chkpoint_withgen_with_size256x256_beta1_30epochs'
best_model_path = './bestmodel_withgen_with_size256x256_beta1_30epochs.pt'
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
net.to(device)
net.train()
optimizer = torch.optim.Adam(net.parameters(), lr=1e-4, weight_decay=0)
epochs = 30
beta=1.0


# train_dataset = MRI2DSegmentationDataset(data_dir=data_dir, slice_axis=1,transform=mt_transforms.ToTensor())

# data = train_dataset[70]

# print(data["input"].shape)
# print(data["gt"].shape)
# print(data["boxes"])
# print(data["labels"])



data_dir='/media/hmn-mednuc/InternalDisk_1/datasets/GAINED/resampled_croped/'

valid_loss_min=float('inf')

# TODO lesy way of capturing the logs, find a more elegant way to capture the logs 

train_loss,val_loss=[],[]
dice_score_train,dice_score_val=[],[]
kls_loss_train,kls_loss_val=[],[]
recons_loss_train,recons_loss_val=[],[]
detection_loss_train,detection_loss_val=[],[]


#print("zHere")


for epoch in range(epochs):
    
    dataset = MRI2DSegmentationDataset(data_dir=data_dir, slice_axis=1)
    train_loader,val_loader=prepare_loader(dataset)

    running_train_reconstruction,running_train_kl_loss,running_train_total_loss,running_train_score = [[] for _ in range(4)]

    print('Numbers of epoch:{}/{}'.format(epoch+1,epochs))
    started = time.time()
          
    for batch_idx, (train_batch_input , train_batch_gt) in enumerate(train_loader):
        #print('Batch idx {}, data shape {}, target shape {}'.format(batch_idx, data.shape, target.shape))
        target,data=train_batch_gt.to(device),train_batch_input.to(device)
        #targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        


        # kl divergence loss/loss_per_level part of the ELBO

        kl_loss_per_levels,kl_loss = net.kl_divergence_(target,data)


        # binary cross-entropy reconstruction loss part of ELBO

        reconstruction = net.reconstruct(target,data,mean=False)
        loss_bce = nn.BCEWithLogitsLoss(size_average=False,reduce=False,reduction=None)
        criterion_reconstruction = loss_bce(input=reconstruction,target=target)
        reconstruction_loss = torch.sum(criterion_reconstruction)


        # definition of the ELBO

        elbo =  -(reconstruction_loss + beta * kl_loss)
        
        # regularisation term 
        
        reg_loss = l2_regularisation(net._prior)+l2_regularisation(net._posterior)+l2_regularisation(net._f_comb)

        # Total loss that will be used to for back propagete the gradient + regularisation term omit for the DeTr for now 

        
        total_loss = -elbo  + 1e-5*reg_loss 
        score = batch_dice(F.softmax(net.sample(data,mean=False),dim=1),target)
        #running_loss += loss.item() * inputs.size(0) 
        #print(loss) 
        optimizer.zero_grad() 
        total_loss.backward() 
        optimizer.step() 

        #print(len(running_train_Detr_loss))
        running_train_total_loss.append(total_loss.item())
        running_train_kl_loss.append(kl_loss_per_levels)
        #print(len(running_train_kl_loss))
        running_train_reconstruction.append(reconstruction_loss.item())
        running_train_score.append(score.item())

        #running_train_score.append(score.item())
        
        
        #print('loss batch: {},Dice score batch: {}, batch_idx: {}'.format(loss.item(),score.item(),batch_idx))
        print(' KL divergence loss over one batch: {} ---- Reconstruction loss over one batch: {} ---- Overall loss batch: {} ---- Overall score batch: {} ---- Batch idx: {}'.format(kl_loss.item(),reconstruction_loss.item(),total_loss.item(),score.item(),batch_idx))

    else:
        running_valid_reconstruction,running_valid_kl_loss,running_valid_total_loss,running_valid_score = [[] for _ in range(4)]
          
        with torch.no_grad():

            for batch_idx, (valid_batch_input, valid_batch_gt) in enumerate(val_loader):
                target,data=valid_batch_gt.to(device),valid_batch_input.to(device)


                # kl divergence loss/loss_per_level part of the ELBO

                kl_loss_per_levels,kl_loss = net.kl_divergence_(target,data)


                # binary cross-entropy reconstruction loss part of ELBO

                reconstruction = net.reconstruct(target,data,mean=False)
                loss_bce = nn.BCEWithLogitsLoss(size_average=False,reduce=False,reduction=None)
                criterion_reconstruction = loss_bce(input=reconstruction,target=target)
                reconstruction_loss = torch.sum(criterion_reconstruction)


                # definition of the ELBO

                elbo =  -(reconstruction_loss + beta * kl_loss)
                reg_loss = l2_regularisation(net._prior)+l2_regularisation(net._posterior)+l2_regularisation(net._f_comb)

                # Total loss that will be used to for nack propagete the gradient + regularisation term omit for the DeTr for now 
                total_loss = -elbo + 1e-5*reg_loss 

                score = batch_dice(F.softmax(net.sample(data,mean=False),dim=1),target)

 
                running_valid_total_loss.append(total_loss.item())
                running_valid_kl_loss.append(kl_loss_per_levels)
                running_valid_reconstruction.append(reconstruction_loss.item())
                running_valid_score.append(score.item())
        
    epoch_train_loss,epoch_train_kl,epoch_train_score,epoch_train_reconstruction = np.mean(running_train_total_loss),np.mean(running_train_kl_loss,axis=0),np.mean(running_train_score),np.mean(running_train_reconstruction)
    print('Train total loss epoch : {} Dice score epoch : {}'.format(epoch_train_loss,epoch_train_score))
    train_loss.append(epoch_train_loss)
    dice_score_train.append(epoch_train_score)
    kls_loss_train.append(epoch_train_kl)
    recons_loss_train.append(epoch_train_reconstruction)

    epoch_val_loss,epoch_val_kl,epoch_val_score,epoch_val_reconstruction = np.mean(running_valid_total_loss),np.mean(running_valid_kl_loss,axis=0),np.mean(running_valid_score),np.mean(running_valid_reconstruction)
    print('Valid total loss epoch: {} Dice score epoch : {}'.format(epoch_val_loss,epoch_val_score))
    val_loss.append(epoch_val_loss)
    dice_score_val.append(epoch_val_score)
    kls_loss_val.append(epoch_val_kl)
    recons_loss_val.append(epoch_val_reconstruction)
          
    checkpoint = {'epoch': epoch +1,
                  'valid_loss_min':epoch_val_loss,
                  'state_dict':net.state_dict(),
                  'optimizer':optimizer.state_dict(),
        
    }
    save_ckp(checkpoint, False,checkpoint_path,best_model_path)
     
    if epoch_val_loss <= valid_loss_min:
          print('Validation loss decreased ({:.6f} =======> {:.6f}). Saving model ...'.format(valid_loss_min,epoch_val_loss))
          
          save_ckp(checkpoint, True,checkpoint_path,best_model_path)
          valid_loss_min = epoch_val_loss
          
    time_passed = time.time() - started
    print('{:.0f}m {:.0f}s'.format(time_passed//60, time_passed%60))


In [None]:
checkpoint_path = './chkpoint_withgen_with_detection_zer_test'
best_model_path = './bestmodel_withgen_with_detection_zer_test.pt'
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
net.to(device)
net.train()
optimizer = torch.optim.Adam(net.parameters(), lr=1e-4, weight_decay=0)
epochs = 30
beta=1.0


# train_dataset = MRI2DSegmentationDataset(data_dir=data_dir, slice_axis=1,transform=mt_transforms.ToTensor())

# data = train_dataset[70]

# print(data["input"].shape)
# print(data["gt"].shape)
# print(data["boxes"])
# print(data["labels"])



data_dir='/media/hmn-mednuc/InternalDisk_1/datasets/GAINED/resampled_croped/'

valid_loss_min=float('inf')

# TODO lesy way of capturing the logs, find a more elegant way to capture the logs 

train_loss,val_loss=[],[]
dice_score_train,dice_score_val=[],[]
kls_loss_train,kls_loss_val=[],[]
recons_loss_train,recons_loss_val=[],[]
detection_loss_train,detection_loss_val=[],[]


#print("zHere")


for epoch in range(epochs):
    
    dataset = MRI2DSegmentationDataset(data_dir=data_dir, slice_axis=1)
    train_loader,val_loader=prepare_loader(dataset)

    running_train_Detr_loss,running_train_reconstruction,running_train_kl_loss,running_train_total_loss,running_train_score = [[] for _ in range(5)]

    print('Numbers of epoch:{}/{}'.format(epoch+1,epochs))
    started = time.time()
          
    for batch_idx, (train_batch_input , train_batch_gt , targets) in enumerate(train_loader):
        #print('Batch idx {}, data shape {}, target shape {}'.format(batch_idx, data.shape, target.shape))
        target,data=train_batch_gt.to(device),train_batch_input.to(device)
        _,outputs=net.sample_and_detect(data,mean=True,z_q=None)
        #targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        targets = [{"labels": l.to(device),"boxes":b.to(device)} for l,b in zip(targets["labels"],targets["boxes"])]
        
        # loss from the DeTr '3 losses'
        
        loss_dict = criterion(outputs, targets)
        #print(loss_dict)
        weight_dict = criterion.weight_dict
        losses_detr = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
        #print(loss_dict)


        # kl divergence loss/loss_per_level part of the ELBO

        kl_loss_per_levels,kl_loss = net.kl_divergence_(target,data)


        # binary cross-entropy reconstruction loss part of ELBO

        reconstruction = net.reconstruct(target,data,mean=False)
        loss_bce = nn.BCEWithLogitsLoss(size_average=False,reduce=False,reduction=None)
        criterion_reconstruction = loss_bce(input=reconstruction,target=target)
        reconstruction_loss = torch.sum(criterion_reconstruction)


        # definition of the ELBO

        elbo =  -(reconstruction_loss + beta * kl_loss)
        reg_loss = l2_regularisation(net._prior)+l2_regularisation(net._posterior)+l2_regularisation(net._f_comb)

        # Total loss that will be used to for nack propagete the gradient + regularisation term omit for the DeTr for now 

        
        total_loss = -elbo + losses_detr + 1e-5*reg_loss 
        score = batch_dice(F.softmax(net.sample(data,mean=False),dim=1),target)
        #running_loss += loss.item() * inputs.size(0) 
        #print(loss) 
        optimizer.zero_grad() 
        total_loss.backward() 
        optimizer.step() 

        running_train_Detr_loss.append([loss_dict[k].item() * weight_dict[k] for k in loss_dict.keys() if k in weight_dict])
        #print(len(running_train_Detr_loss))
        running_train_total_loss.append(total_loss.item())
        running_train_kl_loss.append(kl_loss_per_levels)
        #print(len(running_train_kl_loss))
        running_train_reconstruction.append(reconstruction_loss.item())
        running_train_score.append(score.item())

        #running_train_score.append(score.item())
        
        
        #print('loss batch: {},Dice score batch: {}, batch_idx: {}'.format(loss.item(),score.item(),batch_idx))
        print('Loss DeTr loss over one batch: {} ---- KL divergence loss over one batch: {} ---- Reconstruction loss over one batch: {} ---- Overall loss batch: {} ---- Overall score batch: {} ---- Batch idx: {}'.format(losses_detr.item(),kl_loss.item(),reconstruction_loss.item(),total_loss.item(),score.item(),batch_idx))

    else:
        running_valid_Detr_loss,running_valid_reconstruction,running_valid_kl_loss,running_valid_total_loss,running_valid_score = [[] for _ in range(5)]
          
        with torch.no_grad():

            for batch_idx, (valid_batch_input , valid_batch_gt , targets) in enumerate(val_loader):
                target,data=valid_batch_gt.to(device),valid_batch_input.to(device)
                _,outputs=net.sample_and_detect(data,mean=True,z_q=None)
                #targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
                targets = [{"labels": l.to(device),"boxes":b.to(device)} for l,b in zip(targets["labels"],targets["boxes"])]
                
                # loss from the DeTr '3 losses'
                
                loss_dict = criterion(outputs, targets)
                print(loss_dict)
                weight_dict = criterion.weight_dict
                losses_detr = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)


                # kl divergence loss/loss_per_level part of the ELBO

                kl_loss_per_levels,kl_loss = net.kl_divergence_(target,data)


                # binary cross-entropy reconstruction loss part of ELBO

                reconstruction = net.reconstruct(target,data,mean=False)
                loss_bce = nn.BCEWithLogitsLoss(size_average=False,reduce=False,reduction=None)
                criterion_reconstruction = loss_bce(input=reconstruction,target=target)
                reconstruction_loss = torch.sum(criterion_reconstruction)


                # definition of the ELBO

                elbo =  -(reconstruction_loss + beta * kl_loss)
                reg_loss = l2_regularisation(net._prior)+l2_regularisation(net._posterior)+l2_regularisation(net._f_comb)

                # Total loss that will be used to for nack propagete the gradient + regularisation term omit for the DeTr for now 
                total_loss = -elbo + losses_detr + 1e-5*reg_loss 

                score = batch_dice(F.softmax(net.sample(data,mean=False),dim=1),target)

 
                running_valid_Detr_loss.append([loss_dict[k].item() * weight_dict[k] for k in loss_dict.keys() if k in weight_dict])
                #print(len(running_valid_Detr_loss[0]))
                running_valid_total_loss.append(total_loss.item())
                running_valid_kl_loss.append(kl_loss_per_levels)
                running_valid_reconstruction.append(reconstruction_loss.item())
                running_valid_score.append(score.item())
        
    epoch_train_loss,epoch_train_kl,epoch_train_score,epoch_train_reconstruction,epoch_train_detr = np.mean(running_train_total_loss),np.mean(running_train_kl_loss,axis=0),np.mean(running_train_score),np.mean(running_train_reconstruction),np.mean(running_train_Detr_loss,axis=0)
    print('Train total loss epoch : {} Dice score epoch : {}'.format(epoch_train_loss,epoch_train_score))
    train_loss.append(epoch_train_loss)
    dice_score_train.append(epoch_train_score)
    kls_loss_train.append(epoch_train_kl)
    recons_loss_train.append(epoch_train_reconstruction)
    detection_loss_train.append(epoch_train_detr)

    epoch_val_loss,epoch_val_kl,epoch_val_score,epoch_val_reconstruction,epoch_val_detr = np.mean(running_valid_total_loss),np.mean(running_valid_kl_loss,axis=0),np.mean(running_valid_score),np.mean(running_valid_reconstruction),np.mean(running_valid_Detr_loss,axis=0)
    print('Valid total loss epoch: {} Dice score epoch : {}'.format(epoch_val_loss,epoch_val_score))
    val_loss.append(epoch_val_loss)
    dice_score_val.append(epoch_val_score)
    kls_loss_val.append(epoch_val_kl)
    recons_loss_val.append(epoch_val_reconstruction)
    detection_loss_val.append(epoch_val_detr)
          
    checkpoint = { 'epoch': epoch +1,
                  'valid_loss_min':epoch_val_loss,
                  'state_dict':net.state_dict(),
                  'optimizer':optimizer.state_dict(),
        
    }
    save_ckp(checkpoint, False,checkpoint_path,best_model_path)
     
    if epoch_val_loss <= valid_loss_min:
          print('Validation loss decreased ({:.6f} =======> {:.6f}). Saving model ...'.format(valid_loss_min,epoch_val_loss))
          
          save_ckp(checkpoint, True,checkpoint_path,best_model_path)
          valid_loss_min = epoch_val_loss
          
    time_passed = time.time() - started
    print('{:.0f}m {:.0f}s'.format(time_passed//60, time_passed%60))


In [None]:
# checkpoint_path = './chkpoint_withgen_valid'
# best_model_path = './bestmodel_withgen_valid.pt'
# device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
# net.to(device)
# net.train()
# optimizer = torch.optim.Adam(net.parameters(), lr=1e-4, weight_decay=0)
# epochs = 20

# valid_loss_min=float('inf')
# train_loss,val_loss=[],[]
# dice_score_train,dice_score_val=[],[]


# for epoch in range(epochs):
    
#     dataset = MRI2DSegmentationDataset(data_dir=data_dir, slice_axis=1)
#     train_loader,val_loader=prepare_loader(dataset)
    
    
    
#     running_train_loss = []
#     running_train_score = []
#     print('Numbers of epoch:{}/{}'.format(epoch+1,epochs))
#     started = time.time()
          
#     for batch_idx, train_batch in enumerate(train_loader):
#         #print('Batch idx {}, data shape {}, target shape {}'.format(batch_idx, data.shape, target.shape))
#         target,data=train_batch['gt'].to(device),train_batch['input'].to(device)
#         elbo = net.elbo(target,data)
#         reg_loss = l2_regularisation(net._prior)+l2_regularisation(net._posterior)+l2_regularisation(net._f_comb)
#         loss = -elbo + 1e-5*reg_loss
#         score = batch_dice(F.softmax(net.sample(data,mean=False),dim=1),target)
#         #running_loss += loss.item() * inputs.size(0) 
#         #print(loss) 
#         optimizer.zero_grad() 
#         loss.backward() 
#         optimizer.step() 
#         running_train_loss.append(loss.item())
#         running_train_score.append(score.item())
#         print('loss batch: {},Dice score batch: {}, batch_idx: {}'.format(loss.item(),score.item(),batch_idx))
#     else:
#         running_val_loss=[]
#         running_val_score=[]
          
#         with torch.no_grad():
#             for valid_batch in val_loader:
#                 target,data=valid_batch['gt'].to(device),valid_batch['input'].to(device)
#                 elbo = net.elbo(target.to(device),data.to(device))
#                 reg_loss = l2_regularisation(net._prior)+l2_regularisation(net._posterior)+l2_regularisation(net._f_comb)
#                 loss = -elbo + 1e-5*reg_loss
#                 score = batch_dice(F.softmax(net.sample(data,mean=False),dim=1),target)
#                 running_val_loss.append(loss.item())
#                 running_val_score.append(score.item())
        
#     epoch_train_loss,epoch_train_score = np.mean(running_train_loss),np.mean(running_train_score)
#     print('Train loss epoch : {} Dice score epoch : {}'.format(epoch_train_loss,epoch_train_score))
#     train_loss.append(epoch_train_loss)
#     dice_score_train.append(epoch_train_score)
        
#     epoch_val_loss,epoch_val_score = np.mean(running_val_loss),np.mean(running_val_score)
#     print('Valid loss epoch: {} Dice score epoch : {}'.format(epoch_val_loss,epoch_val_score))
#     val_loss.append(epoch_val_loss)
#     dice_score_val.append(epoch_val_score)
          
#     checkpoint = { 'epoch': epoch +1,
#                   'valid_loss_min':epoch_val_loss,
#                   'state_dict':net.state_dict(),
#                   'optimizer':optimizer.state_dict(),
        
#     }
#     save_ckp(checkpoint, False,checkpoint_path,best_model_path)
     
#     if epoch_val_loss <= valid_loss_min:
#           print('Validation loss decreased ({:.6f} =======> {:.6f}). Saving model ...'.format(valid_loss_min,epoch_val_loss))
          
#           save_ckp(checkpoint, True,checkpoint_path,best_model_path)
#           valid_loss_min = epoch_val_loss
          
#     time_passed = time.time() - started
#     print('{:.0f}m {:.0f}s'.format(time_passed//60, time_passed%60))
# #net.eval()
# #sample_1 = net.sample(torch.from_numpy(all_pt_img[25550][np.newaxis][np.newaxis]/100).cuda(),mean=True,z_q=None)
# #sample_2 = net.sample(torch.from_numpy(all_pt_img[25550][np.newaxis]/100).cuda(),mean=True,z_q=None)

# # print(sample,sample.shape,"Sample shape")
# # prekd = torch.argmax(sample,axis=1)
# # print(pred,pred.shape)

In [None]:
# train_loss,val_loss=[],[]
# dice_score_train,dice_score_val=[],[]
# kls_loss_train,kls_loss_val=[],[]
# recons_loss_train,recons_loss_val=[],[]
# detection_loss_train,detection_loss_val=[],[]


training_logs = {"train_loss": train_loss, "dice_score_train": dice_score_train,"kls_loss_train":kls_loss_train,"recons_loss_train":recons_loss_train}

training_file = open("training_logs_run_128128Beta1_epochs60.pkl", "wb")

pickle.dump(training_logs, training_file)

training_file.close()


#validation_logs = {"val_loss": val_loss, "dice_score_val": dice_score_val,"kls_loss_train":kls_loss_val,"recons_loss_val":recons_loss_val,"detection_loss_val":detection_loss_val}




In [None]:
validation_logs = {"val_loss": val_loss, "dice_score_val": dice_score_val,"kls_loss_train":kls_loss_val,"recons_loss_val":recons_loss_val}

validation_file = open("validation_logs_128128Beta1_epochs60.pkl", "wb")

pickle.dump(validation_logs, validation_file)

validation_file.close()



# Reload the model

In [None]:
checkpoint_path = './chkpoint_withgen_with_size128x128_beta1_60epochs'
best_model_path = './bestmodel_withgen_with_size128x128_beta1_60epochs.pt'

In [None]:
#path=r"C:\Users\youve\Dossier Thése\model_chkpointbest"
best_model_path ="bestmodel_withgen_valid.pt" 

In [None]:
optimizer = torch.optim.Adam(net.parameters(), lr=1e-4, weight_decay=0)
model=HierarchicalProbUNet(dim=2,latent_dims=[1,1,1,1],input_channels=list(input_channels),channels_per_block=list(channels_per_block),num_classes=2,
               down_channels_per_block=list(down_channels_per_block), convs_per_block=3,
               blocks_per_level=3)
#model, optimizer, start_epoch, valid_loss_min = load_ckp(checkpoint_path, model, optimizer)
model, optimizer, start_epoch, valid_loss_min = load_ckp(best_model_path, model, optimizer)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model.to(device)
model.eval()

In [None]:
prior = model._prior

In [None]:
valid_loss_min

In [None]:
start_epoch

In [None]:
data_dir=r'C:\GAINED\resampled_croped'

In [None]:
def image_to_test(slices):
    test_mask_img=[]
    for path in mask_path[slices[0]:slices[1]]:
        mask_sitk=sitk.ReadImage(path)
        masks_np=sitk.GetArrayFromImage(mask_sitk)
        for i in range(masks_np.shape[1]):
            test_mask_img.append(masks_np[:,i,:])


    test_pt_img=[]
    for path in pt_path[slices[0]:slices[1]]:
        img_sitk=sitk.ReadImage(path)
        img_np=sitk.GetArrayFromImage(img_sitk)
        for i in range(img_np.shape[1]):
            test_pt_img.append(img_np[:,i,:])
    
    return test_pt_img,test_mask_img
        
        

## Find nice slice to show 

In [None]:
test_pt_img,test_mask_img=image_to_test([100,120])

In [None]:
len(test_mask_img)

In [None]:
good_slices = []
for ary in test_mask_img:
    count_zero = (ary == 0).sum()
    unique, counts = np.unique(ary, return_counts=True)
    p=counts[1:].sum()/count_zero
    good_slices.append(p)
    #print(dict(zip(unique, counts)))


In [None]:
x=dict(zip(good_slices,range(len(test_mask_img))))
best_slices_sorted={k: v for k, v in sorted(x.items(), key=lambda item: float(item[0]))}

In [None]:
best_slices_sorted

In [None]:
plt.imshow(test_pt_img[])

In [None]:
plt.imshow(test_mask_img[310])

In [None]:
i=2500
desired_shape=(256,256)
n_classes=2
npa_pet=padding(desired_shape,test_pt_img[i])[np.newaxis,np.newaxis]
npa_pet.shape

In [None]:
i=12
n_classes=2
npa_pet=test_pt_img[i][np.newaxis,np.newaxis]
npa_pet.shape

In [None]:
n_classes=2
true_mask = (test_mask_img[i]>= 1).astype(np.int)
pad_true_mask = padding(desired_shape,true_mask).astype(np.int)
#y_ohe = one_hot_encoding(true_mask,n_classes)[np.newaxis]

y_ohe = one_hot_encoding(pad_true_mask,n_classes)[np.newaxis]
y_ohe.shape

In [None]:
# npa_pet/=100
# image = torch.from_numpy(npa_pet).float()
# mask = torch.from_numpy(y_ohe).float()
# reconstruction=model.reconstruct(mask.to(device),image.to(device),mean=True)
# sample_full,detection_full=model.sample_and_detect(image.to(device))
# sample_local,detection_local=model.sample_and_detect(image.to(device), mean=[1, 1, 1, 0])
# sample_global,detection_global=model.sample_and_detect(image.to(device), mean=[0, 1, 1, 1])

In [None]:
npa_pet/=100
image = torch.from_numpy(npa_pet).float()
mask = torch.from_numpy(y_ohe).float()
image_ = image.to(device)
reconstruction=model.reconstruct(mask.to(device),image.to(device),mean=True)
sample_full=model.sample(image.to(device))
sample_local=model.sample(image.to(device), mean=[1, 1, 1, 0])
sample_global=model.sample(image.to(device), mean=[0, 1, 1, 1])

In [None]:
prior = model._prior(image_,mean=False, z_q=None)

In [None]:
used_latent = prior["used_latents"]

In [None]:
used_latent

In [None]:
detection_global


In [None]:
pred_local=torch.argmax(sample_local,axis=1)
pred_local=pred_local.clone().cpu().detach().numpy()

pred_global=torch.argmax(sample_global,axis=1)
pred_global=pred_global.clone().cpu().detach().numpy()

pred_full=torch.argmax(sample_full,axis=1)
pred_full=pred_full.clone().cpu().detach().numpy()

pred_reconstruction=torch.argmax(reconstruction,axis=1)
pred_reconstruction=pred_reconstruction.clone().cpu().detach().numpy()

In [None]:
np.unique(pred_full)

In [None]:
plt.figure(figsize=(10,10))

plt.subplot(1,6,1)
plt.title("TEP")
plt.imshow(test_pt_img[i])


plt.subplot(1,6,2)
plt.title("MASK")
plt.imshow(true_mask)

plt.subplot(1,6,3)
plt.title("sample_full")
plt.imshow(pred_full[0])

plt.subplot(1,6,4)
plt.title("sample_local")
plt.imshow(pred_local[0])

plt.subplot(1,6,5)
plt.title("sample_global")
plt.imshow(pred_global[0])

plt.subplot(1,6,6)
plt.title("reconstruction")
plt.imshow(pred_reconstruction[0])

In [None]:
# num_sample=30
# samples=[]
# for i in range(num_sample):
#     sample,_=model.sample_and_detect(image.to(device))
#     pred=sample.clone().cpu().detach().numpy()
#     samples.append(pred)

# samples_global=[]
# for i in range(num_sample):
#     sample,_=model.sample_and_detect(image.to(device), mean=[0, 1, 1, 1])
#     pred=sample.clone().cpu().detach().numpy()
#     samples_global.append(pred)


# samples_local=[]
# for i in range(num_sample):
#     sample,_=model.sample_and_detect(image.to(device), mean=[1, 1, 1, 0])
#     pred=sample.clone().cpu().detach().numpy()
#     samples_local.append(pred)

In [None]:
num_sample=30
samples=[]
for i in range(num_sample):
    sample=model.sample(image.to(device))
    pred=sample.clone().cpu().detach().numpy()
    samples.append(pred)

samples_global=[]
for i in range(num_sample):
    sample=model.sample(image.to(device), mean=[0, 1, 1, 1])
    pred=sample.clone().cpu().detach().numpy()
    samples_global.append(pred)


samples_local=[]
for i in range(num_sample):
    sample=model.sample(image.to(device), mean=[1, 1, 1, 0])
    pred=sample.clone().cpu().detach().numpy()
    samples_local.append(pred)

In [None]:
samples=torch.from_numpy(np.array(samples_local))
samples.shape

In [None]:
sample_sm=torch.sigmoid(samples)
sample_sm.shape


In [None]:
pred=torch.argmax(sample_sm,axis=2)
pred.shape

In [None]:
pred=pred.clone().cpu().detach().numpy()

In [None]:
pred.shape

In [None]:
std=np.std(pred,axis=0)

In [None]:
std=100*std

In [None]:
np.where(std)

In [None]:
len(np.unique(std))

In [None]:
np.unique(std)

In [None]:
std.max()

In [None]:
mean=np.mean(pred,axis=0)

In [None]:
np.unique(mean)

In [None]:
plt.subplot(1,2,1)
plt.imshow(std[0])
plt.title("std images")
plt.subplot(1,2,2)
plt.imshow(mean[0])
plt.title("mean images")

In [None]:
for i in range(30):
    plt.imshow(pred[i,0,:,:],cmap='gray')
    plt.show()

In [None]:
samples_local=torch.from_numpy(np.array(samples_local))
samples_local.shape

In [None]:
sample_sm=torch.sigmoid(samples_local)
sample_sm.shape

In [None]:
pred=torch.argmax(sample_sm,axis=2)
pred.shape

In [None]:
pred=pred.clone().cpu().detach().numpy()

In [None]:
for i in range(10):
    plt.imshow(pred[i,0,:,:])
    
    plt.show()

In [None]:
from PIL import Image

In [None]:
image = nib.load(pt_path[1])
image = image.get_fdata()

In [None]:
plt.imshow(pet)

In [None]:
from torchvision import transforms

In [None]:
trans=transforms.Compose([transforms.Pad((128,0,0,0), fill=0, padding_mode='constant'),
   transforms.ToTensor(),])

In [None]:
test=trans(pet)

In [None]:
test.shape

In [None]:
test_np=test.clone().cpu().detach().numpy()

In [None]:
plt.imshow(test_np[0])

In [None]:
from data_loader import MRI2DSegmentationDataset

In [None]:
trans={'inputs':transforms.Compose(
     [transforms.ToPILImage(mode='LA'), transforms.Pad((0,0,0,128), fill=0, padding_mode='constant'),transforms.ToTensor(),]),
    
     'mask':transforms.Compose(
     [transforms.ToPILImage(mode='L'), transforms.Pad((0,0,0,128), fill=0, padding_mode='constant'),transforms.ToTensor(),])}
train_dataset = MRI2DSegmentationDataset(data_dir,cache=True,slice_axis=1,transform=trans)
#train_dataset = MRI2DSegmentationDataset(data_dir,transform=trans)
#train_dataset.__getitem__(55)

set_data=train_dataset[0]
print(set_data['input'].shape)
print(set_data['gt'].shape)
print(len(train_dataset))

In [None]:
data = set_data['input'].clone().cpu().detach().numpy()
mask = set_data['gt'].clone().cpu().detach().numpy()

In [None]:
mask.max()

In [None]:
data.max()

In [None]:
plt.subplot(1,2,1)
plt.imshow(data[0])
plt.subplot(1,2,2)
plt.imshow(mask[0])

In [None]:
import os
import re
import collections


#from medicaltorch import datasets as mt_datasets
from medicaltorch import transforms as mt_transforms
from tqdm import tqdm
import numpy as np
import nibabel as nib
import glob
from torch.utils.data import Dataset
import torch
from torch._six import string_classes, int_classes

from PIL import Image

data_dir='/media/hmn-mednuc/InternalDisk_1/datasets/GAINED/resampled_croped/'

def get_path_modality_and_mask(data_dir):
    
    ls_idx_pet = list(map(lambda x:x.split('/')[-1].split('_')[0] , glob.glob(data_dir + 'PET0/*00001.nii*')))
    ls_idx_mask = list(np.unique(np.array(list(map(lambda x:x.split('/')[-1][:14], glob.glob(data_dir + 'PET0_mask*/*nii*'))))))
    ls_ids = list(set(ls_idx_pet).intersection(set(ls_idx_mask)))
        
    pt_path=[os.path.join(data_dir,'PET0',ids+'_00001.nii') for ids in ls_ids]
    mask_list = os.path
    
    mask_path=[os.path.join(data_dir,'PET0_masks',ids+'_mask.nii') for ids in ls_ids]
    
    
    return pt_path,mask_path


def padding(desired_shape,npa,value=0):

    if value==0:
        new_npa=np.zeros(desired_shape)
    else:
        new_npa=np.zeros(desired_shape)+value

    new_npa[:npa.shape[0],:npa.shape[1]] = npa

    return new_npa


def one_hot_encoding(y,n_classes):
    
    dim = len(y.shape)
    if dim == 2:
        one_hot = np.zeros((n_classes,y.shape[0], y.shape[1]))
    if dim == 3:
        one_hot = np.zeros((n_classes,y.shape[0], y.shape[1],y.shape[2]))
    for i,unique_value in enumerate(np.unique(y)):
        one_hot[i,:][y == unique_value] = 1
    return one_hot



class SegmentationPair2D(object):
    """This class is used to build 2D segmentation datasets. It represents
    a pair of of two data volumes (the input data and the ground truth data).

    :param input_filename: the input filename (supported by nibabel).
    :param gt_filename: the ground-truth filename.
    :param cache: if the data should be cached in memory or not.
    :param canonical: canonical reordering of the volume axes.
    """
    def __init__(self, input_filename, gt_filename, cache=True,
                 canonical=False):
        self.input_filename = input_filename
        self.gt_filename = gt_filename
        self.canonical = canonical
        self.cache = cache

        self.input_handle = nib.load(self.input_filename)

        # Unlabeled data (inference time)
        if self.gt_filename is None:
            self.gt_handle = None
        else:
            self.gt_handle = nib.load(self.gt_filename)

        if len(self.input_handle.shape) > 3:
            raise RuntimeError("4-dimensional volumes not supported.")

        # Sanity check for dimensions, should be the same
        input_shape, gt_shape = self.get_pair_shapes()

        if self.gt_handle is not None:
            if not np.allclose(input_shape, gt_shape):
                raise RuntimeError('Input and ground truth with different dimensions.')

        if self.canonical:
            self.input_handle = nib.as_closest_canonical(self.input_handle)

            # Unlabeled data
            if self.gt_handle is not None:
                self.gt_handle = nib.as_closest_canonical(self.gt_handle)

    def get_pair_shapes(self):
        """Return the tuple (input, ground truth) representing both the input
        and ground truth shapes."""
        input_shape = self.input_handle.header.get_data_shape()

        # Handle unlabeled data
        if self.gt_handle is None:
            gt_shape = None
        else:
            gt_shape = self.gt_handle.header.get_data_shape()

        return input_shape, gt_shape

    def get_pair_data(self):
        """Return the tuble (input, ground truth) with the data content in
        numpy array."""
        cache_mode = 'fill' if self.cache else 'unchanged'
        input_data = self.input_handle.get_fdata(cache_mode, dtype=np.float32)

        # Handle unlabeled data
        if self.gt_handle is None:
            gt_data = None
        else:
            gt_data = self.gt_handle.get_fdata(cache_mode, dtype=np.float32)
        
 
        return input_data, gt_data

    def get_pair_slice(self, slice_index, slice_axis=2):
        """Return the specified slice from (input, ground truth).

        :param slice_index: the slice number.
        :param slice_axis: axis to make the slicing.
        """
        if self.cache:
            input_dataobj, gt_dataobj = self.get_pair_data()
        else:
            # use dataobj to avoid caching
            input_dataobj = self.input_handle.dataobj

            if self.gt_handle is None:
                gt_dataobj = None
            else:
                gt_dataobj = self.gt_handle.dataobj

        if slice_axis not in [0, 1, 2]:
            raise RuntimeError("Invalid axis, must be between 0 and 2.")

        if slice_axis == 2:
            input_slice = np.asarray(input_dataobj[..., slice_index],
                                     dtype=np.float32)
        elif slice_axis == 1:
            input_slice = np.asarray(input_dataobj[:, slice_index, ...],
                                     dtype=np.float32)
        elif slice_axis == 0:
            input_slice = np.asarray(input_dataobj[slice_index, ...],
                                     dtype=np.float32)

        # Handle the case for unlabeled data
        gt_meta_dict = None
        if self.gt_handle is None:
            gt_slice = None
        else:
            if slice_axis == 2:
                gt_slice = np.asarray(gt_dataobj[..., slice_index],
                                      dtype=np.float32)
            elif slice_axis == 1:
                gt_slice = np.asarray(gt_dataobj[:, slice_index, ...],
                                      dtype=np.float32)
            elif slice_axis == 0:
                gt_slice = np.asarray(gt_dataobj[slice_index, ...],
                                      dtype=np.float32)

        dreturn = {
            "input": input_slice,
            "gt": gt_slice,
        }
        
        return dreturn


class MRI2DSegmentationDataset(Dataset):
    """This is a generic class for 2D (slice-wise) segmentation datasets.

    :param filename_pairs: a list of tuples in the format (input filename,
                           ground truth filename).
    :param slice_axis: axis to make the slicing (default axial).
    :param cache: if the data should be cached in memory or not.
    :param transform: transformations to apply.
    """
    def __init__(self, data_dir, slice_axis=2, cache=False,
                 transform=None, slice_filter_fn=None, canonical=False):

        self.data_dir = data_dir
        self.pt_path,self.mask_path = get_path_modality_and_mask(self.data_dir)
        self.masks_dir = [dir_mask for dir_mask in os.listdir(self.data_dir) if 'masks' in dir_mask] 

        self.filename_pairs = [(p_pt,p_mask) for p_pt,p_mask in zip(self.pt_path,self.mask_path)]
  
        self.handlers = []
        self.indexes = []
        self.transform = transform
        self.cache = cache
        self.slice_axis = slice_axis
        self.slice_filter_fn = slice_filter_fn
        self.canonical = canonical

        self._load_filenames()
        self._prepare_indexes()

    def _load_filenames(self):
        for input_filename, gt_filename in self.filename_pairs:
            segpair = SegmentationPair2D(input_filename, gt_filename.replace('PET0_masks',str(self.masks_dir[np.random.randint(len(self.masks_dir))])),self.cache, self.canonical)
            self.handlers.append(segpair)
        
    def _prepare_indexes(self):
        for segpair in self.handlers:
            input_data_shape, _ = segpair.get_pair_shapes()
            for segpair_slice in range(input_data_shape[1]):

                # Check if slice pair should be used or not
                if self.slice_filter_fn:
                    slice_pair = segpair.get_pair_slice(segpair_slice,
                                                        self.slice_axis)
                    
                    filter_fn_ret = self.slice_filter_fn(slice_pair)
                    if not filter_fn_ret:
                        continue

                item = (segpair, segpair_slice)
                self.indexes.append(item)

    def set_transform(self, transform):
        """This method will replace the current transformation for the
        dataset.

        :param transform: the new transformation
        """
        self.transform = transform

    def __len__(self):
        """Return the dataset size."""
        return len(self.indexes)

    def __getitem__(self, index):
        """Return the specific index pair slices (input, ground truth).

        :param index: slice index.
        """
        segpair, segpair_slice = self.indexes[index]
        pair_slice = segpair.get_pair_slice(segpair_slice,
                                            self.slice_axis)

        pair_slice["input"]=padding((256,256),pair_slice["input"])
        pair_slice["input"]/=100
        pair_slice["gt"] = (pair_slice['gt']>= 1).astype(np.int)
        pair_slice["gt"]=padding((256,256),pair_slice["gt"])
#         print(pair_slice["gt"].shape)
#         print(pair_slice["gt"].max())
        pair_slice["gt"].reshape(256,256)
        pair_slice["gt"]=one_hot_encoding(pair_slice["gt"],2)
        
        # Consistency with torchvision, returning PIL Image
        # Using the "Float mode" of PIL, the only mode
        # supporting unbounded float32 values
        input_img = pair_slice["input"]

        # Handle unlabeled data
        if pair_slice["gt"] is None:
            gt_img = None
        else:
            gt_img = pair_slice["gt"]
 
        data_dict = {
            'input': torch.from_numpy(input_img[np.newaxis]).float(),
            'gt': torch.from_numpy(gt_img).float(),
        }

        if self.transform is not None:
            data_dict = self.transform(data_dict)

        return data_dict

data_dir='/media/hmn-mednuc/InternalDisk_1/datasets/GAINED/resampled_croped/'
#train_dataset = MRI2DSegmentationDataset(data_dir=data_dir, slice_axis=1,transform=mt_transforms.ToTensor())
# train_dataset = MRI2DSegmentationDataset(data_dir=data_dir, slice_axis=1)

# print(len(train_dataset))

# data = train_dataset[71]
# #print(data["input"].shape)
# #print(data["gt"].shape)

In [None]:
dataset = MRI2DSegmentationDataset(data_dir=data_dir, slice_axis=1)
train_loader,val_loader=prepare_loader(dataset,1,False)

In [None]:
len(dataset)

In [None]:
data = dataset[8000]

In [None]:
sum([int(len(train_dataset)*0.8),int(len(train_dataset)*0.2)])

In [None]:
random_split(train_dataset,[int(len(train_dataset)*0.8),int(len(train_dataset)*0.2)+1])

In [None]:
plt.imshow(data['input'][0])

In [None]:
plt.imshow(data["gt"][1])

In [None]:
plt.imshow(all_pt_img[99])

In [None]:
plt.imshow(all_mask_img[71])

In [None]:
dataloader = DataLoader(train_dataset, batch_size=1)

In [None]:
len(dataloader) 

In [None]:
len(train_loader)

In [None]:
for i,batch in enumerate(dataloader):
    print(batch['gt'].shape)
    print(batch['input'].shape)
    if i == 75:
        break

In [None]:
for i,batch in enumerate(train_loader):
    gt=batch['input']
    for item in gt:
        item = item.clone().cpu().detach().numpy()
        
        plt.imshow(item[0])
        plt.show()
    if i == 75:
        break

In [None]:
for i,batch in enumerate(train_loader):
    gt=batch['gt']
    for item in gt:
        plt.imshow(item.squeeze(0)[0],cmap = 'gray')
        plt.show()
    if i == 75:
        break

In [None]:
for i,batch in enumerate(train_loader):
    gt=batch['gt']
    for item in gt:
        plt.imshow(item.squeeze(0)[0],cmap = 'gray')
        plt.show()
    if i == 75:
        break

In [None]:
for i,batch in enumerate(train_loader):
    gt=batch['gt']
    for item in gt:
        plt.imshow(item.squeeze(0)[0],cmap = 'gray')
        plt.show()
    if i == 75:
        break

In [None]:
for i,batch in enumerate(train_loader):
    gt=batch['gt']
    for item in gt:
        plt.imshow(item.squeeze(0)[0],cmap = 'gray')
        plt.show()
    if i == 75:
        break