In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

import os
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
import torchvision
from torchvision.utils import make_grid
import time

from csgm import ConditionalScoreModel2D, ConditionalScoreModel2Dy
from csgm.utils import CustomLRScheduler

from utils_ours import uncond_loss_fn, our_loss_fn
from samplers import *

from torchvision.datasets import MNIST
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import functools
import math

In [None]:
T = 1

size = 64

dataset_str = "LIDC" #@param ['LIDC', 'MNIST'] {'type':'string'}


eps = 5e-3
device = "cuda"

inv_prob = "ct"#@param ['ct', 'deblurr'] {'type':'string'}


batch_size = 1000
num_samples =  1000#@param {"type":"integer"}

## Parameters below are fixed for our experiment. Changing them will require training a new model for our approach.


sig_blurr = 5

max_angle = 45
color_channels = 1

#angles = int(size * max_angle / 180.)
angles = size

num_steps = 1000


if dataset_str == "LIDC":
    lr= 0.002 #@param {'type':'number'}
    lr_final = 0.0005
    from utils_ours import LungDataset

    dataset = LungDataset(train = False,
                       transform=torchvision.transforms.Compose(
                           [torchvision.transforms.Resize(size, antialias=True), torchvision.transforms.Lambda(lambda x: x.permute(1, 2, 0))]))

    val_data_loader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=4)
    sig_obs = 0.05
    

In [None]:
# @title Load the score-based model


if True:
    #prepare forward map and other required parameters
    if inv_prob == "deblurr":
        from utils_ours import build_blurring_operator
        A = build_blurring_operator(size, sig_blurr).to(device)
    elif inv_prob == "ct":
        from utils_ours import build_CT_operator
        try:
            A = torch.load("CT_forward" + str(size) + "angles" + str(angles) + "max_angle" + str(max_angle) +".pt")
        except:
            A = build_CT_operator(size, angles, max_angle) 
        A = torch.from_numpy(A)

A = A.to(device).float()

## Generate Measurement

In [None]:
%matplotlib inline
torch.manual_seed(13)

if dataset_str == 'MNIST':
    true, _ = next(iter(val_data_loader))
else:
    true = next(iter(val_data_loader))


plt.imshow(true[0, :, :, 0], cmap = 'gray')
plt.colorbar()
plt.grid(None) 
plt.axis('off')
#plt.savefig("truth.jpg")
plt.show()

true = true.cuda()

In [None]:
y = A.float() @ true.reshape(size**2, -1).to(A.device).float()
if inv_prob == "ct":
    y = y.reshape((1, size, angles))
elif inv_prob == "deblurr":
    y = y.reshape((1, size, size))
    
y = y + sig_obs * torch.randn_like(y)
y = y.cuda().float()

sample_grid = make_grid(y / torch.max(y), nrow=1)

plt.figure(figsize=(2,2))
plt.axis('off')
plt.imshow(sample_grid.permute(1, 2, 0).cpu())
plt.show()

## Generate Posterior samples

In [None]:
filename = "unconditional" + "_" + "T=" + str(T)+ "_" + dataset_str


cond_score_model_8 = ConditionalScoreModel2Dy(
        modes=5,
        hidden_dim=8,
        nlayers=4,
        nt=T
    ).to(device)


cond_score_model_16 = ConditionalScoreModel2Dy(
        modes=5,
        hidden_dim=16,
        nlayers=4,
        nt=T
    ).to(device)


cond_score_model_32 = ConditionalScoreModel2Dy(
        modes=15,
        hidden_dim=32,
        nlayers=4,
        nt=T
    ).to(device)

cond_score_model_64 = ConditionalScoreModel2Dy(
        modes=15,
        hidden_dim=64,
        nlayers=4,
        nt=T
    ).to(device)

cond_score_model_100 = ConditionalScoreModel2Dy(
        modes=25,
        hidden_dim=100,
        nlayers=4,
        nt=T
    ).to(device)


cond_score_model_128 = ConditionalScoreModel2Dy(
        modes=25,
        hidden_dim=128,
        nlayers=4,
        nt=T
    ).to(device)


cond_score_model_150 = ConditionalScoreModel2Dy(
        modes=25,
        hidden_dim=150,
        nlayers=4,
        nt=T
    ).to(device)


our_score_model_8 = ConditionalScoreModel2D(
        modes=5,
        hidden_dim=8,
        nlayers=4,
        nt=T
    ).to(device)


our_score_model_16 = ConditionalScoreModel2D(
        modes=5,
        hidden_dim=16,
        nlayers=4,
        nt=T
    ).to(device)

our_score_model_32 = ConditionalScoreModel2D(
        modes=15,
        hidden_dim=32,
        nlayers=4,
        nt=T
    ).to(device)

our_score_model_64 = ConditionalScoreModel2D(
        modes=15,
        hidden_dim=64,
        nlayers=4,
        nt=T
    ).to(device)

our_score_model_100 = ConditionalScoreModel2D(
        modes=25,
        hidden_dim=100,
        nlayers=4,
        nt=T
    ).to(device)

our_score_model_128 = ConditionalScoreModel2D(
        modes=25,
        hidden_dim=128,
        nlayers=4,
        nt=T
    ).to(device)

our_score_model_150 = ConditionalScoreModel2D(
        modes=25,
        hidden_dim=150,
        nlayers=4,
        nt=T
    ).to(device)


our_models = [our_score_model_8, our_score_model_16, our_score_model_32, our_score_model_64, our_score_model_100, our_score_model_128, our_score_model_150]
cond_models = [cond_score_model_8, cond_score_model_16, cond_score_model_32, cond_score_model_64, cond_score_model_100, cond_score_model_128, cond_score_model_150]

if inv_prob == "ct":
    for i, hidden_dim in enumerate([8, 16, 32, 64, 100, 128, 150]):
        filename = "ours" + "_" + "T=" + str(T)+ "_" + dataset_str + inv_prob + str(max_angle) + str(size)+ str(hidden_dim)
        ckpt = torch.load(os.path.join("cpts","ckpt_trained_" + filename + ".pth"), map_location=device)
        our_models[i].load_state_dict(ckpt)
        our_models[i].eval()
    
        filename = "conditional" + "_" + "T=" + str(T)+ "_" + dataset_str + inv_prob + str(max_angle)+ str(size)+ str(hidden_dim)
        ckpt = torch.load(os.path.join("cpts","ckpt_trained_" + filename + ".pth"), map_location=device)
        cond_models[i].load_state_dict(ckpt)
        cond_models[i].eval()
elif inv_prob == "deblurr":
    filename = "ours" + "_" + "T=" + str(T)+ "_" + dataset_str + inv_prob + str(sig_blurr)+ str(size)+ str(hidden_dim)
    ckpt = torch.load(os.path.join("cpts","ckpt_trained_" + filename + ".pth"), map_location=device)
    our_score_model.load_state_dict(ckpt)#

    filename = "conditional" + "_" + "T=" + str(T)+ "_" + dataset_str + inv_prob + str(sig_blurr)+ str(size)+ str(hidden_dim)
    ckpt = torch.load(os.path.join("cpts","ckpt_trained_" + filename + ".pth"), map_location=device)
    cond_score_model.load_state_dict(ckpt)


pass

In [None]:
## Generate samples using the specified sampler.
l2_conds = np.zeros(len(cond_models))
std_conds = np.zeros(len(cond_models))
bias_conds = np.zeros(len(cond_models))
t_cond =  np.zeros(len(cond_models))
for i, cond_score_model in enumerate(cond_models):
    start = time.time()
    samples = cond_sampler(cond_score_model, 
                      T = T,
                      y = y.repeat((batch_size, 1, 1))[:, :, :, None],
                      A = A, 
                      sig_obs = sig_obs,
                      nsamples = num_samples,
                      batch_size = batch_size, 
                      size = size,
                      num_steps = num_steps,
                      color_channels = 1,
                      device=device,
                      eps = eps)
    stop = time.time()
    print("time elapsed:", stop - start)
    ## Sample visualization.
    samples = samples.clamp(0.0, 1.0)
    %matplotlib inline
    sample_grid = make_grid(samples, nrow=int(np.sqrt(num_samples)))
    
    plt.figure(figsize=(15,15))
    plt.axis('off')
    plt.imshow(sample_grid.permute(1, 2, 0).cpu(), vmin=0., vmax=1.)
    plt.show()
    
    std_cond = torch.std(samples, axis = 0)[0, :, :]
    std_conds[i] = torch.mean(std_cond**2)**0.5
    bias_cond = (torch.mean(samples, axis = 0) - true.to(samples.device)[0, :, :, 0])[0, :, :]
    bias_conds[i] = torch.mean(bias_cond**2)**0.5
    l2_cond = torch.mean((samples.cuda() - true[0, :, :, 0][None, None, :, :])**2)**0.5
    l2_conds[i] = l2_cond.cpu().numpy()
    t_cond[i] = stop - start

    print("std:" , torch.mean(std_cond**2)**0.5, "bias:", torch.mean(bias_cond**2)**0.5, "l2:", l2_cond)

In [None]:
## Generate samples using the specified sampler.
l2_ourss = np.zeros(len(cond_models))
std_ourss = np.zeros(len(cond_models))
bias_ourss = np.zeros(len(cond_models))
t_ours = np.zeros(len(cond_models))
for i, our_score_model in enumerate(our_models):
    start = time.time()
    samples = Our_sampler(our_score_model, 
                      T = T,
                      y = y.repeat((num_samples, 1, 1)).reshape((num_samples, 1, A.shape[0])),
                      A = A, 
                      sig_obs = sig_obs,
                      nsamples = num_samples,
                      batch_size = batch_size, 
                      size = size,
                      num_steps = num_steps,
                      color_channels = 1,
                      device=device,
                      eps = eps)
    stop = time.time()
    print("time elapsed:", stop - start)
    ## Sample visualization.p
    samples = samples.clamp(0.0, 1.0)
    %matplotlib inline
    sample_grid = make_grid(samples, nrow=int(np.sqrt(num_samples)))
    
    plt.figure(figsize=(15,15))
    plt.axis('off')
    plt.imshow(sample_grid.permute(1, 2, 0).cpu(), vmin=0., vmax=1.)
    plt.show()
    
    std_ours = torch.std(samples, axis = 0)[0, :, :]
    std_ourss[i] = torch.mean(std_ours**2)**0.5
    bias_ours = (torch.mean(samples, axis = 0) - true.to(samples.device)[0, :, :, 0])[0, :, :]
    bias_ourss[i] = torch.mean(bias_ours**2)**0.5
    l2_ours = torch.mean((samples.cuda() - true[0, :, :, 0][None, None, :, :])**2)**0.5
    l2_ourss[i] = l2_ours.cpu().numpy()
    t_ours[i] = stop - start
    
    print("std:" , torch.mean(std_ours**2)**0.5, "bias:", torch.mean(bias_ours**2)**0.5, "l2:", l2_ours)

## Exporting images

In [None]:
plt.plot([8, 16, 32, 64, 100, 128, 150], l2_ourss, linestyle='--', marker='o', color='red', label='ours', linewidth=3, markersize=10)
plt.plot([8, 16, 32, 64, 100, 128, 150], l2_conds, linestyle='--', marker='o', color = 'blue', label='conditional', linewidth=3, markersize=10)


#plt.axis('off')
plt.legend()
plt.xlabel("nodes per layer")
plt.ylabel("L2")
plt.savefig("images/CT/graphl2.png", bbox_inches='tight')
plt.show()


In [None]:
plt.plot([8, 16, 32, 64, 100, 128, 150], std_ourss, linestyle='--', marker='o', color = 'red', label = 'ours', linewidth=3, markersize=10)
plt.plot([8, 16, 32, 64, 100, 128, 150], std_conds, linestyle='--', marker='o', color = 'blue', label = 'conditional', linewidth=3, markersize=10)

plt.plot([32, 64, 128], [0.0344, 0.0319, 0.0298], linestyle='', marker='x', color='red', label='ours', linewidth=3, markersize=10)
plt.plot([32, 64, 128], [0.0952, 0.0683, 0.0584], linestyle='', marker='x', color='blue', label='conditional', linewidth=3, markersize=10)
#plt.axis('off')
plt.legend()
plt.xlabel("nodes per layer")
plt.ylabel("Std")
plt.savefig("images/CT/graphstd.png", bbox_inches='tight')
plt.show()

In [None]:
plt.plot([8, 16, 32, 64, 100, 128, 150], bias_ourss, linestyle='--', marker='o', color = 'red', label = 'ours', linewidth=3, markersize=10)
plt.plot([8, 16, 32, 64, 100, 128, 150], bias_conds, linestyle='--', marker='o', color = 'blue', label = 'conditional', linewidth=3, markersize=10)

plt.plot([32, 64, 128], [0.0489, 0.0477, 0.0443], linestyle='', marker='x', color='red', label='ours', linewidth=3, markersize=10)
plt.plot([32, 64, 128], [0.2683, 0.0579, 0.0421], linestyle='', marker='x', color='blue', label='conditional', linewidth=3, markersize=10)
#plt.axis('off')
plt.legend()
plt.xlabel("nodes per layer")
plt.ylabel("Bias")
plt.savefig("images/CT/graphbias.png", bbox_inches='tight')
plt.show()