In [None]:
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import TensorDataset, DataLoader
import torchvision.transforms as T

import matplotlib.pyplot as plt
import numpy as np
import os
import copy
from model import UNet
import h5py

from utils import get_ssim, get_psnr, get_mse, get_mae, save_metrics, pre_process

In [None]:
img_size = 128

device = "cuda" if torch.cuda.is_available() else "cpu"
device

In [None]:
path = 'ct_test_resized.hdf5'
f = h5py.File(path,'r')
ct_test = f['data']
ct_test

In [None]:
test_size = ct_test.shape[0]
test_size

In [None]:
# configs

class data(object):
  def __init__(self):
    self.dataset = "CELEBA"
    self.image_size = 128
    self.channels = 1
    self.logit_transform = False
    self.uniform_dequantization = False
    self.gaussian_dequantization = False
    self.random_flip = True
    self.rescaled = True
    self.num_workers = 4

class model(object):
  def __init__(self):
    self.var_type = "fixedlarge"

class diffusion(object):
  def __init__(self):
    self.beta_schedule = "linear"
    self.beta_start = 0.0001
    self.beta_end = 0.02
    self.num_diffusion_timesteps = 1000


class sampling(object):
  def __init__(self):
    self.batch_size = 1
    self.last_only = True

In [None]:
class args(object):
  def __init__(self):
    self.eta = 0 # 0 is DDIM, and 1 is one type of DDPM
    self.skip = 2
    self.skip_type = "uniform"
    self.sample_type = "generalized"
    self.timesteps = 1000

In [None]:
config_data = data()
config_model = model()
config_diffusion = diffusion()
config_sampling = sampling()
args = args()

In [None]:
# diffusion.py
import os
import logging
import time
import glob

import numpy as np
import tqdm
import torch
import torch.utils.data as data
import torchvision.utils as tvu

max_epoch = 101

# utils


def compute_alpha(beta, t):
    beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0)
    a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1, 1, 1)
    return a



def inverse_data_transform(X):

    if config_data.logit_transform:
        X = torch.sigmoid(X)
    elif config_data.rescaled:
        X = (X + 1.0) / 2.0

    return torch.clamp(X, 0.0, 1.0)


def torch2hwcuint8(x, clip=False):
    if clip:
        x = torch.clamp(x, -1, 1)
    x = (x + 1.0) / 2.0
    return x


def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
    def sigmoid(x):
        return 1 / (np.exp(-x) + 1)

    if beta_schedule == "quad":
        betas = (
            np.linspace(
                beta_start ** 0.5,
                beta_end ** 0.5,
                num_diffusion_timesteps,
                dtype=np.float64,
            )
            ** 2
        )
    elif beta_schedule == "linear":
        betas = np.linspace(
            beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
        )
    elif beta_schedule == "const":
        betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
    elif beta_schedule == "jsd":  # 1/T, 1/(T-1), 1/(T-2), ..., 1
        betas = 1.0 / np.linspace(
            num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
        )
    elif beta_schedule == "sigmoid":
        betas = np.linspace(-6, 6, num_diffusion_timesteps)
        betas = sigmoid(betas) * (beta_end - beta_start) + beta_start
    else:
        raise NotImplementedError(beta_schedule)
    assert betas.shape == (num_diffusion_timesteps,)
    return betas



def generalized_steps(x, seq, model, b, condition, eta):
    with torch.no_grad():
        n = x.size(0)
        seq_next = [-1] + list(seq[:-1])
        x0_preds = []
        xs = [x]
        condition = torch.tensor(condition)
        condition = condition.unsqueeze(0)
        condition = condition.unsqueeze(0)
        condition = condition.to('cuda')

        for i, j in zip(reversed(seq), reversed(seq_next)):
            t = (torch.ones(n) * i).to(x.device)
            next_t = (torch.ones(n) * j).to(x.device)
            at = compute_alpha(b, t.long())
            at_next = compute_alpha(b, next_t.long())
            xt = xs[-1].to('cuda')
            # print(f'xt shape:{xt.shape}')
            # print(f't shape:{t.shape}')

            # print(xt.is_cuda)

            model = model.to('cuda')

            # print(f'xt shape: {xt.shape}')
            # print(f'condition shape: {condition.shape}')

            et = model(torch.cat([xt, xt], dim=1), t)

            # print(f'et:{et}')
            # break

            x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt()
            x0_preds.append(x0_t.to('cpu'))
            c1 = (
                eta * ((1 - at / at_next) * (1 - at_next) / (1 - at)).sqrt()
            )
            c2 = ((1 - at_next) - c1 ** 2).sqrt()
            xt_next = at_next.sqrt() * x0_t + c1 * torch.randn_like(x) + c2 * et
            xs.append(xt_next.to('cpu'))

    return xs, x0_preds


class Diffusion(object):
    def __init__(self,device=None):
        self.args = args
        if device is None:
            device = (
                torch.device("cuda")
                if torch.cuda.is_available()
                else torch.device("cpu")
            )
        self.device = device

        self.model_var_type = config_model.var_type
        betas = get_beta_schedule(
            beta_schedule=config_diffusion.beta_schedule,
            beta_start=config_diffusion.beta_start,
            beta_end=config_diffusion.beta_end,
            num_diffusion_timesteps=config_diffusion.num_diffusion_timesteps,
        )
        betas = self.betas = torch.from_numpy(betas).float().to(self.device)
        self.num_timesteps = betas.shape[0]

        alphas = 1.0 - betas
        alphas_cumprod = alphas.cumprod(dim=0)
        alphas_cumprod_prev = torch.cat(
            [torch.ones(1).to(device), alphas_cumprod[:-1]], dim=0
        )
        posterior_variance = (
            betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
        )
        if self.model_var_type == "fixedlarge":
            self.logvar = betas.log()
            # torch.cat(
            # [posterior_variance[1:2], betas[1:]], dim=0).log()
        elif self.model_var_type == "fixedsmall":
            self.logvar = posterior_variance.clamp(min=1e-20).log()



    def sample_fid(self, model, conditions):
        # img_id = len(glob.glob(f"{self.args.image_folder}/*"))
        img_id = 0
        print(f"starting from image {0}")
        n_rounds = (test_size - img_id) // config_sampling.batch_size

        i = 0

        with torch.no_grad():
            for _ in tqdm.tqdm(
                range(n_rounds), desc="Generating image samples for FID evaluation."
            ):
                n = config_sampling.batch_size
                x = torch.randn(
                    n,
                    config_data.channels,
                    config_data.image_size,
                    config_data.image_size,
                    device=self.device,
                )

                condition = conditions[i]

                # self.sample_image(x, model, condition)

                x = self.sample_image(x, model, condition)
                x = inverse_data_transform(x)

                path = './current experiment/ddim_results'

                if not os.path.exists(path):
                  os.makedirs(path)

                for j in range(n):
                    ddim_result = x[j].unsqueeze(0)
                    ddim_result = np.array(ddim_result)

                    save_path = f'./current experiment/ddim_results/x0_number_{i+1}_epoch_{max_epoch}.npy'
                    np.save(save_path, ddim_result)
                    # print(ddim_result.shape)
                    # tvu.save_image(
                    #     x[i], os.path.join(self.args.image_folder, f"{img_id}.png")
                    # )
                    img_id += 1

                i += 1


    def sample_image(self, x, model, condition, last=True):
        try:
            skip = self.args.skip
        except Exception:
            skip = 1

        if self.args.sample_type == "generalized":
            if self.args.skip_type == "uniform":
                skip = self.num_timesteps // self.args.timesteps
                seq = range(0, self.num_timesteps, skip)
            elif self.args.skip_type == "quad":
                seq = (
                    np.linspace(
                        0, np.sqrt(self.num_timesteps * 0.8), self.args.timesteps
                    )
                    ** 2
                )
                seq = [int(s) for s in list(seq)]
            else:
                raise NotImplementedError
            # generalized_steps(x, seq, model, self.betas, condition, eta=self.args.eta)
            xs = generalized_steps(x, seq, model, self.betas, condition, eta=self.args.eta)
            x = xs

        if last:
            x = x[1][-1]
        return x

In [None]:
# UNet
ch = 64
ch_mult = [1, 2, 2, 4, 4]
attn = [1]
num_res_blocks = 2
dropout = 0.

# Gaussian Diffusion
beta_1 = 1e-4
beta_T = 0.02
T = 1000

net_model = UNet(
        T=T, ch=ch, ch_mult=ch_mult, attn=attn,
        num_res_blocks=num_res_blocks, dropout=dropout)

ema_model = copy.deepcopy(net_model)

model_path = f'./current experiment/Saved_model/ddpm-unet_epoch_102.pt'

ema_model.load_state_dict(torch.load(model_path))
ema_model.eval()

diffusion = Diffusion()
diffusion.sample_fid(ema_model, conditions=ct_test)

In [None]:
sum_time_ddim = 1231

avg_time_ddim = sum_time_ddim / test_size

avg_time_ddim

In [None]:
diff_outs = [None] * test_size

for i in range(test_size):

  path = f'./current experiment/ddim_results/x0_number_{i+1}_epoch_{max_epoch}.npy'
  diff_out = np.load(path)
  diff_outs[i] = pre_process(diff_out, img_size)

In [None]:
targets = [None] * test_size

for i in range(test_size):
  ct_sample = ct_test[i]
  targets[i] = pre_process(ct_sample, img_size)


In [None]:
max_ssim, argmax_ssim, avg_ssim = get_ssim(diff_outs, targets)
max_psnr, argmax_psnr, avg_psnr = get_psnr(diff_outs, targets)

min_mse, argmin_mse, avg_mse = get_mse(diff_outs, targets)
min_mae, argmin_mae, avg_mae = get_mae(diff_outs, targets)

save_metrics(avg_time_ddim, avg_ssim, avg_psnr, avg_mse, avg_mae, max_ssim, max_psnr, min_mse, min_mae)

In [None]:
from PIL import Image


if not os.path.exists('./current experiment/Train_Output/'):
    os.makedirs('./current experiment/Train_Output/')


best_ssim_folder_path = './current experiment/best ssim'
best_psnr_folder_path ='./current experiment/best psnr'
best_mse_folder_path ='./current experiment/best mse'
best_mae_folder_path ='./current experiment/best mae'

os.makedirs(best_ssim_folder_path)
os.makedirs(best_psnr_folder_path)
os.makedirs(best_mse_folder_path)
os.makedirs(best_mae_folder_path)    

In [None]:
best_ssim_diff_out = diff_outs[argmax_ssim]
best_psnr_diff_out = diff_outs[argmax_psnr]
best_mse_diff_out = diff_outs[argmin_mse]
best_mae_diff_out = diff_outs[argmin_mae]

best_ssim_target = targets[argmax_ssim]
best_psnr_target = targets[argmax_psnr]
best_mse_target = targets[argmin_mse]
best_mae_target = targets[argmin_mae]


np.save(os.path.join(best_ssim_folder_path, f'diff_out_{argmax_ssim}'), best_ssim_diff_out)
np.save(os.path.join(best_ssim_folder_path, f'target_{argmax_ssim}'), best_ssim_target)

np.save(os.path.join(best_psnr_folder_path, f'diff_out_{argmax_psnr}'), best_psnr_diff_out)
np.save(os.path.join(best_psnr_folder_path, f'target_{argmax_psnr}'), best_psnr_target)

np.save(os.path.join(best_mse_folder_path, f'diff_out_{argmin_mse}'), best_mse_diff_out)
np.save(os.path.join(best_mse_folder_path, f'target_{argmin_mse}'), best_mse_target)

np.save(os.path.join(best_mae_folder_path, f'diff_out_{argmin_mae}'), best_mae_diff_out)
np.save(os.path.join(best_mae_folder_path, f'target_{argmin_mae}'), best_mae_target)