In [None]:
import hydra
import torch
import functools
import numpy as np 
import yaml
import sys, os
sys.path.append("/home/user/sirf/")
from src import (BrainWebOSEM, kl_div)
from omegaconf import DictConfig, OmegaConf
import time
import matplotlib.pyplot as plt
from src import PSNR, SSIM

import cupy as xp

detector_efficiency = 1./30

def get_acq_model():
	import pyparallelproj.coincidences as coincidences
	import pyparallelproj.petprojectors as petprojectors
	import pyparallelproj.resolution_models as resolution_models
	import cupyx.scipy.ndimage as ndi
	"""
	create forward operator
	"""
	coincidence_descriptor = coincidences.GEDiscoveryMICoincidenceDescriptor(
		num_rings=1,
		sinogram_spatial_axis_order=coincidences.SinogramSpatialAxisOrder['RVP'],xp=xp)
	acq_model = petprojectors.PETJosephProjector(coincidence_descriptor,
		(128, 128, 1), (-127.0, -127.0, 0.0), (2., 2., 2.))
	res_model = resolution_models.GaussianImageBasedResolutionModel(
		(128, 128, 1), tuple(4.5 / (2.35 * x) for x in (2., 2., 2.)), xp, ndi)
	acq_model.image_based_resolution_model = res_model
	return acq_model


# Generate a unique filename, the file directories specify the SDE and data
name = ""
results = {}

timestr = time.strftime("%Y%m%d_%H%M%S_")

###### SET SEED ######
torch.manual_seed(42)
np.random.seed(42)

###### GET ACQUISITION MODEL AND DATA ######
# get the acquisition model
acq_model = get_acq_model()
dataset = BrainWebOSEM(part="test",
        noise_level="2.5", 
        base_path="/home/user/sirf/src/brainweb_2d/")
subset = list(range(2, len(dataset), 4))
dataset = torch.utils.data.Subset(dataset, subset)
test_loader = torch.utils.data.DataLoader(dataset, 
    batch_size=1, shuffle=False)
# as there are 10 realisations then batch = 10
batch_size = 10
print("Length of test loader: ", len(test_loader))

In [None]:
import torch 
from glob import glob
import torch
import omegaconf
from normalisation import Normalisation
from lpd_modules import LPDForwardFunction2D
from lpd import get_lpd_model
from unet import get_unet_model
from src import PSNR, SSIM

detector_efficiency = 1./30

models = ["data_corrected_mean"]

model_base_path = "/home/user/sirf/results/lpd_unet/UNET/"
save_idxs = [10,20,30,40,50,60]
for model_type in models:
    model_path = model_base_path + model_type + "/"
    config = omegaconf.OmegaConf.load(model_path + ".hydra/config.yaml")
    model = get_unet_model(in_ch=config.benchmark.in_ch, 
                            out_ch=config.benchmark.out_ch, 
                            scales=config.benchmark.scales, 
                            skip=config.benchmark.skip,
                            channels=config.benchmark.channels, 
                            use_sigmoid=config.benchmark.use_sigmoid,
                            use_norm=config.benchmark.use_norm)
    model.load_state_dict(torch.load(sorted(glob(model_path+"/*/model_min_val_loss.pt"))[-1]))
    model.eval()
    model.to(config.device)
    print(sum(p.numel() for p in model.parameters()))
    get_normalisation = Normalisation(config.benchmark.normalisation)
    stdf = []
    crcf = []
    psnrf = []
    ssimf = []
    kl_divf = []
    for idx, batch in enumerate(test_loader):
        ref = torch.swapaxes(batch[0], 0, 1).to(config.device)
        osem = torch.swapaxes(batch[2], 0, 1).to(config.device)
        measurements = torch.swapaxes(batch[4], 0, 1).to(config.device)
        contamination_factor = torch.swapaxes(batch[5], 0, 1)[:,[0],None].to(config.device)[:,0,:]
        attn_factors = torch.swapaxes(batch[6], 0, 1).to(config.device)
        if len(batch) > 7:
            background = torch.swapaxes(batch[7], 0, 1).to(config.device)
            tumour_rois = torch.swapaxes(batch[8], 0, 1).to(config.device)
        norm = get_normalisation(osem, measurements, contamination_factor)
        x_pred = torch.clamp(model(osem, norm).detach(),0)
        kldiv_r = kl_div(x = x_pred,
			acq_model=acq_model, 
			attn_factors=attn_factors, 
			contamination=contamination_factor.unsqueeze(1), 
			measurements=measurements,
			scale_factor=1.)[0].squeeze()
        kl_divf.append(kldiv_r.mean())
        
        images = x_pred.squeeze().unsqueeze(0)
        refs = ref.squeeze().unsqueeze(0)
        if len(batch) > 7:
            lesion_rois = tumour_rois.squeeze().unsqueeze(0)
            background_rois = background.squeeze().unsqueeze(0)
        else:
            lesion_rois = torch.zeros_like(refs).unsqueeze(0)
            background_rois = torch.zeros_like(refs).unsqueeze(0)
            lesion_rois[refs.unsqueeze(0)!=0] = 1
            background_rois[refs.unsqueeze(0)!=0] = 1
        psnrs = []
        ssims = []
        crcs = []
        stds = []
        for img_idx in range(images.shape[0]):
            image = images[img_idx].cpu().numpy()
            ref = refs[img_idx].squeeze().cpu().numpy()
            lesion_roi = lesion_rois[img_idx].cpu().numpy()
            background_roi = background_rois[img_idx].squeeze().cpu().numpy()
            psnr_r = []
            ssim_r = []
            crc_r = []
            b_bar_r = []
            for realisation in image:
                psnr_r.append(torch.asarray(PSNR(realisation,ref)))
                ssim_r.append(torch.asarray(SSIM(realisation, ref)))
                if background_roi.sum() != 0:
                    background_idx = np.nonzero(background_roi)
                    b_bar = realisation[background_idx]
                    b_t = ref[background_idx]
                    crc_t = []
                    for i in range(len(lesion_roi)):
                        if lesion_roi[i,:,:].sum() != 0:
                            tumour_roi_idx = np.nonzero(lesion_roi[i,:,:])
                            a_bar = realisation[tumour_roi_idx]
                            a_t = ref[tumour_roi_idx]
                            if a_bar.mean() == 0 and b_bar.mean() == 0:
                                crc_t.append(np.array([0.0]))
                            else:
                                crc_t.append((a_bar.mean()/b_bar.mean() - 1) / (a_t.mean()/b_t.mean() - 1))
                    crc_r.append(torch.asarray(crc_t).mean())
                    b_bar_r.append(torch.asarray(b_bar))
            std = (torch.std(torch.stack(b_bar_r), dim=0)/torch.clamp(torch.stack(b_bar_r).mean(0),1e-9)).mean()
            psnrs.append(torch.asarray(psnr_r))
            ssims.append(torch.asarray(ssim_r))
            crcs.append(torch.asarray(crc_r))
            stds.append(torch.asarray(std))
        stdf.append(torch.stack(stds).mean())
        crcf.append(torch.stack(crcs).mean())
        psnrf.append(torch.stack(psnrs).mean())
        ssimf.append(torch.stack(ssims).mean())
        """ print(torch.stack(psnrs).mean(), torch.stack(ssims).mean(), torch.stack(crcs).mean(), std)
        fig, ax = plt.subplots(1,2)
        fig.colorbar(ax[0].imshow(x_pred[0,0].cpu().numpy()))
        fig.colorbar(ax[1].imshow(ref))
        plt.show() """
    print(torch.stack(psnrf).mean(), torch.stack(ssimf).mean(), torch.stack(crcf).mean(), torch.stack(stdf).mean(), torch.stack(kl_divf).mean())

In [None]:
import torch 
from glob import glob
import torch
import omegaconf
from normalisation import Normalisation
from lpd_modules import LPDForwardFunction2D, LPDAdjointFunction2D
from lpd import get_lpd_model
from unet import get_unet_model
from src import PSNR, SSIM

detector_efficiency = 1./30

models = ["data_corrected_mean"]

model_base_path = "/home/user/sirf/results/lpd_unet/LPD/"
save_idxs = [10,20,30,40,50,60]
for model_type in models:
    model_path = model_base_path + model_type + "/"
    config = omegaconf.OmegaConf.load(model_path + ".hydra/config.yaml")
    model = get_lpd_model(n_iter = config.benchmark.n_iter, op = LPDForwardFunction2D, op_adj = LPDAdjointFunction2D)
    model.load_state_dict(torch.load(sorted(glob(model_path+"/*/model_min_val_loss.pt"))[-1]))
    model.eval()
    model.to(config.device)
    print(sum(p.numel() for p in model.parameters()))
    get_normalisation = Normalisation(config.benchmark.normalisation)
    stdf = []
    crcf = []
    psnrf = []
    ssimf = []
    kl_divf = []
    for idx, batch in enumerate(test_loader):
        ref = torch.swapaxes(batch[0], 0, 1).to(config.device)
        osem = torch.swapaxes(batch[2], 0, 1).to(config.device)
        measurements = torch.swapaxes(batch[4], 0, 1).to(config.device)
        contamination_factor = torch.swapaxes(batch[5], 0, 1)[:,[0],None].to(config.device)[:,0,:]
        attn_factors = torch.swapaxes(batch[6], 0, 1).to(config.device)
        if len(batch) > 7:
            background = torch.swapaxes(batch[7], 0, 1).to(config.device)
            tumour_rois = torch.swapaxes(batch[8], 0, 1).to(config.device)
        norm = get_normalisation(osem, measurements, contamination_factor)
        x_pred = torch.clamp(model(osem, measurements, acq_model, attn_factors/30, norm, contamination_factor).detach(),0)
        kldiv_r = kl_div(x = x_pred,
			acq_model=acq_model, 
			attn_factors=attn_factors, 
			contamination=contamination_factor.unsqueeze(1), 
			measurements=measurements,
			scale_factor=1.)[0].squeeze()
        kl_divf.append(kldiv_r.mean())
        
        images = x_pred.squeeze().unsqueeze(0)
        refs = ref.squeeze().unsqueeze(0)
        if len(batch) > 7:
            lesion_rois = tumour_rois.squeeze().unsqueeze(0)
            background_rois = background.squeeze().unsqueeze(0)
        else:
            lesion_rois = torch.zeros_like(refs).unsqueeze(0)
            background_rois = torch.zeros_like(refs).unsqueeze(0)
            lesion_rois[refs.unsqueeze(0)!=0] = 1
            background_rois[refs.unsqueeze(0)!=0] = 1
        psnrs = []
        ssims = []
        crcs = []
        stds = []
        for img_idx in range(images.shape[0]):
            image = images[img_idx].cpu().numpy()
            ref = refs[img_idx].squeeze().cpu().numpy()
            lesion_roi = lesion_rois[img_idx].cpu().numpy()
            background_roi = background_rois[img_idx].squeeze().cpu().numpy()
            psnr_r = []
            ssim_r = []
            crc_r = []
            b_bar_r = []
            for realisation in image:
                psnr_r.append(torch.asarray(PSNR(realisation,ref)))
                ssim_r.append(torch.asarray(SSIM(realisation, ref)))
                if background_roi.sum() != 0:
                    background_idx = np.nonzero(background_roi)
                    b_bar = realisation[background_idx]
                    b_t = ref[background_idx]
                    crc_t = []
                    for i in range(len(lesion_roi)):
                        if lesion_roi[i,:,:].sum() != 0:
                            tumour_roi_idx = np.nonzero(lesion_roi[i,:,:])
                            a_bar = realisation[tumour_roi_idx]
                            a_t = ref[tumour_roi_idx]
                            if a_bar.mean() == 0 and b_bar.mean() == 0:
                                crc_t.append(np.array([0.0]))
                            else:
                                crc_t.append((a_bar.mean()/b_bar.mean() - 1) / (a_t.mean()/b_t.mean() - 1))
                    crc_r.append(torch.asarray(crc_t).mean())
                    b_bar_r.append(torch.asarray(b_bar))
            std = (torch.std(torch.stack(b_bar_r), dim=0)/torch.clamp(torch.stack(b_bar_r).mean(0),1e-9)).mean()
            psnrs.append(torch.asarray(psnr_r))
            ssims.append(torch.asarray(ssim_r))
            crcs.append(torch.asarray(crc_r))
            stds.append(torch.asarray(std))
        stdf.append(torch.stack(stds).mean())
        crcf.append(torch.stack(crcs).mean())
        psnrf.append(torch.stack(psnrs).mean())
        ssimf.append(torch.stack(ssims).mean())
        """ print(torch.stack(psnrs).mean(), torch.stack(ssims).mean(), torch.stack(crcs).mean(), std)
        fig, ax = plt.subplots(1,2)
        fig.colorbar(ax[0].imshow(x_pred[0,0].cpu().numpy()))
        fig.colorbar(ax[1].imshow(ref))
        plt.show() """
    print(torch.stack(psnrf).mean(), torch.stack(ssimf).mean(), torch.stack(crcf).mean(), torch.stack(stdf).mean(), torch.stack(kl_divf).mean())