In [1]:
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import TensorDataset, DataLoader
import torchvision.transforms as T

import matplotlib.pyplot as plt
import numpy as np
import os

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [3]:
base_dir = '/home/iplab/Desktop/Shaker/Brain tumor MRI and CT scan/data(processed)'

# read mr and ct datasets
MR_train_address = os.path.join(base_dir, 'train_input.npy')
CT_train_address = os.path.join(base_dir, 'train_output.npy')

MR_val_address = os.path.join(base_dir, 'val_input.npy')
CT_val_address = os.path.join(base_dir, 'val_output.npy')

MR_test_address = os.path.join(base_dir, 'test_input.npy')
CT_test_address = os.path.join(base_dir, 'test_output.npy')

mr_train = np.load(MR_train_address)
ct_train = np.load(CT_train_address)

mr_val = np.load(MR_val_address)
ct_val = np.load(CT_val_address)

mr_test = np.load(MR_test_address)
ct_test = np.load(CT_test_address)

ct_train.shape, ct_val.shape, ct_test.shape

((570, 256, 256), (90, 256, 256), (150, 256, 256))

In [4]:
# resizes a 1d numpy array to an arbitrary size
def resize(img, size):

  img = img.astype('float32')
  img = torch.tensor(img)
  img = img.unsqueeze(0)

  transform = T.Resize(size)
  resized_img = transform(img)

  # resized_img = np.array(resized_img)

  return resized_img

In [5]:
change_gray_level = T.Compose([
    T.Lambda(lambda t: t * 0.2)
])

horizontal_flip = T.Compose([
    T.functional.hflip
])

vertical_flip = T.Compose([
    T.functional.vflip
])

rotate_45 = T.Compose([
    T.Lambda(lambda t: T.functional.rotate(t, angle=45))
])

rotate_minus_45 = T.Compose([
    T.Lambda(lambda t: T.functional.rotate(t, angle=-45))
])

In [6]:
n_train_samples = ct_train.shape[0]
n_val_samples = ct_val.shape[0]
n_test_samples = ct_test.shape[0]

# remove 90 samples from test data and add it to train data
n_add_from_test_to_train = 69

n_train_new = n_train_samples + n_val_samples + n_add_from_test_to_train
n_test_new = n_test_samples - n_add_from_test_to_train

# data augmentation
# add 5 varient of each sample
# so our train dataset will be 6 times bigger

n_train_new = n_train_new * 6
# n_train_new = n_train_new * 3

mr_train_resized = [None] * n_train_new
ct_train_resized = [None] * n_train_new

mr_test_resized = [None] * n_test_new
ct_test_resized = [None] * n_test_new

# train samples with augmentation
for i in range(n_train_samples):

  j = i * 6
  # j = i*3
  # j = i

  resized_mr = resize(mr_train[i], 128)

  mr_train_resized[j] = resized_mr
  mr_train_resized[j+1] = change_gray_level(resized_mr)
  mr_train_resized[j+2] = horizontal_flip(resized_mr)
  mr_train_resized[j+3] = vertical_flip(resized_mr)
  mr_train_resized[j+4] = rotate_45(resized_mr)
  mr_train_resized[j+5] = rotate_minus_45(resized_mr)

  resized_ct = resize(ct_train[i], 128)

  ct_train_resized[j] = resized_ct
  ct_train_resized[j+1] = change_gray_level(resized_ct)
  ct_train_resized[j+2] = horizontal_flip(resized_ct)
  ct_train_resized[j+3] = vertical_flip(resized_ct)
  ct_train_resized[j+4] = rotate_45(resized_ct)
  ct_train_resized[j+5] = rotate_minus_45(resized_ct)


# validation samples with augmentation
for i in range(n_val_samples):

  j = i*6 + n_train_samples*6
  # j = i*3 + n_train_samples*3
  # j = i + n_train_samples

  resized_mr = resize(mr_val[i], 128)

  mr_train_resized[j] = resized_mr
  mr_train_resized[j+1] = change_gray_level(resized_mr)
  mr_train_resized[j+2] = horizontal_flip(resized_mr)
  mr_train_resized[j+3] = vertical_flip(resized_mr)
  mr_train_resized[j+4] = rotate_45(resized_mr)
  mr_train_resized[j+5] = rotate_minus_45(resized_mr)

  resized_ct = resize(ct_val[i], 128)

  ct_train_resized[j] = resized_ct
  ct_train_resized[j+1] = change_gray_level(resized_ct)
  ct_train_resized[j+2] = horizontal_flip(resized_ct)
  ct_train_resized[j+3] = vertical_flip(resized_ct)
  ct_train_resized[j+4] = rotate_45(resized_ct)
  ct_train_resized[j+5] = rotate_minus_45(resized_ct)


# part of test samples with augmentation

for i in range(n_add_from_test_to_train):

  j = i*6 + n_train_samples*6 + n_val_samples*6
  # j = i*3 + n_train_samples*3 + n_val_samples*3
  # j = i + n_train_samples + n_val_samples

  resized_mr = resize(mr_test[i], 128)

  mr_train_resized[j] = resized_mr
  mr_train_resized[j+1] = change_gray_level(resized_mr)
  mr_train_resized[j+2] = horizontal_flip(resized_mr)
  mr_train_resized[j+3] = vertical_flip(resized_mr)
  mr_train_resized[j+4] = rotate_45(resized_mr)
  mr_train_resized[j+5] = rotate_minus_45(resized_mr)

  resized_ct = resize(ct_test[i], 128)

  ct_train_resized[j] = resized_ct
  ct_train_resized[j+1] = change_gray_level(resized_ct)
  ct_train_resized[j+2] = horizontal_flip(resized_ct)
  ct_train_resized[j+3] = vertical_flip(resized_ct)
  ct_train_resized[j+4] = rotate_45(resized_ct)
  ct_train_resized[j+5] = rotate_minus_45(resized_ct)

# test samples
for i in range(n_test_new):

  j = i + n_add_from_test_to_train

  mr_test_resized[i] = resize(mr_test[j], 128)
  ct_test_resized[i] = resize(ct_test[j], 128)



In [7]:
# convert train and test samples to numpy array
for i in range(n_train_new):
  mr_train_resized[i] = np.array(mr_train_resized[i].squeeze())
  ct_train_resized[i] = np.array(ct_train_resized[i].squeeze())

for i in range(n_test_new):
  mr_test_resized[i] = np.array(mr_test_resized[i].squeeze())
  ct_test_resized[i] = np.array(ct_test_resized[i].squeeze())

In [8]:
# convert lists of mr and ct to numpy arrays
mr_train_resized = np.array(mr_train_resized)

ct_train_resized = np.array(ct_train_resized)

mr_test_resized = np.array(mr_test_resized)

ct_test_resized = np.array(ct_test_resized)


print('train images shape:', mr_train_resized.shape)
print('test images shape:', mr_test_resized.shape)

train images shape: (4374, 128, 128)
test images shape: (81, 128, 128)


In [9]:
import h5py

# creating hdf5 data from numpy arrays
with h5py.File('mr_train_resized.hdf5', 'w') as f:
    dset = f.create_dataset("data", data = mr_train_resized)

with h5py.File('ct_train_resized.hdf5', 'w') as f:
    dset = f.create_dataset("data", data = ct_train_resized)

with h5py.File('mr_test_resized.hdf5', 'w') as f:
    dset = f.create_dataset("data", data = mr_test_resized)

with h5py.File('ct_test_resized.hdf5', 'w') as f:
    dset = f.create_dataset("data", data = ct_test_resized)

In [10]:
path = 'mr_train_resized.hdf5'
f = h5py.File(path,'r')
load_data = f['data']
load_data

<HDF5 dataset "data": shape (4374, 128, 128), type "<f4">

In [11]:
# dataset.py
import random
import numpy as np
import torch
import h5py
from torch.utils.data import Dataset
from torchvision import transforms

def random_rot(img1,img2):
    k = np.random.randint(0, 3)
    img1 = np.rot90(img1, k+1)
    img2 = np.rot90(img2, k+1)
    return img1,img2

def random_flip(img1,img2):
    axis = np.random.randint(0, 2)
    img1 = np.flip(img1, axis=axis).copy()
    img2 = np.flip(img2, axis=axis).copy()
    return img1,img2

class RandomGenerator(object):
    def __init__(self, output_size):
        self.output_size = output_size

    def __call__(self, sample):
        lr, hr = sample['lr'], sample['hr']

        if random.random() > 0.5:
            lr, hr = random_rot(lr, hr)
        if random.random() > 0.5:
            lr, hr = random_flip(lr, hr)
        sample = {'lr': lr,'hr': hr}
        return sample



class Train_Data(Dataset):
    def __init__(self):
        path = 'mr_train_resized.hdf5' # data in hdf5 as an example
        f = h5py.File(path,'r')
        load_data = f['data']
        self.lr = load_data
        path = 'ct_train_resized.hdf5'
        f = h5py.File(path,'r')
        load_data = f['data']
        self.hr = load_data
        c, self.h, self.w = self.lr.shape

        self.len = c
        self.transform=transforms.Compose([RandomGenerator(output_size=[self.h, self.w])])

    def __getitem__(self, index):
        x = self.lr[index, :, :]
        y = self.hr[index, :, :]

        x = self.norm(x)
        y = self.norm(y)

        sample = {'lr': x,'hr': y}
        if self.transform:
            sample = self.transform(sample)

        x, y = sample['lr'], sample['hr']

        xx = np.zeros((1, self.h, self.w))
        yy = np.zeros((1, self.h, self.w))

        xx[0,:,:] = x.copy()
        yy[0,:,:] = y.copy()

        xx = torch.from_numpy(xx)
        yy = torch.from_numpy(yy)

        xx = xx.type(torch.FloatTensor)
        yy = yy.type(torch.FloatTensor)

        return xx, yy

    def __len__(self):
        return self.len

    def norm(self, x):
        if np.amax(x) > 0:
            x = (x - np.amin(x)) / (np.amax(x) - np.amin(x))
        return x


class Valid_Data(Dataset):
    def __init__(self):
        path = 'mr_test_resized.hdf5'
        f = h5py.File(path,'r')
        load_data = f['data']
        self.lr = load_data
        path = 'ct_test_resized.hdf5'
        f = h5py.File(path,'r')
        load_data = f['data']
        self.hr = load_data
        c, self.h, self.w = self.lr.shape

        self.len = c

    def __getitem__(self, index):
        x = self.lr[index, :, :]
        y = self.hr[index, :, :]

        x = self.norm(x)
        y = self.norm(y)

        xx = np.zeros((1, self.h, self.w))
        yy = np.zeros((1, self.h, self.w))

        xx[0,:,:] = x.copy()
        yy[0,:,:] = y.copy()

        xx = torch.from_numpy(xx)
        yy = torch.from_numpy(yy)

        xx = xx.type(torch.FloatTensor)
        yy = yy.type(torch.FloatTensor)

        return xx, yy

    def __len__(self):
        return self.len

    def norm(self, x):
        if np.amax(x) > 0:
            x = (x - np.amin(x)) / (np.amax(x) - np.amin(x))
        return x

class Test_Data(Dataset):
    def __init__(self):
        path = '/path/to/test_mri_data'
        f = h5py.File(path,'r')
        load_data = f['data']
        self.lr = load_data
        path = '/path/to/test_ct_data'
        f = h5py.File(path,'r')
        load_data = f['data']
        self.hr = load_data
        c, self.h, self.w = self.lr.shape

        self.len = 135

    def __getitem__(self, index):
        x = self.lr[index, :, :]
        y = self.hr[index, :, :]

        x = self.norm(x)
        y = self.norm(y)

        xx = np.zeros((1, self.h, self.w))
        yy = np.zeros((1, self.h, self.w))

        xx[0,:,:] = x.copy()
        yy[0,:,:] = y.copy()

        xx = torch.from_numpy(xx)
        yy = torch.from_numpy(yy)

        xx = xx.type(torch.FloatTensor)
        yy = yy.type(torch.FloatTensor)

        return xx, yy

    def __len__(self):
        return self.len

    def norm(self, x):
        if np.amax(x) > 0:
            x = (x - np.amin(x)) / (np.amax(x) - np.amin(x))
        return x

In [12]:
# diffusion.py
import torch
import torch.nn as nn
import numpy as np
from copy import deepcopy
from scipy import integrate

device = torch.device('cuda:0')

def marginal_prob_std(t, sigma):
    """ Compute standard deviation of conditional Gaussian distribution at time t.
    SDE: dx = sigma^t * dw  t belongs to [0,1]
    p_{0t}(x(t) | x(0)) = N(m,var) = N(x(t); x(0), 1 / (2log(sigma)) * (sigma^(2t)-1) * I)

    """
    t = torch.tensor(t, device=device)
    return torch.sqrt((sigma**(2 * t) - 1.) / 2. / np.log(sigma))

def diffusion_coeff(t, sigma):
    """ Compute diffusion coefficient at time t """
    return torch.tensor(sigma**t, device=device)

def loss_fn(score_model, condition, x, marginal_prob_std, eps=1e-5):
    """ The loss function for training score-based generative models

    Args:
        score_model: A PyTorch model instance that represents a time-dependent
            score-based model
        x: A mini-batch of training data
        condition: input image as condition
        marginal_prob_std: A function that gives the standard deviation of
            the perturbation kernel
        eps: A tolerance value for numerical stability
    """
    # Step 1: randomly generate time t from [0.0001, 0.9999]
    random_t = torch.rand(x.shape[0], device=x.device) * (1. - eps) + eps

    # Step 2: sampling a perturbed_x sample from data distribtion p_t(x) based on the reparameterization trick
    z = torch.randn_like(x)
    std = marginal_prob_std(random_t)
    perturbed_x = x + z * std[:, None, None, None]

    # Step 3: adding noised sample into the score network to estimate score
    score = score_model(torch.cat([perturbed_x, condition], dim=1), random_t)

    # Step 4: Computing score matching loss
    loss = torch.mean(torch.sum((score * std[:, None, None, None] + z)**2, dim=(1,2,3)))

    return loss


class EMA(nn.Module):
    def __init__(self, model, decay=0.9999):
        super(EMA, self).__init__()
        # make a copy of the model for accumulating moving average of weights
        self.module = deepcopy(model)
        # self.module = deepcopy(model)
        self.module.eval()
        self.decay = decay

    def _update(self, model, update_fn):
        with torch.no_grad():
            for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
                ema_v.copy_(update_fn(ema_v, model_v))

    def update(self, model):
        self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)

    def set(self, model):
        self._update(model, update_fn=lambda e, m: m)


def euler_sampler(score_model, condition, marginal_prob_std, diffusion_coeff, batch_size=64, num_steps=1000, eps=1e-3):
    """
       SDE: dx = f(x,t)dt + g(t)dw
       Reverse-time SDE: dx = [f(x,t) - g(t)**2 * score]dt + g(t)dw_bar
       In this case, omit f(x,t) and choose SDE: dx = sigma*t * dw  t belongs to [0,1]
       Reverse-time SDE: dx = -sigma**{2t} * score * dt + sigma**t * dw_bar

       To sample from time-dependent score-based model, first draw a sample from the prior
       distribution p_1 ~ N(x; 0, 0.5*(sigma**2 - 1)*I), then solve the reverse-time SDE
       via Euler-Maruyama approach. Replacing dw with z ~ N(0, g(t)**2 * dt * I),
       we can obtain the iteration rule:
           x_{t-dt} = x_t + sigma**{2t} * score * dt + sigma**t * sqrt(dt) * z_t, where
           z_t ~ N(0,I)
    """

    # Step 1: define start time t=1 and random samples from prior data distribution
    t = torch.ones(batch_size, device=device)
    init_x = torch.randn(batch_size, 1, 128, 128, device=device) * marginal_prob_std(t)[:, None, None, None]

    # Step 2: define reverse time grid and time invervals
    time_steps = torch.linspace(1., eps, num_steps, device=device)
    step_size = time_steps[0] - time_steps[1]

    # Step 3: solve reverse time SDE via Euler-Maruyama approach
    x = init_x
    with torch.no_grad():
        for time_step in time_steps:
            batch_time_step = torch.ones(batch_size, device=device) * time_step
            g = diffusion_coeff(batch_time_step)
            mean_x = x + (g**2)[:, None, None, None] * score_model(torch.cat([x, condition], dim=1), batch_time_step) * step_size
            x = mean_x + torch.sqrt(step_size) * g[:, None, None, None] * torch.randn_like(x)

    # Step 4: select final step expectation as a sampler
    return mean_x


def pc_sampler(score_model, condition, marginal_prob_std, diffusion_coeff, batch_size=64, snr=0.16, num_steps=1000, eps=1e-3):
    """ Generate samplers from score-based models with Predictor-Corrector method.

    Parameters
    ----------
    score_model : A PyTorch model instance that represents a time-dependent
            score-based model.
    marginal_prob_std : A function that gives the standard deviation of
            the perturbation kernel
    diffusion_coeff : A function that gives the diffusion coefficient
    batch_size : default: 64
    snr : signal-to-noise-ratio, default: 0.16
    """
    # Step 1: define start time t=1 and random samples from prior data distribution
    t = torch.ones(batch_size, device=device)
    init_x = torch.randn(batch_size, 1, 128, 128, device=device) * marginal_prob_std(t)[:, None, None, None]

    # Step 2: define reverse time grid and time invervals
    time_steps = np.linspace(1., eps, num_steps)
    step_size = time_steps[0] - time_steps[1]

    # Step 3: alternatively use Langevin sampling and reverse-time SDE with Euler approach to solve
    x = init_x
    with torch.no_grad():
        for time_step in time_steps:
            batch_time_step = torch.ones(batch_size, device=device) * time_step

            # Corrector step (Langevin MCMC)
            grad = score_model(torch.cat([x, condition], dim=1), batch_time_step)
            grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean()
            noise_norm = np.sqrt(np.prod(x.shape[1:]))
            langevin_step_size = 2 * (snr * noise_norm / grad_norm)**2
            # print(f"{langevin_step_size=}")

            for _ in range(10):
                x = x + langevin_step_size * grad + torch.sqrt(2 * langevin_step_size) * torch.randn_like(x)
                grad = score_model(torch.cat([x, condition], dim=1), batch_time_step)
                grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean()
                noise_norm = np.sqrt(np.prod(x.shape[1:]))
                langevin_step_size = 2 * (snr * noise_norm / grad_norm)**2
                # print(f"{langevin_step_size=}")

            # Predictor step (Euler-Maruyama)
            g = diffusion_coeff(batch_time_step)
            mean_x = x + (g**2)[:, None, None, None] * score_model(torch.cat([x, condition], dim=1), batch_time_step) * step_size
            x = mean_x + torch.sqrt(g**2 * step_size)[:, None, None, None] * torch.randn_like(x)

        # Step 4: select final step expectation as a sampler
        return mean_x

def ode_sampler(score_model, condition, marginal_prob_std, diffusion_coeff, batch_size=64, atol=1e-5, rtol=1e-5, z=None, eps=1e-3):
    """ Generate samplers from score-based models with ODE method """

    # Step 1: define start time t=1 and initial x
    t = torch.ones(batch_size, device=device)
    if z is None:
        init_x = torch.randn(batch_size, 1, 128, 128, device=device) * marginal_prob_std(t)[:, None, None, None]
    else:
        init_x = z
    shape = init_x.shape

    # Step 2: define score estimation function and ODE function
    def score_eval_wrapper(sample, time_steps):
        """ A Wrapper of the score-based model for use by the ODE solver """
        sample = torch.tensor(sample, device=device, dtype=torch.float32).reshape(shape)
        time_steps = torch.tensor(time_steps, device=device, dtype=torch.float32).reshape((sample.shape[0], ))
        with torch.no_grad():
            score = score_model(torch.cat([sample, condition], dim=1), time_steps)
        return score.cpu().numpy().reshape((-1,)).astype(np.float64)

    def ode_func(t, x):
        """ The ODE function for use by the ODE solver """
        time_steps = np.ones((shape[0],)) * t
        g = diffusion_coeff(torch.tensor(t)).cpu().numpy()
        return -0.5 * (g**2) * score_eval_wrapper(x, time_steps)

    # Step 3: using ODE to solve value at t=eps
    res = integrate.solve_ivp(ode_func, (1., eps), init_x.reshape(-1).cpu().numpy(), rtol=rtol, atol=atol, method='RK45')
    # print(f"Number of function evaluations: {res.nfev}")

    x = torch.tensor(res.y[:, -1], device=device).reshape(shape)

    return x

In [13]:
# model.py
import math
import torch
from torch import nn
from torch.nn import init
from torch.nn import functional as F
from torchvision import models
from collections import namedtuple
import numpy as np


class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)


class TimeEmbedding(nn.Module):
    def __init__(self, embed_dim, scale=30.):
        super().__init__()
        # Randomly sample weights druing initialization. These weights are fixed
        # during optimization and are not trainable.
        self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False)

    def forward(self, x):
        x_proj = x[:, None] * self.W[None, :].to(x.device) * 2 * np.pi
        return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)


class DownSample(nn.Module):
    def __init__(self, in_ch):
        super().__init__()
        self.main = nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1)
        self.initialize()

    def initialize(self):
        init.xavier_uniform_(self.main.weight)
        init.zeros_(self.main.bias)

    def forward(self, x, temb):
        x = self.main(x)
        return x


class UpSample(nn.Module):
    def __init__(self, in_ch):
        super().__init__()
        self.main = nn.Conv2d(in_ch, in_ch, 3, stride=1, padding=1)
        self.initialize()

    def initialize(self):
        init.xavier_uniform_(self.main.weight)
        init.zeros_(self.main.bias)

    def forward(self, x, temb):
        _, _, H, W = x.shape
        x = F.interpolate(
            x, scale_factor=2, mode='nearest')
        x = self.main(x)
        return x


class AttnBlock(nn.Module):
    def __init__(self, in_ch):
        super().__init__()
        self.group_norm = nn.GroupNorm(32, in_ch)
        self.proj_q = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
        self.proj_k = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
        self.proj_v = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
        self.proj = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
        self.initialize()

    def initialize(self):
        for module in [self.proj_q, self.proj_k, self.proj_v, self.proj]:
            init.xavier_uniform_(module.weight)
            init.zeros_(module.bias)
        init.xavier_uniform_(self.proj.weight, gain=1e-5)

    def forward(self, x):
        B, C, H, W = x.shape
        h = self.group_norm(x)
        q = self.proj_q(h)
        k = self.proj_k(h)
        v = self.proj_v(h)

        q = q.permute(0, 2, 3, 1).view(B, H * W, C)
        k = k.view(B, C, H * W)
        w = torch.bmm(q, k) * (int(C) ** (-0.5))
        assert list(w.shape) == [B, H * W, H * W]
        w = F.softmax(w, dim=-1)

        v = v.permute(0, 2, 3, 1).view(B, H * W, C)
        h = torch.bmm(w, v)
        assert list(h.shape) == [B, H * W, C]
        h = h.view(B, H, W, C).permute(0, 3, 1, 2)
        h = self.proj(h)

        return x + h


class ResBlock(nn.Module):
    def __init__(self, in_ch, out_ch, tdim, dropout, attn=False):
        super().__init__()
        self.block1 = nn.Sequential(
            nn.GroupNorm(32, in_ch),
            Swish(),
            nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1),
        )
        self.temb_proj = nn.Sequential(
            Swish(),
            nn.Linear(tdim, out_ch),
        )
        self.block2 = nn.Sequential(
            nn.GroupNorm(32, out_ch),
            Swish(),
            nn.Dropout(dropout),
            nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1),
        )
        if in_ch != out_ch:
            self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0)
        else:
            self.shortcut = nn.Identity()
        if attn:
            self.attn = AttnBlock(out_ch)
        else:
            self.attn = nn.Identity()
        self.initialize()

    def initialize(self):
        for module in self.modules():
            if isinstance(module, (nn.Conv2d, nn.Linear)):
                init.xavier_uniform_(module.weight)
                init.zeros_(module.bias)
        init.xavier_uniform_(self.block2[-1].weight, gain=1e-5)

    def forward(self, x, temb):
        h = self.block1(x) # [batch, out_ch, h, w]
        h += self.temb_proj(temb)[:, :, None, None] # [batch, out_ch, :, :]
        h = self.block2(h) # [batch, out_ch, h, w]

        h = h + self.shortcut(x)
        h = self.attn(h)
        return h


class UNet(nn.Module):
    def __init__(self, marginal_prob_std, T, ch, ch_mult, attn, num_res_blocks, dropout):
        super().__init__()
        assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound'
        tdim = ch * 4
        self.time_embedding = TimeEmbedding(tdim)
        # self.time_embedding = TimeEmbedding(T, ch, tdim)

        self.head = nn.Conv2d(2, ch, kernel_size=3, stride=1, padding=1)
        self.downblocks = nn.ModuleList()
        chs = [ch]  # record output channel when dowmsample for upsample
        now_ch = ch
        for i, mult in enumerate(ch_mult):
            out_ch = ch * mult
            for _ in range(num_res_blocks):
                self.downblocks.append(ResBlock(
                    in_ch=now_ch, out_ch=out_ch, tdim=tdim,
                    dropout=dropout, attn=False))
                now_ch = out_ch
                chs.append(now_ch)
            if i != len(ch_mult) - 1:
                self.downblocks.append(DownSample(now_ch))
                chs.append(now_ch)

        self.middleblocks = nn.ModuleList([
            ResBlock(now_ch, now_ch, tdim, dropout, attn=False),
            ResBlock(now_ch, now_ch, tdim, dropout, attn=False),
        ])

        self.upblocks = nn.ModuleList()
        for i, mult in reversed(list(enumerate(ch_mult))):
            out_ch = ch * mult
            for _ in range(num_res_blocks + 1):
                self.upblocks.append(ResBlock(
                    in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim,
                    dropout=dropout, attn=False))
                now_ch = out_ch
            if i != 0:
                self.upblocks.append(UpSample(now_ch))
        assert len(chs) == 0

        self.tail = nn.Sequential(
            nn.GroupNorm(32, now_ch),
            Swish(),
            nn.Conv2d(now_ch, 1, 3, stride=1, padding=1)
        )

        self.marginal_prob_std = marginal_prob_std

        self.initialize()

    def initialize(self):
        init.xavier_uniform_(self.head.weight)
        init.zeros_(self.head.bias)
        init.xavier_uniform_(self.tail[-1].weight, gain=1e-5)
        init.zeros_(self.tail[-1].bias)

    def forward(self, x, t):  # t [batch,], x [batch,3,h,w]
        # Timestep embedding
        temb = self.time_embedding(t) # [batch, 128*4]
        # Downsampling
        h = self.head(x) # [batch,128,h,w]
        hs = [h]
        for layer in self.downblocks: # [res, res, down; res, res, down; res, res, down; res, res]
            h = layer(h, temb)
            hs.append(h)
        # Middle
        for layer in self.middleblocks: # [res, res]
            h = layer(h, temb)
        # Upsampling
        for layer in self.upblocks: # [res,res,res; res,res,res,up; res,res,res,up; res,res,res,up]
            if isinstance(layer, ResBlock):
                h = torch.cat([h, hs.pop()], dim=1)
            h = layer(h, temb)
        h = self.tail(h)
        h = h / self.marginal_prob_std(t)[:, None, None, None] # divide the expectation of L2-norm
        return h

In [14]:
# train.py
import os
import warnings
import scipy.io as sio
# from absl import app, flags
from tqdm import tqdm
from matplotlib import pyplot as plt
from matplotlib import gridspec
import functools
import numpy as np

import torch
# from torchvision.utils import make_grid
from torch.utils.data import DataLoader


train = True
continue_train = False

# UNet
ch = 64
ch_mult = [1, 2, 2, 4, 4]
attn = [1]
num_res_blocks = 2
dropout = 0.

# Gaussian Diffusion
beta_1 = 1e-4
beta_T = 0.02
T = 1000

# Training
lr = 1e-4
grad_clip = 1.
img_size = 128
batch_size = 2
num_workers = 1
ema_decay = 0.9999

sample_size = 1

min_epoch = 5
max_epoch = 5000

epoch_mean_loss = max_epoch * [None]
n_prev_epochs = 5

DIREC = f'score-unet_n-train-samples_{n_train_new}n-test-samples_{n_test_new}_batch-size_{batch_size}_T_{T}_img-size_{img_size}_data_augmentation_all'

device = torch.device('cuda:0')


def train():
    sigma = 25.
    marginal_prob_std_fn = functools.partial(marginal_prob_std, sigma=sigma) # construc function without parameters
    diffusion_coeff_fn = functools.partial(diffusion_coeff, sigma=sigma) # construc function without parameters

    # dataset
    tr_train = Train_Data()
    trainloader = DataLoader(tr_train, batch_size=batch_size, num_workers=num_workers,
                             pin_memory=True, shuffle=True)
    va_train = Valid_Data()
    validloader = DataLoader(va_train, batch_size=sample_size, num_workers=num_workers,
                             pin_memory=True, shuffle=False)

    # model setup
    score_model = UNet(marginal_prob_std=marginal_prob_std_fn,
                       T=T, ch=ch, ch_mult=ch_mult, attn=attn,
                       num_res_blocks=num_res_blocks, dropout=dropout).to(device)

    ema_model = EMA(score_model).to(device)

    optim = torch.optim.Adam(score_model.parameters(), lr=lr)

    # sampler setup
    sampler_od = ode_sampler
    sampler_eu = euler_sampler
    sampler_pc = pc_sampler

    # show model size
    model_size = 0
    for param in score_model.parameters():
        model_size += param.data.nelement()
    print('Model params: %.2f M' % (model_size / 1024 / 1024))

    if continue_train:
        checkpoint = torch.load('./Save/' + DIREC + '/model_latest.pkl')
        score_model.load_state_dict(checkpoint['score_model'])
        ema_model.load_state_dict(checkpoint['ema_model'])
        optim.load_state_dict(checkpoint['optim'])
        restore_epoch = checkpoint['epoch']
        print('Finish loading model')
    else:
        restore_epoch = 0

    # if not os.path.exists('Loss'):
    #     os.makedirs('Loss')
    if not os.path.exists('current experiment'):
        os.makedirs('current experiment')

    tr_ls = []
    if continue_train:
        readmat = sio.loadmat('./Loss/' + DIREC)
        load_tr_ls = readmat['loss']
        for i in range(restore_epoch):
            tr_ls.append(load_tr_ls[0][i])
        print('Finish loading loss!')


    if not os.path.exists('./current experiment/diff_results'):
        os.makedirs('./current experiment/diff_results')

    last_epoch = False

    for epoch in range(restore_epoch, max_epoch):
        with tqdm(trainloader, unit="batch") as tepoch:
            tmp_tr_loss = 0
            tr_sample = 0
            score_model.train()
            for data, target in tepoch:
                tepoch.set_description(f"Epoch {epoch+1}")

                # train
                condition = data.to(device)
                x_0 = target.to(device)

                loss = loss_fn(score_model, condition, x_0, marginal_prob_std_fn)

                tmp_tr_loss += loss.item()
                tr_sample += len(data)

                optim.zero_grad()
                loss.backward()
                optim.step()
                ema_model.update(score_model)

                tepoch.set_postfix({'Loss': loss.item()})

        mean_loss = tmp_tr_loss / tr_sample
        print('mean loss', mean_loss)

        epoch_mean_loss[epoch] = mean_loss
        
        if epoch+1 > min_epoch:
          prev_mean_loss = 0
          
          for i in range(n_prev_epochs):
            prev_mean_loss += epoch_mean_loss[epoch - (i+1)]

          prev_mean_loss /= n_prev_epochs
          
          if mean_loss > (prev_mean_loss - 0.01*prev_mean_loss):
            # break
            last_epoch = True

        tr_ls.append(tmp_tr_loss / tr_sample)
        # sio.savemat('./Loss/' + DIREC +'.mat', {'loss': tr_ls})

        if not os.path.exists('./current experiment/Train_Output/' + DIREC):
            os.makedirs('./current experiment/Train_Output/' + DIREC)


        score_model.eval()
        if last_epoch:
        # if (epoch+1) >= 100 and ((epoch+1) % 5 == 0):
        # if epoch >= 0:
            with torch.no_grad():
                for batch_idx, (data, target) in enumerate(validloader):
                    if batch_idx == 0:
                        condition = data.to(device)

                        samples1 = sampler_od(score_model, condition, marginal_prob_std_fn, diffusion_coeff_fn, sample_size)
                        samples2 = sampler_eu(score_model, condition, marginal_prob_std_fn, diffusion_coeff_fn, sample_size)
                        samples3 = sampler_pc(score_model, condition, marginal_prob_std_fn, diffusion_coeff_fn, sample_size)

                        diff_out_od = np.array(samples1.cpu())
                        save_path = f'./current experiment/diff_results/x0_od_number_{batch_idx+1}_epoch_{epoch+1}.npy'
                        np.save(save_path, diff_out_od)

                        diff_out_eu = np.array(samples2.cpu())
                        save_path = f'./current experiment/diff_results/x0_eu_number_{batch_idx+1}_epoch_{epoch+1}.npy'
                        np.save(save_path, diff_out_eu)

                        diff_out_pc = np.array(samples3.cpu())
                        save_path = f'./current experiment/diff_results/x0_pc_number_{batch_idx+1}_epoch_{epoch+1}.npy'
                        np.save(save_path, diff_out_pc)
                        # sample visulization
                        samples1 = samples1.clamp(0., 1.)
                        samples2 = samples2.clamp(0., 1.)
                        samples3 = samples3.clamp(0., 1.)

                        fig = plt.figure()
                        # fig.set_figheight(8)
                        fig.set_figheight(4)
                        fig.set_figwidth(20)
                        spec = gridspec.GridSpec(ncols=5, nrows=1,
                                              width_ratios=[1,1,1,1,1], wspace=0.01,
                                              hspace=0.01, height_ratios=[1],left=0,right=1,top=1,bottom=0)
                        ax = fig.add_subplot(spec[0])
                        ax.imshow(data[0].data.squeeze().cpu(), cmap='gray', vmin=0, vmax=1)
                        ax.axis('off')
                        ax = fig.add_subplot(spec[1])
                        ax.imshow(samples1[0].squeeze().cpu(), cmap='gray', vmin=0, vmax=1)
                        ax.axis('off')
                        ax = fig.add_subplot(spec[2])
                        ax.imshow(samples2[0].squeeze().cpu(), cmap='gray', vmin=0, vmax=1)
                        ax.axis('off')
                        ax = fig.add_subplot(spec[3])
                        ax.imshow(samples3[0].squeeze().cpu(), cmap='gray', vmin=0, vmax=1)
                        ax.axis('off')
                        ax = fig.add_subplot(spec[4])
                        ax.imshow(target[0].data.squeeze().cpu(), cmap='gray', vmin=0, vmax=1)
                        ax.axis('off')

                        # ax = fig.add_subplot(spec[5])
                        # ax.imshow(data[1].data.squeeze().cpu(), cmap='gray', vmin=0, vmax=1)
                        # ax.axis('off')
                        # ax = fig.add_subplot(spec[6])
                        # ax.imshow(samples1[1].squeeze().cpu(), cmap='gray', vmin=0, vmax=1)
                        # ax.axis('off')
                        # ax = fig.add_subplot(spec[7])
                        # ax.imshow(samples2[1].squeeze().cpu(), cmap='gray', vmin=0, vmax=1)
                        # ax.axis('off')
                        # ax = fig.add_subplot(spec[8])
                        # ax.imshow(samples3[1].squeeze().cpu(), cmap='gray', vmin=0, vmax=1)
                        # ax.axis('off')
                        # ax = fig.add_subplot(spec[9])
                        # ax.imshow(target[1].data.squeeze().cpu(), cmap='gray', vmin=0, vmax=1)
                        # ax.axis('off')

                        plt.savefig('./current experiment/Train_Output/'+ DIREC + '/Epoch_' + str(epoch+1) + '.png',
                                    bbox_inches='tight', pad_inches=0)
                        
                break

        # # save
        # if not os.path.exists('Save/' + DIREC):
        #     os.makedirs('Save/' + DIREC)
        # ckpt = {
        #     'score_model': score_model.state_dict(),
        #     'ema_model': ema_model.state_dict(),
        #     'optim': optim.state_dict(),
        #     'epoch': epoch+1,
        # }
        # if (epoch+1) % 20 == 0:
        #     torch.save(ckpt, './Save/' + DIREC + '/model_epoch_'+str(epoch+1)+'.pkl')
        # torch.save(ckpt, './Save/' + DIREC + '/model_latest.pkl')


train()

  0%|          | 0/2187 [00:00<?, ?batch/s]

Model params: 24.91 M


  t = torch.tensor(t, device=device)
Epoch 1: 100%|██████████| 2187/2187 [02:24<00:00, 15.16batch/s, Loss=28.6]   
  0%|          | 0/2187 [00:00<?, ?batch/s]

mean loss 247.17520786345577


Epoch 2: 100%|██████████| 2187/2187 [02:21<00:00, 15.48batch/s, Loss=100]    
  0%|          | 0/2187 [00:00<?, ?batch/s]

mean loss 51.683993379438526


Epoch 3: 100%|██████████| 2187/2187 [02:20<00:00, 15.52batch/s, Loss=72.4]   
  0%|          | 0/2187 [00:00<?, ?batch/s]

mean loss 46.86444416076667


Epoch 4: 100%|██████████| 2187/2187 [02:21<00:00, 15.48batch/s, Loss=38.7]   
  0%|          | 0/2187 [00:00<?, ?batch/s]

mean loss 40.58279460633133


Epoch 5: 100%|██████████| 2187/2187 [02:20<00:00, 15.54batch/s, Loss=65]     
  0%|          | 0/2187 [00:00<?, ?batch/s]

mean loss 34.694274219193304


Epoch 6: 100%|██████████| 2187/2187 [02:21<00:00, 15.50batch/s, Loss=194]    
  0%|          | 0/2187 [00:00<?, ?batch/s]

mean loss 32.80211564335694


Epoch 7: 100%|██████████| 2187/2187 [02:21<00:00, 15.49batch/s, Loss=3.7]    
  0%|          | 0/2187 [00:00<?, ?batch/s]

mean loss 27.913708628395895


Epoch 8: 100%|██████████| 2187/2187 [02:20<00:00, 15.53batch/s, Loss=34.7]   
  0%|          | 0/2187 [00:00<?, ?batch/s]

mean loss 27.569816202129683


Epoch 9: 100%|██████████| 2187/2187 [02:21<00:00, 15.44batch/s, Loss=46.3]   
  0%|          | 0/2187 [00:00<?, ?batch/s]

mean loss 27.386405454959615


Epoch 10: 100%|██████████| 2187/2187 [02:21<00:00, 15.41batch/s, Loss=12]     
  0%|          | 0/2187 [00:00<?, ?batch/s]

mean loss 26.150790516295604


Epoch 11: 100%|██████████| 2187/2187 [02:21<00:00, 15.47batch/s, Loss=13.6]   
  0%|          | 0/2187 [00:00<?, ?batch/s]

mean loss 26.046797305577194


Epoch 12: 100%|██████████| 2187/2187 [02:22<00:00, 15.37batch/s, Loss=16.6]   
  0%|          | 0/2187 [00:00<?, ?batch/s]

mean loss 23.393867710217744


Epoch 13: 100%|██████████| 2187/2187 [02:22<00:00, 15.39batch/s, Loss=16.8]   
  0%|          | 0/2187 [00:00<?, ?batch/s]

mean loss 23.099694393571784


Epoch 14: 100%|██████████| 2187/2187 [02:21<00:00, 15.42batch/s, Loss=54]     
  0%|          | 0/2187 [00:00<?, ?batch/s]

mean loss 23.04272699852103


Epoch 15: 100%|██████████| 2187/2187 [02:22<00:00, 15.38batch/s, Loss=56.5]   
  0%|          | 0/2187 [00:00<?, ?batch/s]

mean loss 21.28134956518482


Epoch 16: 100%|██████████| 2187/2187 [02:21<00:00, 15.43batch/s, Loss=111]    
  0%|          | 0/2187 [00:00<?, ?batch/s]

mean loss 21.629968897060113


Epoch 17: 100%|██████████| 2187/2187 [02:21<00:00, 15.49batch/s, Loss=27.3]   
  0%|          | 0/2187 [00:00<?, ?batch/s]

mean loss 21.807750146357392


Epoch 18:  26%|██▌       | 569/2187 [00:39<01:51, 14.53batch/s, Loss=138] 


KeyboardInterrupt: 

In [None]:
epoch = 100
diff_outs = [None] * n_test_new
sampler = 'od'

for i in range(n_test_new):
  diff_out = np.load(f'./current experiment/diff_results/x0_{sampler}_number_{batch_idx+1}_epoch_{epoch+1}.npy')
  diff_out = np.reshape(diff_out, (1, img_size, img_size))
  diff_out = torch.tensor(diff_out)
  diff_outs[i] = diff_out.unsqueeze(0)

In [None]:
from ignite.metrics import PSNR, SSIM
from collections import OrderedDict

import torch
from torch import nn, optim

from ignite.engine import *
from ignite.handlers import *
from ignite.metrics import *
from ignite.utils import *
from ignite.contrib.metrics.regression import *
from ignite.contrib.metrics import *

# create default evaluator for doctests

def eval_step(engine, batch):
    return batch

default_evaluator = Engine(eval_step)

In [None]:
targets = [None] * n_test_new

for i in range(n_test_new):
  ct_sample = ct_test_resized[i]
  ct_sample = np.reshape(ct_sample, (1, img_size, img_size))
  ct_sample = torch.tensor(ct_sample)
  targets[i] = ct_sample.unsqueeze(0)


In [None]:
metric = SSIM(data_range=1.0)
metric.attach(default_evaluator, 'ssim')


ssims_epoch_80 = [None] * n_test_new
sum_ssims = 0

for i in range(n_test_new):
  state = default_evaluator.run([[diff_outs[i], targets[i]]])
  ssim_value = state.metrics['ssim']
  # print(ssim_value)
  sum_ssims += ssim_value

avg_ssim = sum_ssims / n_test_new

avg_ssim