# Sampling images from the trained diffusion models

This code samples from the trained diffusion models. Sampling is done following the tutorial from the MONAI-generative framework: https://github.com/Project-MONAI/GenerativeModels/blob/main/tutorials/generative/2d_ddpm/2d_ddpm_tutorial.ipynb

In [None]:
import torch
from generative.inferers import DiffusionInferer
from generative.networks.nets import DiffusionModelUNet
from generative.networks.schedulers import DDPMScheduler
from torch.cuda.amp import autocast
import matplotlib.pyplot as plt
import cv2
from monai.utils import set_determinism
from monai.config import print_config
import nibabel as nib
import numpy as np
print_config()

In [None]:
set_determinism(42)

### Loading trained model 

In [None]:
model = DiffusionModelUNet(
    spatial_dims=2,
    in_channels=1,
    out_channels=1,
    num_channels=(128, 256, 256), #256, 256, 512
    attention_levels=(False, True, True),
    num_res_blocks=1,
    num_head_channels=256,
)
device = torch.device("cuda")
'''Specify which model you wish to sample from'''
#modelname = "Models/bs16_Epoch124_of_2503nov" #74 / 124 / 174
#modelname = "Models/bs16_Epoch124_of_2503nov"
#modelname = "Models/bs16_Epoch149_of_2008nov_timestep500"
#modelname = "Models/bs8_Epoch149_of_2008nov"
modelname = "Models/bs16_Epoch149_of_2503nov"
pre_trained_model = torch.load(modelname) #,map_location=torch.device('cpu'))
model.load_state_dict(pre_trained_model, strict = False) 
model.to(device)


scheduler = DDPMScheduler(num_train_timesteps=1000)#1000
inferer = DiffusionInferer(scheduler)

### Sampling images

In [None]:
noise = torch.randn((100, 1, 128, 128)) #Generating 100 images with the shape (1, 128, 128)
noise = noise.to(device)
scheduler.set_timesteps(num_inference_steps=1000)

images = inferer.sample(input_noise=noise, diffusion_model=model, scheduler=scheduler)
print(len(images))


for i in range(len(images)):
    print(images[i, 0].cpu().shape, images[i, 0].cpu().type)
    numpy_arr = images[i, 0].detach().cpu().numpy()
    #numpy_arr_scaled = 255 * numpy_arr
    plt.imshow(numpy_arr, vmin = 0, vmax = 1, cmap = "bone")
    #plt.imshow(numpy_arr_scaled, cmap = "bone")
    #cv2.imwrite("Genererte_" + str(i) + ".png", numpy_arr_scaled)
    nifti_image = nib.Nifti1Image(numpy_arr,np.eye(4))
    #nib.save(nifti_image, "Synthetic_images/bs16_125epochs_3nov/nifti_file_" + str(i) + ".nii") #125
    #nib.save(nifti_image, "Synthetic_images/bs8_150epochs_8nov/nifti_file_" + str(i) + ".nii")
    nib.save(nifti_image, "Synthetic_images/bs16_150epochs_22_nov_larger_dataset/nifti_file_" + str(i) + ".nii")
    #nib.save(nifti_image, "Synthetic_images/bs16_150epochs_timestep500/nifti_file_" + str(i) + ".nii")

### Displaying diffusion process

In [None]:
model.eval()
noise = torch.randn((1, 1, 128, 128))
noise = noise.to(device)
scheduler.set_timesteps(num_inference_steps=1000)
with autocast(enabled=True):
    image, intermediates = inferer.sample(
        input_noise=noise, diffusion_model=model, scheduler=scheduler, save_intermediates=True, intermediate_steps=200
    )

chain = torch.cat(intermediates, dim=-1)

plt.style.use("default")
plt.imshow(chain[0, 0].cpu(), vmin=0, vmax=1, cmap="gray")
plt.tight_layout()
plt.axis("off")
plt.figure(figsize=(30,10))
plt.show()