<a href="https://colab.research.google.com/github/adiagarwal191/Generation-of-High-Fidelity-Fluorescent-Cell-Images-via-a-Self-Developed-Diffusion-Model/blob/main/FYP_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional
from diffusers import UNet2DModel
from torch.optim import Adam
from torch.amp import autocast, GradScaler
import torch.amp
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data import DataLoader, Dataset
import glob
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt

In [None]:
!nvidia-smi

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

In [None]:
class config:
  image_size = (256,256)
  num_epochs = 10
  learning_rate = 0.0001
  output_dir = '/content/drive/MyDrive/gen_images' #This path can be changed based on folder and path names
  T = 100
  batch_size = 8
config = config()

In [None]:
from google.colab import drive
drive.mount('/content/drive')

data_path = '/content/drive/MyDrive/all_images'


transform = transforms.Compose([
    #random crop here
    transforms.RandomCrop(256), #Randomly crops images of size 256x256 from any location in the image
    transforms.ToTensor()
])


class imageDataSet(Dataset):
  def __init__(self,folder_path, transform):
    self.paths = glob.glob(os.path.join(folder_path, '*'))
    self.transform = transform

  def __len__(self):
    return len(self.paths)

  def __getitem__(self,idx):
    img_path = self.paths[idx]
    img = Image.open(img_path)
    # print(img_path)
    if self.transform:
      img = self.transform(img)
    return img

dataset = imageDataSet(folder_path = data_path, transform = transform)
dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True)

print(f"total images = {len(dataset)}")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
total images = 1440


In [None]:
model = UNet2DModel(
    sample_size=config.image_size,
    in_channels=3,
    out_channels=3,
    layers_per_block=2,
    block_out_channels=(64,64,128, 128, 256, 256),  # Consider reducing these numbers if needed.
    down_block_types=(
        "DownBlock2D",
        "DownBlock2D",
        "DownBlock2D",
        "DownBlock2D",
        "AttnDownBlock2D",
        "DownBlock2D",
    ),
    up_block_types=(
        "UpBlock2D",
        "AttnUpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D"
    ),
)

model = model.to(device)

if hasattr(model, "enable_gradient_checkpointing"):
    model.enable_gradient_checkpointing()
print("Model loaded on", device)

Model loaded on cuda


In [None]:
for batch in dataloader:
    print(batch.shape)  #Cell to check if model has processed the dataset correctly
    break

torch.Size([8, 3, 256, 256])


In [None]:
def cosine_beta_schedule(timesteps , s=0.008): #Function that defines the noise schedule
  t = torch.linspace(0 , timesteps, timesteps + 1)
  alphas_cumprod = torch.cos(((t / timesteps + s) / (1 + s)) * (torch.pi / 2)) ** 2
  alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
  betas = 1 - alphas_cumprod[1:] / alphas_cumprod[:-1]
  betas = torch.clamp(betas, 0, 0.999)
  return betas

#Code below is to calculate the posterior variance
betas = cosine_beta_schedule(config.T).to(device)
alphas = 1 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_cumprod_prev = torch.cat([torch.tensor([1.0], device = device), alphas_cumprod[:-1]], dim=0)
posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)

In [None]:
def forward_diffusion_sample(x_0, t): #Applies noise to the image and returns the noisy image and the amount of noise added
  t = t.long()
  alphas_cumprod_t = alphas_cumprod[t].view(-1, 1, 1, 1)
  sqrt_alpha_cumprod_t = torch.sqrt(alphas_cumprod_t)
  sqrt_one_minus_alpha_cumprod_t = torch.sqrt(1 - alphas_cumprod_t)
  noise = torch.randn_like(x_0)
  x_t = sqrt_alpha_cumprod_t * x_0 + sqrt_one_minus_alpha_cumprod_t * noise
  return x_t, noise

In [None]:
def ddpm_loss(model, x_0, t): #Function to calculate the loss between the noise added and the model's prediction
  x_0 = x_0.to(device)
  t = t.to(device)
  x_t, noise = forward_diffusion_sample(x_0, t)
  noise_pred = model(x_t, t).sample
  lossMSE = nn.MSELoss()
  output = lossMSE(noise_pred, noise)
  return output

In [None]:
# model_load = torch.load('/content/drive/MyDrive/gen_images/diffusion_model.pth', map_location=torch.device('cpu'))
# model.load_state_dict(model_load)
# The above code is if you wish to load the model from a saved state


def train_model(model, dataloader, epochs, learning_rate): #The training loop for the model
  optimiser = Adam(model.parameters(), lr = config.learning_rate)
  scaler = GradScaler()
  model.train()
  z = 0
  for epoch in range(epochs):
    for step, batch_images in enumerate(dataloader):

      t = torch.randint(0, config.T, (batch_images.size(0),), device=device)

      optimiser.zero_grad()
      with autocast(device_type="cuda"):
          loss = ddpm_loss(model, batch_images, t)

      scaler.scale(loss).backward()
      torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
      scaler.step(optimiser)
      scaler.update()

      if (step% 5 == 0):
        print (f"Epoch: {epoch} | step: {step} | loss: {loss.item()}")

      if (step % 20 == 0):
        z = z + 1
        print("Sampling images from the trained model...")
        generated = sample(model, sample_size=1)  #The number of samples can be changed, however this will affect training times
        for i in range(generated.size(0)):
          img_np = generated[i].permute(1,2,0).cpu().numpy()
          plt.figure(figsize=(6,6))
          plt.imshow(img_np)
          plt.axis('off')
          plt.title(f"Generated sample {z}")
          plt.show()


    os.makedirs(config.output_dir, exist_ok=True)
    save_path = os.path.join(config.output_dir, f"epoch_{epoch}_image.png")


  model_save_path = os.path.join(config.output_dir, "diffusion_model.pth")
  torch.save(model.state_dict(), model_save_path) #Saves the trained model
  print(f"Model saved to {model_save_path}")


In [None]:
def sample(model, sample_size = 1): #Code for the reverse diffusion process
  model.eval()
  with torch.no_grad():
      # Start from pure noise
    x = torch.randn(sample_size, 3, config.image_size[1], config.image_size[0], device=device)

    for i in reversed(range(config.T)):
        # print(i)
        t = torch.tensor([i]*sample_size, device=device, dtype=torch.long)
        beta_t = betas[t].view(-1,1,1,1).to(device)
        alpha_t = alphas[t].view(-1,1,1,1).to(device)
        alpha_cumprod_t = alphas_cumprod[t].view(-1,1,1,1).to(device)
        alpha_cumprod_t_prev = alphas_cumprod_prev[t].view(-1,1,1,1).to(device)

        model_pred = model(x, t).sample
        one_over_sqrt_alpha_t = 1.0 / torch.sqrt(alpha_t)
        coeff_model_pred = (beta_t / torch.sqrt(1 - alpha_cumprod_t))
        x_0_pred_part = x - coeff_model_pred * model_pred
        x_0_pred_part = one_over_sqrt_alpha_t * x_0_pred_part

        if i > 0:
            sigma_t = torch.sqrt((1 - alpha_cumprod_t_prev)/(1 - alpha_cumprod_t)*beta_t)
            z = torch.randn_like(x)
            x = x_0_pred_part + (sigma_t * z)
        else:
            x = x_0_pred_part

    x = x.clamp(0,1)
    return x



In [None]:
if __name__ == "__main__": #This begins the training
  print("Starting training...")
  train_model(model, dataloader, epochs=config.num_epochs, learning_rate=config.learning_rate)

  print("Sampling images from the trained model...")
  generated = sample(model, sample_size=2)

  for i in range(generated.size(0)):
      img_np = generated[i].permute(1,2,0).cpu().numpy()
      plt.figure(figsize=(6,6))
      plt.imshow(img_np)
      plt.axis('off')
      plt.title(f"Generated sample {i}")
      plt.show()


In [None]:
  #This code snippet is to test the images generated after the model has been trained
  print("Sampling images from the trained model...")
  generated = sample(model, sample_size=16)


  for i in range(generated.size(0)):
      img_np = generated[i].permute(1,2,0).cpu().numpy()
      plt.figure(figsize=(6,6))
      plt.imshow(img_np)
      plt.axis('off')
      plt.title(f"Generated sample {i}")
      plt.show()
