<a href="https://colab.research.google.com/github/JohnYechanJo/Novo-Nordisk_Anomaly-Detection/blob/initial-diffusion-model/diffusion_model_tuning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install kagglehub torch torchvision transformers diffusers accelerate datasets xformers pytorch-fid pandas
import os
import gc
import torch
import numpy as np
import random
from PIL import Image
import pandas as pd
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset, TensorDataset
from transformers import ViTModel, CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler, StableDiffusionPipeline
from accelerate import Accelerator
from sklearn.metrics import classification_report, accuracy_score
import matplotlib.pyplot as plt

# Set random seed
seed = 123
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)

# GPU setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# Memory cleanup function
def clear_memory():
    gc.collect()
    torch.cuda.empty_cache()

##Diffusion Model Fine-tuning
Fine-tune the UNet of the Stable Diffusion model using CNV images.
Download the dataset from Kaggle, preprocess the CNV images, and use them for training.
The trained model will be saved to `/content/models/sd_cnv_finetuned.`

###1. Data Preparation
Download CNV images from the Kaggle dataset, resize them to 512x512, and save them to `/content/processed/CNV/.`

In [None]:
import kagglehub
import os
from PIL import Image

def prepare_cnv_images():
    # Kaggle dataset download
    path = kagglehub.dataset_download("paultimothymooney/kermany2018")
    print(f"Dataset downloaded to: {path}")

    # Set dataset path and copy CNV images
    in_dir = os.path.join(path, "OCT2017 /train/CNV")
    out_dir = "/content/processed/CNV/"
    os.makedirs(out_dir, exist_ok=True)
    for fn in os.listdir(in_dir):
        img = Image.open(os.path.join(in_dir, fn)).convert("RGB")
        img = img.resize((512, 512), resample=Image.LANCZOS)
        img.save(os.path.join(out_dir, fn))
    print(f"Processed CNV images saved to {out_dir}")

prepare_cnv_images()

###2. Dataset Class
Define the `ImgDataset` class to load CNV images and transform them into the format required for training, with detailed prompt.

In [None]:
from torch.utils.data import Dataset
from torchvision import transforms

class ImgDataset(Dataset):
    def __init__(self, root_dir:str, tokenizer, prompt:str = "OCT scan showing CNV", resolution=512, max_length=77, dataset_size:int = 6400) -> None:
        self.files = [
            os.path.join(root_dir, f)
            for f in os.listdir(root_dir)
            if f.lower().endswith((".png", ".jpg", ".jpeg"))
        ]
        if len(self.files) > dataset_size: self.files = self.files[:dataset_size]
        self.tokenizer = tokenizer
        self.transform = transforms.Compose([
            transforms.Resize((resolution, resolution), transforms.InterpolationMode.LANCZOS),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ])
        self.prompt = prompt
        self.max_length = max_length

    def __len__(self) -> int:
        return len(self.files)

    def __getitem__(self, idx:int) -> dict:
        img = Image.open(self.files[idx]).convert("RGB")
        img = self.transform(img)
        tokens = self.tokenizer(
            self.prompt,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )
        return {"pixel_values": img, "input_ids": tokens.input_ids.squeeze(0)}

###3. Training Function
Define the `train` function to fine-tune the Stable Diffusion UNet using CNV images. It supports mixed precision training and checkpoint saving.

In [None]:
import torch
from torch.utils.data import DataLoader
from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler
from transformers import CLIPTextModel, CLIPTokenizer
from accelerate import Accelerator

def train(
    pretrained_model: str,
    data_dir: str,
    output_dir: str = "sd_cnv_finetuned",
    prompt:str = "OCT scan showing CNV",
    resolution: int = 512,
    dataset_size: int = 6400,
    batch_size: int = 4,
    learning_rate: float = 1e-4,
    epochs: int = 5,
    grad_accum_steps: int = 1,
    save_steps: int = 1000,
    resume_checkpoint: str = None,
    unet: UNet2DConditionModel = None,
    accelerator: Accelerator = None,
):
    # 1) Prepare output directory, accelerator & device
    os.makedirs(output_dir, exist_ok=True)
    accel = accelerator or Accelerator()
    device = accel.device

    # 2) Load or resume UNet
    if resume_checkpoint and unet is None:
        unet = UNet2DConditionModel.from_pretrained(resume_checkpoint).to(device)
    elif unet is None:
        unet = UNet2DConditionModel.from_pretrained(pretrained_model, subfolder="unet").to(device)

    # 3) Load & freeze tokenizer + text encoder
    tokenizer = CLIPTokenizer.from_pretrained(pretrained_model, subfolder="tokenizer")
    text_encoder = CLIPTextModel.from_pretrained(pretrained_model, subfolder="text_encoder").to(device)
    text_encoder.requires_grad_(False)

    # 4) Load & freeze VAE
    vae = AutoencoderKL.from_pretrained(pretrained_model, subfolder="vae").to(device)
    vae.requires_grad_(False)

    # 5) Load noise scheduler
    scheduler = DDPMScheduler.from_pretrained(pretrained_model, subfolder="scheduler")

    # 6) Prepare dataset & dataloader
    dataset = ImgDataset(data_dir, tokenizer, prompt = prompt, resolution=resolution, dataset_size=dataset_size)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # 7) Optimizer (only UNet params)
    optimizer = torch.optim.AdamW(unet.parameters(), lr=learning_rate)

    # 8) Wrap models, optimizer, and dataloader for mixed-precision / distributed
    unet, optimizer, dataloader = accel.prepare(unet, optimizer, dataloader)

    # 9) Resume state if requested
    if resume_checkpoint:
        accel.load_state(resume_checkpoint)
        global_step = int(resume_checkpoint.rsplit("_", 1)[-1])
    else:
        global_step = 0

    # 10) Training loop
    for epoch in range(1, epochs + 1):
        unet.train()
        for batch in dataloader:
            with accel.accumulate(unet):
                # Encode images to latents
                pixels = batch["pixel_values"].to(device)
                latents = vae.encode(pixels).latent_dist.sample() * 0.18215

                # Add noise
                noise = torch.randn_like(latents)
                timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (latents.shape[0],), device=device)
                noisy_latents = scheduler.add_noise(latents, noise, timesteps)

                # Text conditioning
                input_ids = batch["input_ids"].to(device)
                encoder_hidden_states = text_encoder(input_ids)[0]

                # Noise prediction & loss
                pred_noise = unet(noisy_latents, timesteps, encoder_hidden_states).sample
                loss = torch.nn.functional.mse_loss(pred_noise, noise)

                # Backpropagate
                accel.backward(loss)
                optimizer.step()
                optimizer.zero_grad()

            global_step += 1
            if global_step % save_steps == 0:
                accel.wait_for_everyone()
                ckpt_dir = os.path.join(output_dir, f"checkpoint_{global_step}")
                unet.save_pretrained(ckpt_dir)
                if accel.is_main_process:
                    tokenizer.save_pretrained(ckpt_dir)
                accel.save_state(ckpt_dir)

        print(f"Epoch {epoch}/{epochs} complete")

    # 11) Final save
    accel.wait_for_everyone()
    final_dir = os.path.join(output_dir, "final_unet")
    unet.save_pretrained(final_dir)
    if accel.is_main_process:
        tokenizer.save_pretrained(output_dir)
    print("Fine-tuning complete — models saved to", output_dir)

###4. Execute Diffusion Model Training
Use the pre-trained model (`nota-ai/bk-sdm-small`) to fine-tune with CNV images.

In [None]:
from accelerate import Accelerator

# Accelerator initialization
accel = Accelerator()

# Change prompt and see what happens
detailed_prompt = "Optical Coherence Tomography Scacn of Choroidal neovascularization retina"

# Excute Diffusion Model training
train(
    pretrained_model="nota-ai/bk-sdm-small",
    data_dir="/content/processed/CNV",
    output_dir="/content/models/sd_cnv_finetuned",
    prompt = detailed_prompt,
    resolution=512,
    dataset_size=500,  # Consideration of Training Time
    batch_size=4,
    learning_rate=1e-4,
    epochs=10,
    grad_accum_steps=1,
    save_steps=1000,
    resume_checkpoint=None,
    accelerator=accel
)

# Memory clean up
clear_memory()

##5. Generate Synthetic CNV Images
Use the fine-tuned Diffusion Model to generate synthetic CNV images.

In [None]:
def generate_synthetic_images():
    # Load the Base Pipeline from the Original Model
    pipeline = StableDiffusionPipeline.from_pretrained(
        "nota-ai/bk-sdm-small",
        torch_dtype=torch.float16,
        use_auth_token=False
    ).to(device)

    # Load the Fine-tuned UNet
    unet = UNet2DConditionModel.from_pretrained(
        "/content/models/sd_cnv_finetuned/final_unet",
        torch_dtype=torch.float16
    ).to(device)

    # Replace the UNet in the Pipeline
    pipeline.unet = unet

    # Generate Synthetic Images
    synthetic_dir = "/content/synthetic_cnv/"
    os.makedirs(synthetic_dir, exist_ok=True)
    num_images = 640
    prompt = (
    f"High-resolution grayscale {detailed_prompt}"
    "Clear layer definition, with details of retinal features. Dark background, minimal color, realistic biomedical texture."
    )


    for i in range(num_images):
        image = pipeline(prompt, num_inference_steps=50).images[0]
        image.save(os.path.join(synthetic_dir, f"synthetic_cnv_{i}.png"))
        if i % 50 == 0:
            print(f"Generated {i}/{num_images} images")
        clear_memory()

generate_synthetic_images()
clear_memory()

# 6. Example of Generated Images

In [None]:
import random

# Function to display images in a grid
def display_images_grid(image_paths, num_cols=4):
    num_images = len(image_paths)
    num_rows = (num_images + num_cols - 1) // num_cols
    fig, axes = plt.subplots(num_rows, num_cols, figsize=(12, 3 * num_rows))
    for i, path in enumerate(image_paths):
        row = i // num_cols
        col = i % num_cols
        try:
            image = plt.imread(path)
            axes[row, col].imshow(image)
            axes[row, col].axis('off')
        except FileNotFoundError:
            print(f"File not found: {path}")
        except Exception as e:
          print(f"Error loading image {path}: {e}")
    plt.tight_layout()
    plt.show()

# Get a list of image file paths
image_dir = "/content/synthetic_cnv/"
image_files = [os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith('.png')]

# Randomly select 16 images
random_images = random.sample(image_files, 16)

# Display the selected images in a grid
display_images_grid(random_images)
