In [None]:
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import TensorDataset, DataLoader
import torchvision.transforms as T

import matplotlib.pyplot as plt
import numpy as np
import os
import time

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

In [None]:
base_dir = '/home/iplab/Desktop/Shaker/Brain tumor MRI and CT scan/data(processed)'

# read mr and ct datasets
MR_train_address = os.path.join(base_dir, 'train_input.npy')
CT_train_address = os.path.join(base_dir, 'train_output.npy')

MR_val_address = os.path.join(base_dir, 'val_input.npy')
CT_val_address = os.path.join(base_dir, 'val_output.npy')

MR_test_address = os.path.join(base_dir, 'test_input.npy')
CT_test_address = os.path.join(base_dir, 'test_output.npy')

mr_train = np.load(MR_train_address)
ct_train = np.load(CT_train_address)

mr_val = np.load(MR_val_address)
ct_val = np.load(CT_val_address)

mr_test = np.load(MR_test_address)
ct_test = np.load(CT_test_address)

ct_train.shape, ct_val.shape, ct_test.shape

In [None]:
# resizes a 1d numpy array to an arbitrary size
def resize(img, size):

  img = img.astype('float32')
  img = torch.tensor(img)
  img = img.unsqueeze(0)

  transform = T.Resize(size)
  resized_img = transform(img)

  # resized_img = np.array(resized_img)

  return resized_img

In [None]:
change_gray_level = T.Compose([
    T.Lambda(lambda t: t * 0.2)
])

horizontal_flip = T.Compose([
    T.functional.hflip
])

vertical_flip = T.Compose([
    T.functional.vflip
])

rotate_45 = T.Compose([
    T.Lambda(lambda t: T.functional.rotate(t, angle=45))
])

rotate_minus_45 = T.Compose([
    T.Lambda(lambda t: T.functional.rotate(t, angle=-45))
])

In [None]:
n_train_samples = ct_train.shape[0]
n_val_samples = ct_val.shape[0]
n_test_samples = ct_test.shape[0]

# remove 90 samples from test data and add it to train data
n_add_from_test_to_train = 69

n_train_new = n_train_samples + n_val_samples + n_add_from_test_to_train
n_test_new = n_test_samples - n_add_from_test_to_train

# data augmentation
# add 5 varient of each sample
# so our train dataset will be 6 times bigger

n_train_new = n_train_new * 6
# n_train_new = n_train_new * 3

mr_train_resized = [None] * n_train_new
ct_train_resized = [None] * n_train_new

mr_test_resized = [None] * n_test_new
ct_test_resized = [None] * n_test_new

# train samples with augmentation
for i in range(n_train_samples):

  j = i * 6
  # j = i*3
  # j = i

  resized_mr = resize(mr_train[i], 128)

  mr_train_resized[j] = resized_mr
  mr_train_resized[j+1] = change_gray_level(resized_mr)
  mr_train_resized[j+2] = horizontal_flip(resized_mr)
  mr_train_resized[j+3] = vertical_flip(resized_mr)
  mr_train_resized[j+4] = rotate_45(resized_mr)
  mr_train_resized[j+5] = rotate_minus_45(resized_mr)

  resized_ct = resize(ct_train[i], 128)

  ct_train_resized[j] = resized_ct
  ct_train_resized[j+1] = change_gray_level(resized_ct)
  ct_train_resized[j+2] = horizontal_flip(resized_ct)
  ct_train_resized[j+3] = vertical_flip(resized_ct)
  ct_train_resized[j+4] = rotate_45(resized_ct)
  ct_train_resized[j+5] = rotate_minus_45(resized_ct)


# validation samples with augmentation
for i in range(n_val_samples):

  j = i*6 + n_train_samples*6
  # j = i*3 + n_train_samples*3
  # j = i + n_train_samples

  resized_mr = resize(mr_val[i], 128)

  mr_train_resized[j] = resized_mr
  mr_train_resized[j+1] = change_gray_level(resized_mr)
  mr_train_resized[j+2] = horizontal_flip(resized_mr)
  mr_train_resized[j+3] = vertical_flip(resized_mr)
  mr_train_resized[j+4] = rotate_45(resized_mr)
  mr_train_resized[j+5] = rotate_minus_45(resized_mr)

  resized_ct = resize(ct_val[i], 128)

  ct_train_resized[j] = resized_ct
  ct_train_resized[j+1] = change_gray_level(resized_ct)
  ct_train_resized[j+2] = horizontal_flip(resized_ct)
  ct_train_resized[j+3] = vertical_flip(resized_ct)
  ct_train_resized[j+4] = rotate_45(resized_ct)
  ct_train_resized[j+5] = rotate_minus_45(resized_ct)


# part of test samples with augmentation

for i in range(n_add_from_test_to_train):

  j = i*6 + n_train_samples*6 + n_val_samples*6
  # j = i*3 + n_train_samples*3 + n_val_samples*3
  # j = i + n_train_samples + n_val_samples

  resized_mr = resize(mr_test[i], 128)

  mr_train_resized[j] = resized_mr
  mr_train_resized[j+1] = change_gray_level(resized_mr)
  mr_train_resized[j+2] = horizontal_flip(resized_mr)
  mr_train_resized[j+3] = vertical_flip(resized_mr)
  mr_train_resized[j+4] = rotate_45(resized_mr)
  mr_train_resized[j+5] = rotate_minus_45(resized_mr)

  resized_ct = resize(ct_test[i], 128)

  ct_train_resized[j] = resized_ct
  ct_train_resized[j+1] = change_gray_level(resized_ct)
  ct_train_resized[j+2] = horizontal_flip(resized_ct)
  ct_train_resized[j+3] = vertical_flip(resized_ct)
  ct_train_resized[j+4] = rotate_45(resized_ct)
  ct_train_resized[j+5] = rotate_minus_45(resized_ct)

# test samples
for i in range(n_test_new):

  j = i + n_add_from_test_to_train

  mr_test_resized[i] = resize(mr_test[j], 128)
  ct_test_resized[i] = resize(ct_test[j], 128)

In [None]:
# convert train and test samples to numpy array
for i in range(n_train_new):
  mr_train_resized[i] = np.array(mr_train_resized[i].squeeze())
  ct_train_resized[i] = np.array(ct_train_resized[i].squeeze())

for i in range(n_test_new):
  mr_test_resized[i] = np.array(mr_test_resized[i].squeeze())
  ct_test_resized[i] = np.array(ct_test_resized[i].squeeze())

In [None]:
# convert lists of mr and ct to numpy arrays
mr_train_resized = np.array(mr_train_resized)

ct_train_resized = np.array(ct_train_resized)

mr_test_resized = np.array(mr_test_resized)

ct_test_resized = np.array(ct_test_resized)


print('train images shape:', mr_train_resized.shape)
print('test images shape:', mr_test_resized.shape)

In [None]:
import h5py

# creating hdf5 data from numpy arrays
with h5py.File('mr_train_resized.hdf5', 'w') as f:
    dset = f.create_dataset("data", data = mr_train_resized)

with h5py.File('ct_train_resized.hdf5', 'w') as f:
    dset = f.create_dataset("data", data = ct_train_resized)

with h5py.File('mr_test_resized.hdf5', 'w') as f:
    dset = f.create_dataset("data", data = mr_test_resized)

with h5py.File('ct_test_resized.hdf5', 'w') as f:
    dset = f.create_dataset("data", data = ct_test_resized)

In [None]:
path = 'mr_train_resized.hdf5'
f = h5py.File(path,'r')
load_data = f['data']
load_data

In [None]:
# train.py
import os
import warnings
import scipy.io as sio
# from absl import app, flags
from tqdm import tqdm
from matplotlib import pyplot as plt
from matplotlib import gridspec
import functools
import numpy as np

import torch
# from torchvision.utils import make_grid
from torch.utils.data import DataLoader
from dataset import Train_Data, Test_Data
from diffusion import ode_sampler, euler_sampler, pc_sampler
from model import UNet


train = True
continue_train = False

# UNet
ch = 64
ch_mult = [1, 2, 2, 4, 4]
attn = [1]
num_res_blocks = 2
dropout = 0.

# Gaussian Diffusion
beta_1 = 1e-4
beta_T = 0.02
T = 1000

# Training
lr = 1e-4
grad_clip = 1.
img_size = 128
batch_size = 2
num_workers = 1
ema_decay = 0.9999

sample_size = 1

min_epoch = 100
max_epoch = 1000

epoch_mean_loss = max_epoch * [None]
n_prev_epochs = 20

DIREC = f'score-unet_min-epoch_{min_epoch}_n-train-samples_{n_train_new}n-test-samples_{n_test_new}_batch-size_{batch_size}_T_{T}_img-size_{img_size}_data_augmentation_all'

device = torch.device('cuda:0')


def train():

    od_sum_time = 0
    eu_sum_time = 0
    pc_sum_time = 0
    last_epoch_num = 0

    sigma = 25.
    marginal_prob_std_fn = functools.partial(marginal_prob_std, sigma=sigma) # construc function without parameters
    diffusion_coeff_fn = functools.partial(diffusion_coeff, sigma=sigma) # construc function without parameters

    # dataset
    train_data = Train_Data()
    train_loader = DataLoader(train_data, batch_size=batch_size, num_workers=num_workers,
                             pin_memory=True, shuffle=True)
    test_data = Test_Data()
    test_loader = DataLoader(test_data, batch_size=sample_size, num_workers=num_workers,
                             pin_memory=True, shuffle=False)

    # model setup
    score_model = UNet(T=T, ch=ch, ch_mult=ch_mult, attn=attn,
                       num_res_blocks=num_res_blocks, dropout=dropout,
                       marginal_prob_std=marginal_prob_std_fn).to(device)

    ema_model = EMA(score_model).to(device)

    optim = torch.optim.Adam(score_model.parameters(), lr=lr)

    # sampler setup
    sampler_od = ode_sampler
    sampler_eu = euler_sampler
    sampler_pc = pc_sampler

    # show model size
    model_size = 0
    for param in score_model.parameters():
        model_size += param.data.nelement()
    print('Model params: %.2f M' % (model_size / 1024 / 1024))

    if continue_train:
        checkpoint = torch.load('./Save/' + DIREC + '/model_latest.pkl')
        score_model.load_state_dict(checkpoint['score_model'])
        ema_model.load_state_dict(checkpoint['ema_model'])
        optim.load_state_dict(checkpoint['optim'])
        restore_epoch = checkpoint['epoch']
        print('Finish loading model')
    else:
        restore_epoch = 0

    if not os.path.exists('current experiment'):
        os.makedirs('current experiment')

    tr_ls = []
    if continue_train:
        readmat = sio.loadmat('./Loss/' + DIREC)
        load_tr_ls = readmat['loss']
        for i in range(restore_epoch):
            tr_ls.append(load_tr_ls[0][i])
        print('Finish loading loss!')


    if not os.path.exists('./current experiment/diff_results'):
        os.makedirs('./current experiment/diff_results')

    last_epoch = False

    for epoch in range(restore_epoch, max_epoch):
        with tqdm(train_loader, unit="batch") as tepoch:
            tmp_tr_loss = 0
            tr_sample = 0
            score_model.train()
            for data, target in tepoch:
                tepoch.set_description(f"Epoch {epoch+1}")

                # train
                condition = data.to(device)
                x_0 = target.to(device)

                loss = loss_fn(score_model, condition, x_0, marginal_prob_std_fn)

                tmp_tr_loss += loss.item()
                tr_sample += len(data)

                optim.zero_grad()
                loss.backward()
                optim.step()
                ema_model.update(score_model)

                tepoch.set_postfix({'Loss': loss.item()})

        mean_loss = tmp_tr_loss / tr_sample
        print('mean loss:', mean_loss)

        epoch_mean_loss[epoch] = mean_loss
        
        if epoch+1 > min_epoch:
          prev_mean_loss = 0
          
          for i in range(n_prev_epochs):
            prev_mean_loss += epoch_mean_loss[epoch - (i+1)]

          prev_mean_loss /= n_prev_epochs
          
          if mean_loss > (prev_mean_loss - 0.01*prev_mean_loss):
            # break
            last_epoch = True

        tr_ls.append(tmp_tr_loss / tr_sample)

        if not os.path.exists('./current experiment/Train_Output/' + DIREC):
            os.makedirs('./current experiment/Train_Output/' + DIREC)


        score_model.eval()
        if last_epoch:
            last_epoch_num = epoch+1
            with torch.no_grad():
                for idx, (data, target) in enumerate(test_loader):
                        condition = data.to(device)

                        tic_od = time.time()
                        samples1 = sampler_od(score_model, condition, marginal_prob_std_fn, diffusion_coeff_fn, sample_size)
                        toc_od = time.time()
                        time_interval_od = toc_od - tic_od
                        od_sum_time += time_interval_od

                        tic_eu = time.time()
                        samples2 = sampler_eu(score_model, condition, marginal_prob_std_fn, diffusion_coeff_fn, sample_size)
                        toc_eu = time.time()
                        time_interval_eu = toc_eu - tic_eu
                        eu_sum_time += time_interval_eu

                        tic_pc = time.time()
                        samples3 = sampler_pc(score_model, condition, marginal_prob_std_fn, diffusion_coeff_fn, sample_size)
                        toc_pc = time.time()
                        time_interval_pc = toc_pc - tic_pc
                        pc_sum_time += time_interval_pc

                        diff_out_od = np.array(samples1.cpu())
                        save_path = f'./current experiment/diff_results/x0_od_number_{idx+1}_epoch_{epoch+1}.npy'
                        np.save(save_path, diff_out_od)

                        diff_out_eu = np.array(samples2.cpu())
                        save_path = f'./current experiment/diff_results/x0_eu_number_{idx+1}_epoch_{epoch+1}.npy'
                        np.save(save_path, diff_out_eu)

                        diff_out_pc = np.array(samples3.cpu())
                        save_path = f'./current experiment/diff_results/x0_pc_number_{idx+1}_epoch_{epoch+1}.npy'
                        np.save(save_path, diff_out_pc)
                        # sample visulization
                        samples1 = samples1.clamp(0., 1.)
                        samples2 = samples2.clamp(0., 1.)
                        samples3 = samples3.clamp(0., 1.)

                        fig = plt.figure()
                        fig.set_figheight(4)
                        fig.set_figwidth(20)
                        spec = gridspec.GridSpec(ncols=5, nrows=1,
                                              width_ratios=[1,1,1,1,1], wspace=0.01,
                                              hspace=0.01, height_ratios=[1],left=0,right=1,top=1,bottom=0)
                        ax = fig.add_subplot(spec[0])
                        ax.imshow(data[0].data.squeeze().cpu(), cmap='gray', vmin=0, vmax=1)
                        ax.axis('off')
                        ax = fig.add_subplot(spec[1])
                        ax.imshow(samples1[0].squeeze().cpu(), cmap='gray', vmin=0, vmax=1)
                        ax.axis('off')
                        ax = fig.add_subplot(spec[2])
                        ax.imshow(samples2[0].squeeze().cpu(), cmap='gray', vmin=0, vmax=1)
                        ax.axis('off')
                        ax = fig.add_subplot(spec[3])
                        ax.imshow(samples3[0].squeeze().cpu(), cmap='gray', vmin=0, vmax=1)
                        ax.axis('off')
                        ax = fig.add_subplot(spec[4])
                        ax.imshow(target[0].data.squeeze().cpu(), cmap='gray', vmin=0, vmax=1)
                        ax.axis('off')


                        plt.savefig('./current experiment/Train_Output/'+ DIREC + '/Epoch_' + str(epoch+1) + '.png',
                                    bbox_inches='tight', pad_inches=0)
                        

                break

    return od_sum_time, eu_sum_time, pc_sum_time, last_epoch_num


In [None]:
od_sum_time, eu_sum_time, pc_sum_time, last_epoch_num = train()

In [None]:
od_avg_time = od_sum_time / (n_test_new * last_epoch_num)
eu_avg_time = eu_sum_time / (n_test_new * last_epoch_num)
pc_avg_time = pc_sum_time / (n_test_new * last_epoch_num)

od_avg_time, eu_avg_time, pc_avg_time

In [None]:
diff_outs = [None] * n_test_new
sampler = 'pc'
last_epoch_num = 102
img_size = 128

for i in range(n_test_new):
  diff_out = np.load(f'./current experiment/diff_results/x0_{sampler}_number_{i+1}_epoch_{last_epoch_num}.npy')
  diff_out = np.reshape(diff_out, (1, img_size, img_size))
  diff_out = torch.tensor(diff_out)
  diff_outs[i] = diff_out.unsqueeze(0)

In [None]:
from ignite.metrics import PSNR, SSIM
from collections import OrderedDict

import torch
from torch import nn, optim

from ignite.engine import *
from ignite.handlers import *
from ignite.metrics import *
from ignite.utils import *
from ignite.contrib.metrics.regression import *
from ignite.contrib.metrics import *

# create default evaluator for doctests

def eval_step(engine, batch):
    return batch

default_evaluator = Engine(eval_step)

In [None]:
targets = [None] * n_test_new

for i in range(n_test_new):
  ct_sample = ct_test_resized[i]
  ct_sample = np.reshape(ct_sample, (1, img_size, img_size))
  ct_sample = torch.tensor(ct_sample)
  targets[i] = ct_sample.unsqueeze(0)


In [None]:
diff_out_sample = diff_outs[0]
targets_sample = targets[0]

(diff_out_sample.dtype), (targets_sample.dtype)

In [None]:
metric = SSIM(data_range=1.0)
metric.attach(default_evaluator, 'ssim')


sum_ssims = 0

for i in range(n_test_new):
  state = default_evaluator.run([[diff_outs[i].float(), targets[i].float()]])
  ssim_value = state.metrics['ssim']
  # print(ssim_value)
  sum_ssims += ssim_value

avg_ssim = sum_ssims / n_test_new

avg_ssim

In [None]:
metric = PSNR(data_range=1.0)
metric.attach(default_evaluator, 'psnr')


sum_psnrs = 0

for i in range(n_test_new):
  state = default_evaluator.run([[diff_outs[i].float(), targets[i].float()]])
  psnr_value = state.metrics['psnr']
  # print(ssim_value)
  sum_psnrs += psnr_value

avg_psnr = sum_psnrs / n_test_new

avg_psnr