In [None]:
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import TensorDataset, DataLoader
import torchvision.transforms as T

import matplotlib.pyplot as plt
import numpy as np
import os
import time

from utils import resize, convert_to_tensor, norm, fill_nan_reshape, pre_process
from utils import get_ssim, get_psnr, get_mse, get_mae, save_best_samples, save_metrics

In [None]:
img_size = 128

device = "cuda" if torch.cuda.is_available() else "cpu"
device

In [None]:
change_gray_level = T.Compose([
    T.Lambda(lambda t: t * 0.2)
])

horizontal_flip = T.Compose([
    T.functional.hflip
])

vertical_flip = T.Compose([
    T.functional.vflip
])

rotate_45 = T.Compose([
    T.Lambda(lambda t: T.functional.rotate(t, angle=45))
])

rotate_minus_45 = T.Compose([
    T.Lambda(lambda t: T.functional.rotate(t, angle=-45))
])

In [None]:
# read CERMEP dataset
data_path = '../Brain_tumor_MR_CT.npz'
data_loaded = np.load(data_path)

data_pairs = data_loaded["arr_0"]

data_pairs.shape

In [None]:
import random


train_percent = 0.9

n_samples = data_pairs.shape[0]
train_size = int(train_percent*n_samples)
test_size = n_samples - train_size


test_indices = []

# loop until the list has the desired size
while len(test_indices) < test_size:
  # generate a random number between 1 and 100
  num = random.randint(0, n_samples-1)
  # check if the number is already in the list
  if num not in test_indices:
    # add the number to the list
    test_indices.append(num)

all_indices = list(range(n_samples))

train_indices_filter = filter(lambda i: i not in test_indices, all_indices)

train_indices = list(train_indices_filter)

train_data = data_pairs[train_indices]
test_data = data_pairs[test_indices]


train_data.shape, test_data.shape

In [None]:
mr_train = []
ct_train = []

mr_test = []
ct_test = []


for train_pair in train_data:

    mr = train_pair[0]
    mr = fill_nan_reshape(mr, img_size)

    ct = train_pair[1]
    ct = fill_nan_reshape(ct, img_size)

    mr_train.append(mr)
    ct_train.append(ct)


for test_pair in test_data:

    mr = test_pair[0]
    mr = fill_nan_reshape(mr, img_size)

    ct = test_pair[1]
    ct = fill_nan_reshape(ct, img_size)

    mr_test.append(mr)
    ct_test.append(ct)


mr_train = np.array(mr_train)
mr_test = np.array(mr_test)
ct_train = np.array(ct_train)
ct_test = np.array(ct_test)

mr_train.shape, mr_test.shape, ct_train.shape, ct_test.shape

In [None]:
# augmentation for training data

# we want to do 5 augmentations
# so our training data size will be 6 times bigger
train_size_augmented = train_size * 6

mr_train_augmented = [None] * train_size_augmented
ct_train_augmented = [None] * train_size_augmented


# augment training data
for i in range(train_size):

  j = i * 6

  mr = convert_to_tensor(mr_train[i])

  mr_train_augmented[j] = mr
  mr_train_augmented[j+1] = change_gray_level(mr)
  mr_train_augmented[j+2] = horizontal_flip(mr)
  mr_train_augmented[j+3] = vertical_flip(mr)
  mr_train_augmented[j+4] = rotate_45(mr)
  mr_train_augmented[j+5] = rotate_minus_45(mr)

  ct = convert_to_tensor(ct_train[i])

  ct_train_augmented[j] = ct
  ct_train_augmented[j+1] = change_gray_level(ct)
  ct_train_augmented[j+2] = horizontal_flip(ct)
  ct_train_augmented[j+3] = vertical_flip(ct)
  ct_train_augmented[j+4] = rotate_45(ct)
  ct_train_augmented[j+5] = rotate_minus_45(ct)


# convert train samples to numpy array and normalize them
for i in range(train_size_augmented):

  mr_train_augmented[i] = np.array(mr_train_augmented[i].squeeze())
  ct_train_augmented[i] = np.array(ct_train_augmented[i].squeeze())

  mr_train_augmented[i] = norm(mr_train_augmented[i])
  ct_train_augmented[i] = norm(ct_train_augmented[i])


# convert lists of mr and ct (train data) to numpy arrays
mr_train_augmented = np.array(mr_train_augmented)
ct_train_augmented = np.array(ct_train_augmented)


# normalize test samples
for i in range(test_size):
  mr_test[i] = norm(mr_test[i])
  ct_test[i] = norm(ct_test[i])


print('train images shape:', mr_train_augmented.shape)
print('test images shape:', mr_test.shape)

In [None]:
normalized = True

# for sample in mr_train_augmented:
for sample in mr_test:
    if sample.max() != 1.0:
        normalized = False

normalized

In [None]:
# for Datasets with augmentation 
import h5py

# creating hdf5 data from numpy arrays
with h5py.File('mr_train_resized.hdf5', 'w') as f:
    dset = f.create_dataset("data", data = mr_train_augmented)

with h5py.File('ct_train_resized.hdf5', 'w') as f:
    dset = f.create_dataset("data", data = ct_train_augmented)

with h5py.File('mr_test_resized.hdf5', 'w') as f:
    dset = f.create_dataset("data", data = mr_test)

with h5py.File('ct_test_resized.hdf5', 'w') as f:
    dset = f.create_dataset("data", data = ct_test)

In [None]:
# for Datasets without augmentation 

import h5py

# creating hdf5 data from numpy arrays
with h5py.File('mr_train_resized.hdf5', 'w') as f:
    dset = f.create_dataset("data", data = mr_train)

with h5py.File('ct_train_resized.hdf5', 'w') as f:
    dset = f.create_dataset("data", data = ct_train)

with h5py.File('mr_test_resized.hdf5', 'w') as f:
    dset = f.create_dataset("data", data = mr_test)

with h5py.File('ct_test_resized.hdf5', 'w') as f:
    dset = f.create_dataset("data", data = ct_test)

In [None]:
path = 'mr_train_resized.hdf5'
f = h5py.File(path,'r')
load_data = f['data']
load_data

In [None]:
# train.py DDPM folder
import sys
import copy
import os
import warnings
import scipy.io as sio
from absl import app, flags
from tqdm import tqdm
from matplotlib import pyplot as plt
from matplotlib import gridspec

import torch
from torch.utils.data import DataLoader

from diffusion import GaussianDiffusionTrainer, GaussianDiffusionSampler
from model import UNet
from dataset import Train_Data, Test_Data


train = True

# UNet
ch = 64
ch_mult = [1, 2, 2, 4, 4]
attn = [1]
num_res_blocks = 2
dropout = 0.

# Gaussian Diffusion
beta_1 = 1e-4
beta_T = 0.02
T = 1000

# Training
lr = 1e-4
grad_clip = 1.
img_size = 128
batch_size = 2
num_workers = 1
ema_decay = 0.9999

sample_size = 1

min_epoch = 100
max_epoch = 110
n_prev_epochs = 20

epoch_mean_loss = max_epoch * [None]
# Logging & Sampling
DIREC = f'ddpm-unet_n-train-samples_{train_size_augmented}_n-test-samples_{test_size}_batch-size_{batch_size}_T_{T}_img-size_{img_size}_atlas_data'

device = torch.device('cuda:0')


def ema(source, target, decay):
    source_dict = source.state_dict()
    target_dict = target.state_dict()
    for key in source_dict.keys():
        target_dict[key].data.copy_(
            target_dict[key].data * decay +
            source_dict[key].data * (1 - decay))

In [None]:
def train():

    # dataset
    train_data = Train_Data()
    train_loader = DataLoader(train_data, batch_size=batch_size, num_workers=num_workers,
                             pin_memory=True, shuffle=True)

    # model setup
    net_model = UNet(
        T=T, ch=ch, ch_mult=ch_mult, attn=attn,
        num_res_blocks=num_res_blocks, dropout=dropout)
    ema_model = copy.deepcopy(net_model)

    net_model.to(device)
    ema_model.to(device)


    optim = torch.optim.Adam(net_model.parameters(), lr=lr)

    trainer = GaussianDiffusionTrainer(
        net_model, beta_1, beta_T, T).to(device)


    # show model size
    model_size = 0
    for param in net_model.parameters():
        model_size += param.data.nelement()
    print('Model params: %.2f M' % (model_size / 1024 / 1024))


    if not os.path.exists('current experiment'):
        os.makedirs('current experiment')

    if not os.path.exists('./current experiment/Saved_model'):
        os.makedirs('./current experiment/Saved_model')


    for epoch in range(max_epoch):
        with tqdm(train_loader, unit="batch") as tepoch:
            tmp_tr_loss = 0
            tr_sample = 0
            net_model.train()
            for data, target in tepoch:
                tepoch.set_description(f"Epoch {epoch+1}")

                # train
                optim.zero_grad()
                condition = data.to(device)
                x_0 = target.to(device)

                loss = trainer(x_0, condition)
                tmp_tr_loss += loss.item()
                tr_sample += len(data)

                loss.backward()

                torch.nn.utils.clip_grad_norm_(
                    net_model.parameters(), grad_clip)
                optim.step()
                ema(net_model, ema_model, ema_decay)

                tepoch.set_postfix({'Loss': loss.item()})
        
        mean_loss = tmp_tr_loss / tr_sample
        print('mean loss', mean_loss)

        epoch_mean_loss[epoch] = mean_loss
        
        if epoch+1 > min_epoch:
          prev_mean_loss = 0
          
          for i in range(n_prev_epochs):
            prev_mean_loss += epoch_mean_loss[epoch - (i+1)]

          prev_mean_loss /= n_prev_epochs
          
          if mean_loss > (prev_mean_loss - 0.01*prev_mean_loss):
            break        

    torch.save(ema_model.state_dict(), f'./current experiment/Saved_model/ddpm-unet_epoch_{epoch+1}.pt')
    
    return epoch+1

In [None]:
last_epoch_num = train()

In [None]:
def test():

    if not os.path.exists('./current experiment/Train_Output/' + DIREC):
        os.makedirs('./current experiment/Train_Output/' + DIREC)

    if not os.path.exists('./current experiment/diff_results'):
        os.makedirs('./current experiment/diff_results')

    ddpm_sum_time = 0
    
    test_data = Test_Data()
    test_loader = DataLoader(test_data, batch_size=sample_size, num_workers=num_workers,
                                pin_memory=True, shuffle=False)

    net_model = UNet(
    T=T, ch=ch, ch_mult=ch_mult, attn=attn,
    num_res_blocks=num_res_blocks, dropout=dropout)
    
    ema_model = copy.deepcopy(net_model)

    model_path = f'./current experiment/Saved_model/ddpm-unet_epoch_{last_epoch_num}.pt'

    ema_model.load_state_dict(torch.load(model_path))

    ema_sampler = GaussianDiffusionSampler(
    ema_model, beta_1, beta_T, T, img_size).to(device)


    with torch.no_grad():
        for idx, (data, target) in enumerate(test_loader):

                x_T = torch.randn(sample_size, 1, img_size, img_size)
                x_T = x_T.to(device)

                condition = data.to(device)

                tic_ddpm = time.time()
                x_0 = ema_sampler(x_T, condition)
                toc_ddpm = time.time()

                time_interval_ddpm = toc_ddpm - tic_ddpm
                ddpm_sum_time += time_interval_ddpm
                
                ddpm_out = x_0[-1]
                ddpm_out = np.array(ddpm_out.cpu())

                save_path = f'./current experiment/diff_results/x0_number_{idx+1}_epoch_{last_epoch_num}.npy'
                np.save(save_path, ddpm_out)

                fig = plt.figure()
                fig.set_figheight(8)
                fig.set_figwidth(28)
                spec = gridspec.GridSpec(ncols=7, nrows=2,
                        width_ratios=[1,1,1,1,1,1,1], wspace=0.01,
                        hspace=0.01, height_ratios=[1,1],left=0,right=1,top=1,bottom=0)

                img = data[0].data.squeeze()
                ax = fig.add_subplot(spec[0])
                ax.imshow(img, cmap='gray', vmin=0,vmax=1)
                ax.axis('off')

                count = 1
                for kk in range(5): # x_0 [5,b,1,h,w]
                    imgs = x_0[kk] # imgs [b,1,h,w]
                    img = imgs[0].data.squeeze().cpu()
                    ax = fig.add_subplot(spec[count])
                    ax.imshow(img, cmap='gray', vmin=0,vmax=1)
                    ax.axis('off')

                    count += 1

                img = target[0].data.squeeze().cpu()
                ax = fig.add_subplot(spec[6])
                ax.imshow(img, cmap='gray', vmin=0,vmax=1)
                ax.axis('off')

                plt.savefig('./current experiment/Train_Output/'+ DIREC + '/sample_' + str(idx) + '.png',
                            bbox_inches='tight', pad_inches=0)
                plt.close(fig)

    return ddpm_sum_time

In [None]:
ddpm_sum_time = test()

In [None]:
ddpm_avg_time = ddpm_sum_time / test_size
ddpm_avg_time

In [None]:
last_epoch_num = 102
DIREC = f'ddpm-unet_n-train-samples_{train_size_augmented}_n-test-samples_{test_size}_batch-size_{batch_size}_T_{T}_img-size_{img_size}_atlas_data'

In [None]:
diff_outs = [None] * test_size

for i in range(test_size):
  path = f'./current experiment/diff_results/x0_number_{i+1}_epoch_{last_epoch_num}.npy'
  diff_out = np.load(path)
  diff_outs[i] = pre_process(diff_out, img_size)

In [None]:
targets = [None] * test_size

for i in range(test_size):
  ct_sample = ct_test[i]
  targets[i] = pre_process(ct_sample, img_size)

In [None]:
# save average and best values of ssim, psnr, mse, mae
# save samples that have best values of ssim, psnr, mse, mae in seperate folders
max_ssim, argmax_ssim, avg_ssim = get_ssim(diff_outs, targets)
max_psnr, argmax_psnr, avg_psnr = get_psnr(diff_outs, targets)

min_mse, argmin_mse, avg_mse = get_mse(diff_outs, targets)
min_mae, argmin_mae, avg_mae = get_mae(diff_outs, targets)

sampler_type = 'ddpm'

save_best_samples(sampler_type, DIREC, diff_outs, targets, argmax_ssim, argmax_psnr, argmin_mse, argmin_mae)
save_metrics(sampler_type, ddpm_avg_time, avg_ssim, avg_psnr, avg_mse, avg_mae, max_ssim, max_psnr, min_mse, min_mae)
