# Generating synthetic MRI data using diffusion models

This notebook loads and preprocesses T2 weighted MR images of patients with rectal cancer. Preprocessing consists of normalizing and downsampling the images. Cropping is also possible. Then a diffusion model is defined and subsequently trained. The following tutorial from the MONAI-generative framework have been used when defining and training the diffusion model:  
https://github.com/Project-MONAI/GenerativeModels/blob/main/tutorials/generative/2d_ddpm/2d_ddpm_tutorial.ipynb

### Importing necessary libraries

In [None]:
import os
import shutil
import tempfile
import time

import matplotlib.pyplot as plt
import numpy as np
import nibabel as nib
import torch
import torch.nn.functional as F
from monai.transforms import Compose, LoadImage, ToTensor, ScaleIntensity, CenterSpatialCrop, Resize, EnsureChannelFirst, RandAffined, SaveImage,Rotate90
from monai.apps import MedNISTDataset
from monai.config import print_config
from monai.data import CacheDataset, DataLoader, Dataset, ArrayDataset
from monai.utils import first, set_determinism
from torch.cuda.amp import GradScaler, autocast
from tqdm import tqdm
from torch.utils.data import ConcatDataset,random_split
from torch.utils.tensorboard import SummaryWriter

from generative.inferers import DiffusionInferer
from generative.networks.nets import DiffusionModelUNet
from generative.networks.schedulers import DDPMScheduler
import monai
import cv2
from datetime import datetime
#from PIL import Image

print_config()

In [None]:
torch.cuda.empty_cache()

Checking GPU

In [None]:
!nvidia-smi

### Loading and preprocessing MRI dataset

PREPROCESSING: Scaling the intensity to interval [0, 1]. Cropping the image and resizing it. Image should now be of size 128 x 128

In [None]:
image_transform = Compose(
    [LoadImage(image_only = True),
     EnsureChannelFirst(),
     ToTensor(),
     #ScaleIntensity(minv = 0.0, maxv = 1.0),
     #CenterSpatialCrop(roi_size=(256,256,-1)),
     Resize(spatial_size=(128, 128, -1))
     #Resize(spatial_size=(64, 64, -1))
     ])

image_transforms = Compose(
    [
     ScaleIntensity(minv = 0.0, maxv = 1.0),
     Rotate90(k=3, spatial_axes=(0, 1), lazy=False),
     ])

Creating NiFTIDataset-class that inherits from monai.Dataset. Each dataelement is a niftifile. If transform is applied, only the images are loaded (any additional info removed). 
Hence, each dataelement a stack of images. The function extract_slices collects the image slices and creates a new dataset only consisting of these 2D images. All sluces, except for edge slices, are kept for further training. 

In [None]:
class NiFTIDataset(Dataset):
    def __init__(self, data_dir, transform = None):
        self.data_dir = data_dir
        self.data = os.listdir(data_dir)
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        nifti_file = os.path.join(self.data_dir, self.data[index])
        if self.transform is not None:
            nifti_file = self.transform(nifti_file)
        return nifti_file

### Loading more images into the dataset

In [None]:
'''Function that extracts 2D slices and creates one large dataset consisting of these '''
def extract_slices(nifti_dataset): 
    total_dataset = Dataset([])
    for i in range(len(nifti_dataset)):
        image_stack = Dataset(nifti_dataset).__getitem__(index = i)
        for j in range(image_stack.shape[3]):
            image_stack[:,:,:,j] = image_transforms(image_stack[:,:,:,j])
        images = Dataset([image_stack[:,:,:,k] for k in range(3, image_stack.shape[3] - 3)])
        total_dataset = ConcatDataset([total_dataset, images])
    
    return total_dataset

In [None]:
nifti_dataset = NiFTIDataset(data_dir= "T2_images", transform = image_transform)
#nifti_dataset_org = NiFTIDataset(data_dir= "T2_images")
data = extract_slices(nifti_dataset)

In [None]:
nifti_dataset_org = NiFTIDataset(data_dir= "T2_images")

### Prelimenary visualizing of data

In [None]:

plt.figure()
plt.imshow(data.__getitem__(0)[0,:,:],cmap = "gray")
plt.colorbar()
#plt.title("Preprocessed image")
plt.show()

org_file = nib.load(nifti_dataset_org.__getitem__(index = 0))
print(org_file.shape)
slice_1 = org_file.get_fdata()[:,:,3]
print(slice_1.shape)
slice_1 = np.rot90(slice_1, k=3)

#print("diminfo:", slice_1['dim_info'])

plt.figure()
plt.imshow(slice_1, cmap = "gray")
plt.colorbar()
plt.show()

for i in range(109):
    org_file = nib.load(nifti_dataset_org.__getitem__(index = i))
    print(i, org_file.shape)
    print(org_file)
    
    org_file = nib.load(nifti_dataset_org.__getitem__(index = i))
    print(org_file.shape)
    slice_1 = org_file.get_fdata()[:,:,3]
    print(slice_1.shape)
    slice_1 = np.rot90(slice_1, k=3)

#print("diminfo:", slice_1['dim_info'])

    plt.figure()
    plt.imshow(slice_1, cmap = "gray")
    plt.colorbar()
    plt.title(str(i))
    plt.show()
    
    
    



### Creating validation and training dataset

In [None]:
'''Hold out validation: Need to hold back some patients for validation purposes, need to work with the nifti_dataset'''
train_ratio = 0.8 #Possible to choose another split-ratio

train_patiens = int(train_ratio * len(nifti_dataset))
val_patiens = len(nifti_dataset) - train_patiens

train_nifti_dataset, val_nifti_dataset = random_split(nifti_dataset, [train_patiens, val_patiens])

'''train_dataset and val_dataset needs to be fed into extract_slices. ! check input parameter !'''

print(len(train_nifti_dataset), len(val_nifti_dataset))

train_dataset = extract_slices(train_nifti_dataset)
val_dataset = extract_slices(val_nifti_dataset)

print(len(train_dataset), len(val_dataset))

### Visualizing training dataset

In [None]:
for i in range(len(train_dataset)):
    if (i % 100 == 0):
        
        plt.figure()
        plt.imshow(Dataset(train_dataset).__getitem__(index = i)[0,:,:],cmap = "bone")
        plt.colorbar()
        plt.title("Dataelement " + str(i))
        plt.show()
        
j = 0        
for real_image in train_dataset:
   # print(torch.amax(real_image[0]))
    real_image = real_image.numpy()
   # print(np.amax(real_image[0]))
    nifti_image = nib.Nifti1Image(real_image[0],np.eye(4))
    #nib.save(nifti_image, "Real_images/Real_training_data_bs8_8nov/nifti_file_" + str(j) + ".nii")
    j+=1
'''
k = 0    
for val_image in val_dataset:
   # print(torch.amax(val_image[0]))
    val_image = val_image.numpy()
   # print(np.amax(val_image[0]))
    nifti_image = nib.Nifti1Image(val_image[0],np.eye(4))
    nib.save(nifti_image, "Real_images/Real_validation_data_bs8_8nov/nifti_file_" + str(k) + ".nii")
    k+=1'''

### Visualizing and saving validation dataset

In [None]:
for i in range(len(val_dataset)):
    if (i % 100 == 0):
        
        plt.figure()
        plt.imshow(Dataset(val_dataset).__getitem__(index = i)[0,:,:],cmap = "bone")
        plt.colorbar()
        plt.title("Dataelement " + str(i))
        plt.show()
j = 0        
for val_image in val_dataset:
   # print(torch.amax(val_image[0]))
    val_image = val_image.numpy()
   # print(np.amax(val_image[0]))
    nifti_image = nib.Nifti1Image(val_image[0],np.eye(4))
    nib.save(nifti_image, "Real_images/Real_validation_data_bs16_for_epoch250/nifti_file_" + str(j) + ".nii")
    j+=1

### Loading dataset into dataloader

In [None]:
bs = 16
train_data_loader = DataLoader(train_dataset, batch_size=bs, shuffle=True, num_workers = 4, persistent_workers = True)
val_data_loader = DataLoader(val_dataset, batch_size=bs, shuffle=True, num_workers = 4, persistent_workers = True)

In [None]:
set_determinism(42)

### Defining diffusion model

In [None]:
device = torch.device("cuda")

model = DiffusionModelUNet(
    spatial_dims=2,
    in_channels=1,
    out_channels=1,
    num_channels=(128, 256, 256), #256, 256, 512
    attention_levels=(False, True, True),
    num_res_blocks=1,
    num_head_channels=256,
)

#Loading pre-trained model with the same architecture as above. This model is however trained on the MedNIST Hand dataset.
pre_trained_model = torch.hub.load("marksgraham/pretrained_generative_models:v0.2", model="ddpm_2d", verbose=True) 
state_dict = pre_trained_model.state_dict()
model.load_state_dict(state_dict, strict = False) 

model.to(device)

scheduler = DDPMScheduler(num_train_timesteps=1000)

optimizer = torch.optim.Adam(params=model.parameters(), lr=2.5e-5)

inferer = DiffusionInferer(scheduler)

In [None]:
'''Saving the images during training'''
saver = SaveImage(
    output_dir="Generated_images",
    output_ext=".png",
    output_postfix="itk",
    output_dtype=np.uint8,
    resample=False,
    writer="ITKWriter",
)

### Training the model

In [None]:
n_epochs = 200
val_interval = 25
'''Lists for metrics in order to tune hyperparameters'''
epoch_accuracy_list = []
val_epoch_accuracy_list = []
epoch_loss_list = []
val_epoch_loss_list = []
epoch_f1_score_list = []
val_epoch_f1_score_list = []
epoch_precision_list = []
val_epoch_precision_list = []
epoch_recall_list = []
val_epoch_recall_list = []

scaler = GradScaler()
total_start = time.time()

'''For tensorboard'''
writer = SummaryWriter()

for epoch in range(n_epochs):
    model.train()
    #Epoch metrics
    epoch_loss = 0
    epoch_accuracy = 0
    epoch_f1_score = 0
    epoch_precision = 0
    epoch_recall = 0
    
    progress_bar = tqdm(enumerate(train_data_loader), total=len(train_data_loader), ncols=70)
    progress_bar.set_description(f"Epoch {epoch}")
    for step, batch in progress_bar: 
        images = batch.to(device)
        optimizer.zero_grad(set_to_none=True)

        with autocast(enabled=True):
            '''Generate noise with equal shape as images in order overlay the noise on these. Create timesteps and get noise prediction'''
            noise = torch.randn_like(images).to(device)
            timesteps = torch.randint(0, inferer.scheduler.num_train_timesteps, (images.shape[0],), device=images.device).long()
            noise_pred = inferer(inputs=images, diffusion_model=model, noise=noise, timesteps=timesteps)
            '''Compare the predicted noise to the actual noise'''
            loss = F.mse_loss(noise_pred.float(), noise.float()) 
            
            conf_matrix = monai.metrics.get_confusion_matrix(noise_pred.float(),noise.float())
            accuracy = monai.metrics.compute_confusion_matrix_metric("accuracy", conf_matrix)
            f1_score = monai.metrics.compute_confusion_matrix_metric("f1 score", conf_matrix)
            precision = monai.metrics.compute_confusion_matrix_metric("precision", conf_matrix)
            recall = monai.metrics.compute_confusion_matrix_metric("recall", conf_matrix)
            
        #writer.add_scalar("Loss/train", loss, epoch)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        epoch_accuracy += torch.mean(accuracy).cpu()
        epoch_f1_score += torch.mean(f1_score).cpu()
        epoch_precision += torch.mean(precision).cpu()
        epoch_recall += torch.mean(recall).cpu()
        epoch_loss += loss.item() 
        #print("Accuracy: ", epoch_accuracy / (step + 1), ",precision: ", epoch_precision /(step + 1), ",Loss: ", epoch_loss /(step + 1))

        progress_bar.set_postfix({"loss": epoch_loss / (step + 1)})
    epoch_loss_list.append(epoch_loss / (step + 1))
    writer.add_scalar("Loss/train", epoch_loss /(step + 1), epoch)
    epoch_accuracy_list.append(epoch_accuracy / (step + 1))
    writer.add_scalar("Accuracy/train", epoch_accuracy /(step + 1), epoch)
    epoch_f1_score_list.append(epoch_f1_score / (step + 1))
    writer.add_scalar("F1_score/train", epoch_f1_score /(step + 1), epoch)
    epoch_precision_list.append(epoch_precision / (step + 1))
    writer.add_scalar("Precision/train", epoch_precision /(step + 1), epoch)
    epoch_recall_list.append(epoch_recall / (step + 1))
    writer.add_scalar("Recall/train", epoch_recall /(step + 1), epoch)
    
    if (epoch + 1) % val_interval == 0:
        model.eval()
        '''Saving model'''
        path = "Models/bs" + str(bs) + "_Epoch" + str(epoch) + "_of_" + str(n_epochs) + "8nov"
        torch.save(model.state_dict(), path)
        #Validation epoch metrics
        val_epoch_loss = 0
        val_epoch_accuracy = 0
        val_epoch_f1_score = 0
        val_epoch_precision = 0
        val_epoch_recall = 0
        
        for step, batch in enumerate(val_data_loader):
            images = batch.to(device)
            with torch.no_grad():
                with autocast(enabled=True):
                    noise = torch.randn_like(images).to(device)
                    timesteps = torch.randint(0, inferer.scheduler.num_train_timesteps, (images.shape[0],), device=images.device).long()
                    noise_pred = inferer(inputs=images, diffusion_model=model, noise=noise, timesteps=timesteps)
                    val_loss = F.mse_loss(noise_pred.float(), noise.float())
                    
                    val_conf_matrix = monai.metrics.get_confusion_matrix(noise_pred.float(),noise.float())
                    val_accuracy = monai.metrics.compute_confusion_matrix_metric("accuracy", val_conf_matrix)
                    val_f1_score = monai.metrics.compute_confusion_matrix_metric("f1 score", val_conf_matrix)
                    val_precision = monai.metrics.compute_confusion_matrix_metric("precision", val_conf_matrix)
                    val_recall = monai.metrics.compute_confusion_matrix_metric("recall", val_conf_matrix)

            val_epoch_loss += val_loss.item()
            val_epoch_accuracy = torch.mean(val_accuracy).cpu()
            val_epoch_f1_score = torch.mean(val_f1_score).cpu()
            val_epoch_precision = torch.mean(val_precision).cpu()
            val_epoch_recall= torch.mean(val_recall).cpu()
            progress_bar.set_postfix({"val_loss": val_epoch_loss / (step + 1)})
        
        val_epoch_loss_list.append(val_epoch_loss / (step + 1))
        val_epoch_accuracy_list.append(val_epoch_loss / (step + 1))
        val_epoch_f1_score_list.append(val_epoch_f1_score / (step + 1))
        val_epoch_precision_list.append(val_epoch_precision / (step + 1))
        val_epoch_recall_list.append(val_epoch_recall / (step + 1))
        
        writer.add_scalar("Loss/val", val_epoch_loss /(step + 1), epoch)
        writer.add_scalar("Accuracy/val", val_epoch_accuracy/(step + 1), epoch)
        writer.add_scalar("F1_score/val", val_epoch_f1_score/(step + 1), epoch)
        writer.add_scalar("Precision/val", val_epoch_precision/(step + 1), epoch)
        writer.add_scalar("Recall/val", val_epoch_recall/(step + 1), epoch)


        '''Sampling images from random noise to visualize during training'''
        noise = torch.randn((1, 1, 128, 128))
        noise = noise.to(device)
        scheduler.set_timesteps(num_inference_steps=1000)
        with autocast(enabled=True):
            image = inferer.sample(input_noise=noise, diffusion_model=model, scheduler=scheduler)

        plt.figure()
        plt.imshow(image[0, 0].cpu(), cmap="gray")
        plt.colorbar()
        plt.show()

total_time = time.time() - total_start
print(f"train completed, total time: {total_time}.")

### Training progress

Plotting the loss function over the number of epochs

In [None]:

plt.figure()
plt.plot(np.linspace(1, n_epochs, n_epochs), epoch_loss_list, color="red", linewidth=2.0, label="Train")
plt.plot(np.linspace(1, n_epochs, int(n_epochs / val_interval)), val_epoch_loss_list, "go-",linewidth=2.0, label="Validation")
plt.xlabel("Epochs", fontsize=16)
plt.ylabel("Loss", fontsize=16)
plt.legend(prop={"size": 14})
plt.show()

plt.figure()
plt.plot(np.linspace(1, n_epochs, n_epochs), epoch_recall_list, color="red", linewidth=2.0, label="Train")
plt.plot(np.linspace(1, n_epochs, int(n_epochs / val_interval)), val_epoch_recall_list, "go-",linewidth=2.0, label="Validation")
plt.xlabel("Epochs", fontsize=16)
plt.ylabel("Loss", fontsize=16)
plt.legend(prop={"size": 14})
plt.show()


loss = np.array(epoch_loss_list)
print("Training loss:", np.around(loss, 3))

val_loss = np.array(val_epoch_loss_list)
print("Validation loss:", np.around(val_loss, 3))

recall_train = np.array(epoch_recall_list)
print("Validation loss:", np.around(recall_train, 3))