<a href="https://colab.research.google.com/github/JohnYechanJo/Novo-Nordisk_Anomaly-Detection/blob/initial-diffusion-model/Training_Diffusion_Models_for_CNV_Images.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Data Download

In [None]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("paultimothymooney/kermany2018")

print("Path to dataset files:", path)

Path to dataset files: /kaggle/input/kermany2018


In [None]:
# Example using Python + PIL
from PIL import Image
import os

in_dir = "/kaggle/input/kermany2018/OCT2017 /train/CNV"
out_dir = "processed/CNV/"
os.makedirs(out_dir, exist_ok=True)

for fn in os.listdir(in_dir):
    img = Image.open(os.path.join(in_dir, fn)).convert("RGB")     # ensure 3‑channel
    img = img.resize((512, 512), resample=Image.LANCZOS)         # model’s expected res
    img.save(os.path.join(out_dir, fn))

In [None]:
!zip -r /content/processed_images.zip /content/processed

UnZip 6.00 of 20 April 2009, by Debian. Original by Info-ZIP.

Usage: unzip [-Z] [-opts[modifiers]] file[.zip] [list] [-x xlist] [-d exdir]
  Default action is to extract files in list, except those in xlist, to exdir;
  file[.zip] may be a wildcard.  -Z => ZipInfo mode ("unzip -Z" for usage).

  -p  extract files to pipe, no messages     -l  list files (short format)
  -f  freshen existing files, create none    -t  test compressed archive data
  -u  update files, create if necessary      -z  display archive comment only
  -v  list verbosely/show version info       -T  timestamp archive to latest
  -x  exclude files that follow (in xlist)   -d  extract files into exdir
modifiers:
  -n  never overwrite existing files         -q  quiet mode (-qq => quieter)
  -o  overwrite files WITHOUT prompting      -a  auto-convert any text files
  -j  junk paths (do not make directories)   -aa treat ALL files as text
  -U  use escapes for all non-ASCII Unicode  -UU ignore any Unicode fields
  -C  mat

In [None]:
import zipfile

zip_path = "/processed_images.zip"
extract_to = "/content/processed/"

with zipfile.ZipFile(zip_path, "r") as z:
    z.extractall(extract_to)

print(f"Extracted all files to {extract_to}")

Extracted all files to /content/processed/


In [None]:
!pip install diffusers transformers accelerate datasets xformers

Collecting datasets
  Downloading datasets-3.5.0-py3-none-any.whl.metadata (19 kB)
Collecting xformers
  Downloading xformers-0.0.29.post3-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (1.0 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.12.0,>=2023.1.0 (from fsspec[http]<=2024.12.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.12.0-py3-none-any.whl.metadata (11 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=2.0.0->accelerate)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=2.0.0->accelerate)
  Downloading nvidia_cuda_runtime_

In [None]:
import xformers; xformers.ops.memory_efficient_attention

<function xformers.ops.fmha.memory_efficient_attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_bias: Union[torch.Tensor, xformers.ops.fmha.attn_bias.AttentionBias, NoneType] = None, p: float = 0.0, scale: Optional[float] = None, *, op: Optional[Tuple[Optional[Type[xformers.ops.fmha.common.AttentionFwOpBase]], Optional[Type[xformers.ops.fmha.common.AttentionBwOpBase]]]] = None, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor>

# Train the Model

In [None]:
import os
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler
from transformers import CLIPTextModel, CLIPTokenizer
from accelerate import Accelerator

In [None]:
class CNVDataset(Dataset):
    def __init__(self, root_dir, tokenizer, resolution=512, max_length=77):
        self.files = [
            os.path.join(root_dir, f)
            for f in os.listdir(root_dir)
            if f.lower().endswith((".png", ".jpg", ".jpeg"))
        ]
        self.tokenizer = tokenizer
        self.transform = transforms.Compose([
            transforms.Resize((resolution, resolution), transforms.InterpolationMode.LANCZOS),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ])
        self.prompt = "OCT scan showing CNV"
        self.max_length = max_length

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img = Image.open(self.files[idx]).convert("RGB")
        img = self.transform(img)
        tokens = self.tokenizer(
            self.prompt,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )
        return {"pixel_values": img, "input_ids": tokens.input_ids.squeeze(0)}

In [None]:
def train(
    pretrained_model: str,
    data_dir: str,
    output_dir: str = "sd_cnv_finetuned",
    resolution: int = 512,
    batch_size: int = 4,
    learning_rate: float = 1e-4,
    epochs: int = 5,
    grad_accum_steps: int = 1,
    save_steps: int = 1000,
    resume_checkpoint: str = None,
    unet: UNet2DConditionModel = None,
    accelerator: Accelerator = None,
):

    # 1) Prepare output directory, accelerator & device
    os.makedirs(output_dir, exist_ok=True)
    accel = accelerator or Accelerator()
    device = accel.device

    # 2) Load or resume UNet
    if resume_checkpoint and unet is None:
        unet = UNet2DConditionModel.from_pretrained(resume_checkpoint).to(device)
    elif unet is None:
        unet = UNet2DConditionModel.from_pretrained(pretrained_model, subfolder="unet").to(device)

    # 3) Load & freeze tokenizer + text encoder
    tokenizer    = CLIPTokenizer.from_pretrained(pretrained_model, subfolder="tokenizer")
    text_encoder = CLIPTextModel.from_pretrained(pretrained_model, subfolder="text_encoder").to(device)
    text_encoder.requires_grad_(False)

    # 4) Load & freeze VAE
    vae = AutoencoderKL.from_pretrained(pretrained_model, subfolder="vae").to(device)
    vae.requires_grad_(False)

    # 5) Load noise scheduler
    scheduler = DDPMScheduler.from_pretrained(pretrained_model, subfolder="scheduler")

    # 6) Prepare dataset & dataloader (assumes CNVDataset is defined)
    dataset    = CNVDataset(data_dir, tokenizer, resolution=resolution)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # 7) Optimizer (only UNet params)
    optimizer = torch.optim.AdamW(unet.parameters(), lr=learning_rate)

    # 8) Wrap models, optimizer, and dataloader for mixed‑precision / distributed
    unet, optimizer, dataloader = accel.prepare(unet, optimizer, dataloader)

    # 9) Resume state if requested
    if resume_checkpoint:
        accel.load_state(resume_checkpoint)
        # pick up where you left off
        global_step = int(resume_checkpoint.rsplit("_", 1)[-1])
    else:
        global_step = 0

    # 10) Training loop
    for epoch in range(1, epochs + 1):
        unet.train()
        for batch in dataloader:
            with accel.accumulate(unet):
                # Encode images to latents
                pixels  = batch["pixel_values"].to(device)
                latents = vae.encode(pixels).latent_dist.sample() * 0.18215

                # Add noise
                noise     = torch.randn_like(latents)
                timesteps = torch.randint(0, scheduler.num_train_timesteps,
                                          (latents.shape[0],), device=device)
                noisy_latents = scheduler.add_noise(latents, noise, timesteps)

                # Text conditioning
                input_ids             = batch["input_ids"].to(device)
                encoder_hidden_states = text_encoder(input_ids)[0]

                # Noise prediction & loss
                pred_noise = unet(noisy_latents, timesteps, encoder_hidden_states).sample
                loss       = torch.nn.functional.mse_loss(pred_noise, noise)

                # Backpropagate
                accel.backward(loss)
                optimizer.step()
                optimizer.zero_grad()

            global_step += 1
            if global_step % save_steps == 0:
                accel.wait_for_everyone()
                ckpt_dir = os.path.join(output_dir, f"checkpoint_{global_step}")
                unet.save_pretrained(ckpt_dir)
                if accel.is_main_process:
                    tokenizer.save_pretrained(ckpt_dir)
                accel.save_state(ckpt_dir)

        print(f"Epoch {epoch}/{epochs} complete")

    # 11) Final save
    accel.wait_for_everyone()
    final_dir = os.path.join(output_dir, "final_unet")
    unet.save_pretrained(final_dir)
    if accel.is_main_process:
        tokenizer.save_pretrained(output_dir)
    print("Fine‑tuning complete — models saved to", output_dir)

In [None]:
from accelerate import Accelerator
# Optionally load the checkpoint’s state:
accel = Accelerator()
train(
    pretrained_model="runwayml/stable-diffusion-v1-5",
    data_dir="/content/processed/content/processed/CNV",
    output_dir="/content/sd_cnv_finetuned",
    resolution=512,
    batch_size=4,
    learning_rate=1e-4,
    epochs=5,
    grad_accum_steps=1,
    save_steps=1000,
    resume_checkpoint="/content/sd_cnv_finetuned/checkpoint_6000",
    accelerator=accel
)

KeyboardInterrupt: 

In [None]:
# if you get out of memory issues, run this cell

import gc
import torch

gc.collect()
torch.cuda.empty_cache()

In [None]:
# for me to download the files from colab

!zip -r /content/checkpoints_11000.zip /content/sd_cnv_finetuned/checkpoint_11000/

  adding: content/sd_cnv_finetuned/checkpoint_11000/ (stored 0%)
  adding: content/sd_cnv_finetuned/checkpoint_11000/tokenizer_config.json (deflated 63%)
  adding: content/sd_cnv_finetuned/checkpoint_11000/random_states_0.pkl (deflated 25%)
  adding: content/sd_cnv_finetuned/checkpoint_11000/model.safetensors

# Inference

In [None]:
import torch
from diffusers import StableDiffusionPipeline, UNet2DConditionModel
from transformers import CLIPTokenizer

device = "cuda"  # or "cpu" if no GPU

# 1) Load the base pipeline (with original VAE, text-encoder, scheduler, tokenizer)
pipe = StableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    torch_dtype=torch.float16
).to(device)


# 2) Load your fine-tuned UNet
finetuned_unet = UNet2DConditionModel.from_pretrained(
    "/content/sd_cnv_finetuned/checkpoint_13000",
    torch_dtype=torch.float16
).to(device)

# 3) Replace the pipeline’s UNet
pipe.unet = finetuned_unet

# 4) (Optional) Load & swap in your checkpoint’s tokenizer
ckpt_tokenizer = CLIPTokenizer.from_pretrained(
    "/content/sd_cnv_finetuned/checkpoint_5000"
)
pipe.tokenizer = ckpt_tokenizer

# 5) Generate!
prompt = "OCT scan showing CNV"
out = pipe(
    prompt,
    num_inference_steps=50,
    guidance_scale=7.5
)
img = out.images[0]
img.save("cnv_finetuned_example.png")

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]