In [None]:
%load_ext autoreload
import sys
import os
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import pandas as pd
import nibabel as nib
import matplotlib.pyplot as plt
import scipy
import torch
from tensorboard.backend.event_processing import event_accumulator

sys.path.append('../')
sys.path.append('/workspace/MRI-inpainting-project')

from data_scripts.datasets import PathologicalMRIDataset, HealthyMRIDataset, TrainPatchesDataset
from data_scripts.visualization_utils import ImageSliceViewer3D

%autoreload 2
from dataset import NiftiImageGenerator, NiftiPairImageGenerator, TrainFCDDataset, HealthyFCDDataset, TrainFCDPatchesDataset
from torchvision.transforms import RandomCrop, Compose, ToPILImage, Resize, ToTensor, Lambda
import torch
from diffusion_model.unet import create_model
from diffusion_model.trainer import GaussianDiffusion

from dataset import reconstruct_patch

from skimage.metrics import peak_signal_noise_ratio 
from skimage.metrics import structural_similarity 

In [23]:
def inpaint_train_patch(medddpm_dataset, orig_dataset, model, sample_id):
    sample = medddpm_dataset[sample_id]
    orig_patch, orig_fcd_mask = orig_dataset[sample_id]['patch'], orig_patches_dataset[sample_id]['mask']
    gen_patch = model.sample(batch_size=1, condition_tensors=sample['input'].unsqueeze(0).cuda('cuda:2'))
    gen_patch = gen_patch.cpu().numpy().squeeze()
    recon_patch = reconstruct_patch(orig_patch, gen_patch, orig_fcd_mask)

    return recon_patch, orig_patch, orig_fcd_mask

In [8]:
input_size=40
depth_size=40

transform = Compose([
    Lambda(lambda t: torch.tensor(t).float()),
    Lambda(lambda t: (t * 2) - 1),
    Lambda(lambda t: t.unsqueeze(0)),
    Lambda(lambda t: t.transpose(3, 1)),
])

input_transform = Compose([
    Lambda(lambda t: torch.tensor(t).float()),
    Lambda(lambda t: (t * 2) - 1),
    Lambda(lambda t: t.permute(3, 0, 1, 2)),
    Lambda(lambda t: t.transpose(3, 1)),
])

In [16]:
ckpt_path = '../scripts/results/train_fcd_inpainting_data_l1_masked_patches_split0_500_000/model-2.pt'

channel_mult = "1,2,4,4"

with torch.cuda.device('cuda:2'):
    model = create_model(input_size, num_channels=64, num_res_blocks=1, in_channels=3, out_channels=1, channel_mult=channel_mult).cuda()
    
    diffusion = GaussianDiffusion(
        model,
        image_size = input_size,
        depth_size = depth_size,
        timesteps = 250,   # number of steps
        loss_type = 'l1_masked', 
        with_condition=True,
    ).cuda()
    
    diffusion.load_state_dict(torch.load(ckpt_path, map_location='cuda:2')['ema'])
    print("Model Loaded!")

Model Loaded!


In [33]:
train_patches_dataset = TrainFCDPatchesDataset('../../data/train_patches_v3', input_size, depth_size, 
                                               mask_transform=input_transform, target_transform=transform, 
                                               splits_filename='stratified_8_cv_filtered_2.npy', split_id=0, 
                                               train=False)
orig_patches_dataset = TrainPatchesDataset('../../data/train_patches_v3', 
                                           splits_filename='stratified_8_cv_filtered_2.npy', split_id=0, train=False)

In [36]:
recon_patch, orig_patch, orig_fcd_mask = inpaint_train_patch(train_patches_dataset, orig_patches_dataset,
                                                             diffusion, sample_id=0)

sampling loop time step: 100%|██████████| 250/250 [00:19<00:00, 12.92it/s]


In [51]:
gen_fcd = recon_patch[orig_fcd_mask > 0.5]
fcd = orig_patch[orig_fcd_mask > 0.5]

psnr = peak_signal_noise_ratio(fcd, gen_fcd, data_range = fcd.max() - fcd.min())
ssim = structural_similarity(fcd, gen_fcd, data_range = fcd.max() - fcd.min())

In [53]:
psnr_list = []
ssim_list = []

for sample_id in range(len(orig_patches_dataset)):
    recon_patch, orig_patch, orig_fcd_mask = inpaint_train_patch(train_patches_dataset, orig_patches_dataset,
                                                                 diffusion, sample_id=sample_id)
    gen_fcd = recon_patch[orig_fcd_mask > 0.5]
    fcd = orig_patch[orig_fcd_mask > 0.5]
    
    psnr = peak_signal_noise_ratio(fcd, gen_fcd, data_range = fcd.max() - fcd.min())
    ssim = structural_similarity(fcd, gen_fcd, data_range = fcd.max() - fcd.min())

    psnr_list.append(psnr)
    ssim_list.append(ssim)

sampling loop time step: 100%|██████████| 250/250 [00:18<00:00, 13.74it/s]
sampling loop time step: 100%|██████████| 250/250 [00:18<00:00, 13.65it/s]
sampling loop time step: 100%|██████████| 250/250 [00:18<00:00, 13.52it/s]
sampling loop time step: 100%|██████████| 250/250 [00:19<00:00, 13.14it/s]
sampling loop time step: 100%|██████████| 250/250 [00:19<00:00, 12.79it/s]
sampling loop time step: 100%|██████████| 250/250 [00:19<00:00, 12.57it/s]
sampling loop time step: 100%|██████████| 250/250 [00:20<00:00, 12.50it/s]
sampling loop time step: 100%|██████████| 250/250 [00:20<00:00, 12.44it/s]
sampling loop time step: 100%|██████████| 250/250 [00:20<00:00, 12.49it/s]
sampling loop time step: 100%|██████████| 250/250 [00:20<00:00, 12.42it/s]
sampling loop time step: 100%|██████████| 250/250 [00:20<00:00, 12.43it/s]
sampling loop time step: 100%|██████████| 250/250 [00:20<00:00, 12.40it/s]
sampling loop time step: 100%|██████████| 250/250 [00:20<00:00, 12.38it/s]
sampling loop time step: 

In [None]:
np.save('../stats/psnr_3ddiffusion_split0.npy', psnr_list)
np.save('../stats/ssim_3ddiffusion_split0.npy', ssim_list)