In [1]:
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import TensorDataset, DataLoader
import torchvision.transforms as T

import matplotlib.pyplot as plt
import numpy as np
import os
import time
import h5py

from diffusion import loss_fn, marginal_prob_std, diffusion_coeff, EMA, euler_sampler, pc_sampler, ode_sampler
from utils import get_ssim, get_psnr, get_mse, get_mae, save_best_samples, save_metrics, pre_process

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [3]:
path = 'mr_train_resized.hdf5'
f = h5py.File(path,'r')
mr_train = f['data']

path = 'mr_test_resized.hdf5'
f = h5py.File(path,'r')
mr_test = f['data']

train_size = mr_train.shape[0]
test_size = mr_test.shape[0]

train_size, test_size

(4374, 81)

In [4]:
path = 'ct_test_resized.hdf5'
f = h5py.File(path,'r')
ct_test = f['data']

ct_test

<HDF5 dataset "data": shape (81, 128, 128), type "<f4">

In [5]:
import os
import warnings
import scipy.io as sio
# from absl import app, flags
from tqdm import tqdm
from matplotlib import pyplot as plt
from matplotlib import gridspec
import functools
import numpy as np

import torch
# from torchvision.utils import make_grid
from torch.utils.data import DataLoader
from dataset import Train_Data, Test_Data
from diffusion import ode_sampler, euler_sampler, pc_sampler
from model import UNet


train = True
continue_train = False

# UNet
ch = 64
ch_mult = [1, 2, 2, 4, 4]
attn = [1]
num_res_blocks = 2
dropout = 0.

# Gaussian Diffusion
beta_1 = 1e-4
beta_T = 0.02
T = 1000

# Training
lr = 1e-4
grad_clip = 1.
img_size = 128
batch_size = 2
num_workers = 1
ema_decay = 0.9999

sample_size = 1

min_epoch = 100
max_epoch = 110

epoch_mean_loss = max_epoch * [None]
n_prev_epochs = 20

DIREC = f'score-unet_min-epoch_{min_epoch}_n-train-samples_{train_size}n-test-samples_{test_size}_batch-size_{batch_size}_T_{T}_img-size_{img_size}_data_augmentation_all'

device = torch.device('cuda:0')

In [6]:

def train():

    sigma = 25.
    marginal_prob_std_fn = functools.partial(marginal_prob_std, sigma=sigma) # construc function without parameters
    diffusion_coeff_fn = functools.partial(diffusion_coeff, sigma=sigma) # construc function without parameters

    # dataset
    train_data = Train_Data()
    train_loader = DataLoader(train_data, batch_size=batch_size, num_workers=num_workers,
                             pin_memory=True, shuffle=True)

    # model setup
    score_model = UNet(T=T, ch=ch, ch_mult=ch_mult, attn=attn,
                       num_res_blocks=num_res_blocks, dropout=dropout,
                       marginal_prob_std=marginal_prob_std_fn).to(device)

    ema_model = EMA(score_model).to(device)

    optim = torch.optim.Adam(score_model.parameters(), lr=lr)

    # show model size
    model_size = 0
    for param in score_model.parameters():
        model_size += param.data.nelement()
    print('Model params: %.2f M' % (model_size / 1024 / 1024))


    if not os.path.exists('current experiment'):
        os.makedirs('current experiment')

    if not os.path.exists('./current experiment/Saved_model'):
        os.makedirs('./current experiment/Saved_model')

    last_epoch = False

    for epoch in range(max_epoch):
        with tqdm(train_loader, unit="batch") as tepoch:
            tmp_tr_loss = 0
            tr_sample = 0
            score_model.train()
            for data, target in tepoch:
                tepoch.set_description(f"Epoch {epoch+1}")

                # train
                condition = data.to(device)
                x_0 = target.to(device)

                loss = loss_fn(score_model, condition, x_0, marginal_prob_std_fn)

                tmp_tr_loss += loss.item()
                tr_sample += len(data)

                optim.zero_grad()
                loss.backward()
                optim.step()
                ema_model.update(score_model)

                tepoch.set_postfix({'Loss': loss.item()})

        mean_loss = tmp_tr_loss / tr_sample
        print('mean loss:', mean_loss)

        epoch_mean_loss[epoch] = mean_loss
        
        if epoch+1 > min_epoch:
          prev_mean_loss = 0
          
          for i in range(n_prev_epochs):
            prev_mean_loss += epoch_mean_loss[epoch - (i+1)]

          prev_mean_loss /= n_prev_epochs
          
          if mean_loss > (prev_mean_loss - 0.01*prev_mean_loss):
            break

    torch.save(score_model.state_dict(), f'./current experiment/Saved_model/score-unet_epoch_{epoch+1}.pt')

    return epoch+1

In [None]:
last_epoch_num = train()

In [None]:
def test(sampler_type):
    
    if not os.path.exists(f'./current experiment/Train_Output_{sampler_type}/' + DIREC):
        os.makedirs(f'./current experiment/Train_Output_{sampler_type}/' + DIREC)

    if not os.path.exists(f'./current experiment/diff_results_{sampler_type}'):
        os.makedirs(f'./current experiment/diff_results_{sampler_type}')

    sigma = 25.
    sum_time = 0


    marginal_prob_std_fn = functools.partial(marginal_prob_std, sigma=sigma) # construc function without parameters
    diffusion_coeff_fn = functools.partial(diffusion_coeff, sigma=sigma) # construc function without parameters


    test_data = Test_Data()
    test_loader = DataLoader(test_data, batch_size=sample_size, num_workers=num_workers,
                             pin_memory=True, shuffle=False)
    
    score_model = UNet(T=T, ch=ch, ch_mult=ch_mult, attn=attn,
                    num_res_blocks=num_res_blocks, dropout=dropout,
                    marginal_prob_std=marginal_prob_std_fn).to(device)

    # ema_model = EMA(score_model).to(device)

    model_path = f'./current experiment/Saved_model/score-unet_epoch_{last_epoch_num}.pt'

    score_model.load_state_dict(torch.load(model_path))


    if sampler_type == 'od':
        sampler = ode_sampler
    elif sampler_type == 'eu':
        sampler = euler_sampler
    elif sampler_type == 'pc':
        sampler = pc_sampler
    else:
        print('unvaild value for sampler. valid values: od, eu, pc')


    with torch.no_grad():
        for idx, (data, target) in enumerate(test_loader):
                condition = data.to(device)

                tic = time.time()
                samples = sampler(score_model, condition, marginal_prob_std_fn, diffusion_coeff_fn, sample_size)
                toc = time.time()
                time_interval = toc - tic
                sum_time += time_interval

                diff_out = np.array(samples.cpu())
                save_path = f'./current experiment/diff_results_{sampler_type}/x0_number_{idx+1}_epoch_{last_epoch_num}.npy'
                np.save(save_path, diff_out)
                # sample visulization
                samples = samples.clamp(0., 1.)

                fig = plt.figure()
                fig.set_figheight(4)
                fig.set_figwidth(20)
                spec = gridspec.GridSpec(ncols=3, nrows=1,
                                        width_ratios=[1,1,1], wspace=0.01,
                                        hspace=0.01, height_ratios=[1],left=0,right=1,top=1,bottom=0)
                ax = fig.add_subplot(spec[0])
                ax.imshow(data[0].data.squeeze().cpu(), cmap='gray', vmin=0, vmax=1)
                ax.axis('off')

                ax = fig.add_subplot(spec[1])
                ax.imshow(samples[0].squeeze().cpu(), cmap='gray', vmin=0, vmax=1)
                ax.axis('off')
            
                ax = fig.add_subplot(spec[2])
                ax.imshow(target[0].data.squeeze().cpu(), cmap='gray', vmin=0, vmax=1)
                ax.axis('off')


                plt.savefig(f'./current experiment/Train_Output_{sampler_type}/'+ DIREC + '/sample_' + str(idx) + '.png',
                            bbox_inches='tight', pad_inches=0)

                plt.close()
                


    return sum_time


In [None]:
od_sum_time = test(sampler_type='od')
eu_sum_time = test(sampler_type='eu')
pc_sum_time = test(sampler_type='pc')

In [22]:
# od_sum_time = 536.8354086875916
# eu_sum_time = 1018.3237380981445
pc_sum_time = 12047.802762508392

In [23]:
# avg_time = od_sum_time / test_size
# avg_time = eu_sum_time / test_size
avg_time = pc_sum_time / test_size
avg_time

148.73830570998015

In [24]:
diff_outs = [None] * test_size
sampler_type = 'pc'
last_epoch_num = 101

for i in range(test_size):

  path = f'./current experiment/diff_results_{sampler_type}/x0_number_{i+1}_epoch_{last_epoch_num}.npy'
  diff_out = np.load(path)
  diff_outs[i] = pre_process(diff_out, img_size).float()

In [25]:
targets = [None] * test_size

for i in range(test_size):
  ct_sample = ct_test[i]
  targets[i] = pre_process(ct_sample, img_size).float()


In [26]:
max_ssim, argmax_ssim, avg_ssim = get_ssim(diff_outs, targets)
max_psnr, argmax_psnr, avg_psnr = get_psnr(diff_outs, targets)

min_mse, argmin_mse, avg_mse = get_mse(diff_outs, targets)
min_mae, argmin_mae, avg_mae = get_mae(diff_outs, targets)

save_best_samples(sampler_type, DIREC, diff_outs, targets, argmax_ssim, argmax_psnr, argmin_mse, argmin_mae)
save_metrics(sampler_type, avg_time, avg_ssim, avg_psnr, avg_mse, avg_mae, max_ssim, max_psnr, min_mse, min_mae)