# Data Download

In [1]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("paultimothymooney/kermany2018")

print("Path to dataset files:", path)

Downloading from https://www.kaggle.com/api/v1/datasets/download/paultimothymooney/kermany2018?dataset_version_number=2...


100%|██████████| 10.8G/10.8G [02:02<00:00, 95.4MB/s]

Extracting files...





Path to dataset files: /root/.cache/kagglehub/datasets/paultimothymooney/kermany2018/versions/2


In [2]:
import os

os.makedirs("./data/", exist_ok=True)

for cohort in ["CNV", "NORMAL"]:

    ! cp -r "{path}/OCT2017 /train/{cohort}/" ./data/

# Train the Model

In [1]:
import os
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler, StableDiffusionPipeline
from transformers import CLIPTextModel, CLIPTokenizer
from accelerate import Accelerator

In [None]:
class ImgDataset(Dataset):

    def __init__(self, root_dir:str, tokenizer, cohort:str = "CNV", resolution=512, max_length=77, dataset_size:int = 6400) -> None:

        self.files = [
            os.path.join(root_dir, f)
            for f in os.listdir(root_dir)
            if f.lower().endswith((".png", ".jpg", ".jpeg"))
        ]

        if len(self.files) > dataset_size: self.files = self.files[:dataset_size]

        self.tokenizer = tokenizer

        self.transform = transforms.Compose([
            transforms.Resize((resolution, resolution), transforms.InterpolationMode.LANCZOS),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
        ])

        self.prompt = f"OCT scan showing {cohort}"
        self.max_length = max_length

    def __len__(self) -> int:

        return len(self.files)

    def __getitem__(self, idx:int) -> dict:

        img = Image.open(self.files[idx]).convert("RGB")
        img = self.transform(img)

        tokens = self.tokenizer(
            self.prompt,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )

        return {"pixel_values": img, "input_ids": tokens.input_ids.squeeze(0)}

In [3]:
def train(
    pretrained_model: str,
    data_dir: str,
    output_dir: str = "sd_cnv_finetuned",
    cohort: str = "CNV",
    resolution: int = 512,
    dataset_size:int = 6400,
    batch_size: int = 4,
    learning_rate: float = 1e-4,
    epochs: int = 5,
    grad_accum_steps: int = 1,
    save_steps: int = 1000,
    resume_checkpoint: str = None,
    unet: UNet2DConditionModel = None,
    accelerator: Accelerator = None,
):

    # 1) Prepare output directory, accelerator & device
    os.makedirs(output_dir, exist_ok=True)
    accel = accelerator or Accelerator()
    device = accel.device

    # 2) Load or resume UNet
    if resume_checkpoint and unet is None:
        unet = UNet2DConditionModel.from_pretrained(resume_checkpoint).to(device)
    elif unet is None:
        unet = UNet2DConditionModel.from_pretrained(pretrained_model, subfolder="unet").to(device)

    # 3) Load & freeze tokenizer + text encoder
    tokenizer    = CLIPTokenizer.from_pretrained(pretrained_model, subfolder="tokenizer")
    text_encoder = CLIPTextModel.from_pretrained(pretrained_model, subfolder="text_encoder").to(device)
    text_encoder.requires_grad_(False)

    # 4) Load & freeze VAE
    vae = AutoencoderKL.from_pretrained(pretrained_model, subfolder="vae").to(device)
    vae.requires_grad_(False)

    # 5) Load noise scheduler
    scheduler = DDPMScheduler.from_pretrained(pretrained_model, subfolder="scheduler")

    # 6) Prepare dataset & dataloader (assumes CNVDataset is defined)
    dataset    = ImgDataset(data_dir, tokenizer, cohort = cohort, resolution=resolution, dataset_size = dataset_size)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # 7) Optimizer (only UNet params)
    optimizer = torch.optim.AdamW(unet.parameters(), lr=learning_rate)

    # 8) Wrap models, optimizer, and dataloader for mixed‑precision / distributed
    unet, optimizer, dataloader = accel.prepare(unet, optimizer, dataloader)

    # 9) Resume state if requested
    if resume_checkpoint:
        accel.load_state(resume_checkpoint)
        # pick up where you left off
        global_step = int(resume_checkpoint.rsplit("_", 1)[-1])
    else:
        global_step = 0

    # 10) Training loop
    for epoch in range(1, epochs + 1):
        unet.train()
        for batch in dataloader:
            with accel.accumulate(unet):
                # Encode images to latents
                pixels  = batch["pixel_values"].to(device)
                latents = vae.encode(pixels).latent_dist.sample() * 0.18215

                # Add noise
                noise     = torch.randn_like(latents)
                timesteps = torch.randint(0, scheduler.num_train_timesteps,
                                          (latents.shape[0],), device=device)
                noisy_latents = scheduler.add_noise(latents, noise, timesteps)

                # Text conditioning
                input_ids             = batch["input_ids"].to(device)
                encoder_hidden_states = text_encoder(input_ids)[0]

                # Noise prediction & loss
                pred_noise = unet(noisy_latents, timesteps, encoder_hidden_states).sample
                loss       = torch.nn.functional.mse_loss(pred_noise, noise)

                # Backpropagate
                accel.backward(loss)
                optimizer.step()
                optimizer.zero_grad()

            global_step += 1
            if global_step % save_steps == 0:
                accel.wait_for_everyone()
                ckpt_dir = os.path.join(output_dir, f"checkpoint_{global_step}")
                unet.save_pretrained(ckpt_dir)
                if accel.is_main_process:
                    tokenizer.save_pretrained(ckpt_dir)
                accel.save_state(ckpt_dir)

        print(f"Epoch {epoch}/{epochs} complete")

    # 11) Final save
    accel.wait_for_everyone()
    final_dir = os.path.join(output_dir, "final_unet")
    unet.save_pretrained(final_dir)
    if accel.is_main_process:
        tokenizer.save_pretrained(output_dir)
    print("Fine-tuning complete — models saved to", output_dir)

In [4]:
from accelerate import Accelerator
# Optionally load the checkpoint’s state:
# from https://huggingface.co/nota-ai/bk-sdm-small
accel = Accelerator()
train(
    pretrained_model="nota-ai/bk-sdm-small",
    data_dir="./data/CNV",
    output_dir="./models/sd_cnv_finetuned",
    cohort = "CNV",
    resolution=512,
    dataset_size = 500, # 6 min per 100 data
    batch_size=4,
    learning_rate=1e-4,
    epochs=10,
    grad_accum_steps=1,
    save_steps=1000,
    resume_checkpoint=None,
    accelerator=accel
)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
  deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)


Epoch 1/10 complete
Epoch 2/10 complete
Epoch 3/10 complete
Epoch 4/10 complete
Epoch 5/10 complete
Epoch 6/10 complete
Epoch 7/10 complete
Epoch 8/10 complete
Epoch 9/10 complete
Epoch 10/10 complete
Fine-tuning complete — models saved to ./models/sd_cnv_finetuned


In [10]:
!zip -r models.zip ./models/


  adding: models/ (stored 0%)
  adding: models/.ipynb_checkpoints/ (stored 0%)
  adding: models/sd_cnv_finetuned/ (stored 0%)
  adding: models/sd_cnv_finetuned/final_unet/ (stored 0%)
  adding: models/sd_cnv_finetuned/final_unet/config.json (deflated 66%)
  adding: models/sd_cnv_finetuned/final_unet/diffusion_pytorch_model.safetensors (deflated 7%)
  adding: models/sd_cnv_finetuned/checkpoint_1000/ (stored 0%)
  adding: models/sd_cnv_finetuned/checkpoint_1000/random_states_0.pkl (deflated 25%)
  adding: models/sd_cnv_finetuned/checkpoint_1000/optimizer.bin (deflated 9%)
  adding: models/sd_cnv_finetuned/checkpoint_1000/model.safetensors (deflated 7%)
  adding: models/sd_cnv_finetuned/checkpoint_1000/config.json (deflated 66%)
  adding: models/sd_cnv_finetuned/checkpoint_1000/special_tokens_map.json (deflated 73%)
  adding: models/sd_cnv_finetuned/checkpoint_1000/diffusion_pytorch_model.safetensors (deflated 7%)
  adding: models/sd_cnv_finetuned/checkpoint_1000/tokenizer_config.json (de