# 🎨 LoRA Fine-Tuning with CFG + DDIM (Stable Diffusion v1.5)
This notebook runs on Google Colab and supports training a LoRA model with Classifier-Free Guidance, followed by inference using all combinations of DDIM/CFG.

In [None]:
!pip install diffusers transformers torch accelerate peft bitsandbytes
!pip install pytorch-fid

## 📂 Upload your dataset
Upload `captions.txt` and your training images in the `images/` folder.

## 🏋️ Train LoRA with CFG

In [None]:

import os
import torch
import numpy as np
from torch import nn
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
from diffusers import StableDiffusionPipeline, DDIMScheduler, PNDMScheduler
from transformers import CLIPTokenizer
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from torch.cuda.amp import GradScaler, autocast
from torch.optim.lr_scheduler import CosineAnnealingLR
import matplotlib.pyplot as plt
from pytorch_fid import fid_score

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 1
EPOCHS = 10
IMAGE_SIZE = 512
ACCUMULATION_STEPS = 4
PATIENCE = 3
VALID_RATIO = 0.1
LORA_CONFIG = {
    "r": 16,
    "lora_alpha": 32,
    "lora_dropout": 0.1,
    "bias": "none"
}

pipe = StableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    torch_dtype=torch.float16,
    use_safetensors=True
)
tokenizer = pipe.tokenizer

def prepare_lora_model(model, target_modules):
    model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)
    lora_config = LoraConfig(**LORA_CONFIG, target_modules=target_modules)
    model = get_peft_model(model, lora_config)
    for name, param in model.named_parameters():
        param.requires_grad = "lora" in name
    return model

unet = prepare_lora_model(pipe.unet.to(DEVICE), [
    "attn1.to_q", "attn1.to_k", "attn1.to_v", "attn1.to_out.0",
    "attn2.to_q", "attn2.to_k", "attn2.to_v", "attn2.to_out.0"
]).train()
text_encoder = prepare_lora_model(pipe.text_encoder.to(DEVICE), [
    "q_proj", "k_proj", "v_proj", "out_proj"
]).train()
vae = pipe.vae.to(DEVICE).eval()

class GhibliDataset(Dataset):
    def __init__(self, image_dir, caption_file):
        self.image_dir = image_dir
        with open(caption_file, 'r', encoding='utf-8') as f:
            self.data = [line.strip().split("|") for line in f]
        self.transform = transforms.Compose([
            transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])
        ])
    def __len__(self): return len(self.data)
    def __getitem__(self, idx):
        fname, caption = self.data[idx]
        image = Image.open(os.path.join(self.image_dir, fname)).convert("RGB")
        input_ids = tokenizer(caption, return_tensors="pt", padding="max_length", truncation=True, max_length=77).input_ids.squeeze(0)
        return self.transform(image), input_ids

full_dataset = GhibliDataset("images", "captions.txt")
val_size = int(len(full_dataset) * VALID_RATIO)
train_size = len(full_dataset) - val_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

optimizer = torch.optim.AdamW(
    [p for p in unet.parameters() if p.requires_grad] + [p for p in text_encoder.parameters() if p.requires_grad],
    lr=1e-4
)
scheduler_lr = CosineAnnealingLR(optimizer, T_max=EPOCHS)
scheduler = DDIMScheduler(num_train_timesteps=1000)
scaler = GradScaler()

def save_model(unet, text_encoder, vae, suffix):
    unet.save_pretrained(f"ghibli-unet-lora{suffix}")
    text_encoder.save_pretrained(f"ghibli-text-lora{suffix}")
    vae.save_pretrained(f"ghibli-vae{suffix}")

def evaluate_fid():
    os.makedirs("val_generated", exist_ok=True)
    unet.eval()
    text_encoder.eval()
    for idx, (pixel_values, input_ids) in enumerate(val_loader):
        pixel_values, input_ids = pixel_values.to(DEVICE), input_ids.to(DEVICE)
        cond_emb = text_encoder(input_ids)[0]
        latents = vae.encode(pixel_values).latent_dist.sample() * 0.18215
        noise = torch.randn_like(latents)
        timesteps = torch.randint(0, 1000, (latents.size(0),), device=latents.device).long()
        noisy_latents = scheduler.add_noise(latents, noise, timesteps)
        with torch.no_grad(), autocast():
            pred = unet(noisy_latents, timesteps, cond_emb).sample
        pred_image = vae.decode(pred / 0.18215).sample
        pred_image = (pred_image.clamp(-1, 1) + 1) / 2.0
        save_path = f"val_generated/gen_{idx}.png"
        transforms.ToPILImage()(pred_image[0].cpu()).save(save_path)
    fid = fid_score.calculate_fid_given_paths(["images", "val_generated"], batch_size=1, device=DEVICE, dims=2048)
    return fid

best_loss, patience_counter = float('inf'), 0
train_losses, fids = [], []
for epoch in range(EPOCHS):
    unet.train()
    text_encoder.train()
    vae.eval()
    print(f"Epoch {epoch+1}/{EPOCHS}")
    epoch_loss = 0.0
    optimizer.zero_grad()
    for i, (pixel_values, input_ids) in enumerate(tqdm(train_loader)):
        pixel_values, input_ids = pixel_values.to(DEVICE), input_ids.to(DEVICE)

        cond_emb = text_encoder(input_ids)[0]
        uncond_ids = tokenizer([""] * input_ids.size(0), return_tensors="pt", padding="max_length", truncation=True, max_length=77).input_ids.to(DEVICE)
        uncond_emb = text_encoder(uncond_ids)[0]
        encoder_hidden_states = torch.cat([uncond_emb, cond_emb])

        latents = vae.encode(pixel_values).latent_dist.sample() * 0.18215
        noise = torch.randn_like(latents)
        timesteps = torch.randint(0, 1000, (latents.size(0),), device=latents.device).long()
        noisy_latents = scheduler.add_noise(latents, noise, timesteps)

        noisy_latents = torch.cat([noisy_latents, noisy_latents])
        timesteps = torch.cat([timesteps, timesteps])
        noise = torch.cat([noise, noise])

        with autocast():
            noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
            noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
            guidance_scale = 7.5
            noise_guided = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
            loss = nn.functional.mse_loss(noise_guided, noise[:input_ids.size(0)]) / ACCUMULATION_STEPS

        scaler.scale(loss).backward()
        if (i + 1) % ACCUMULATION_STEPS == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

        epoch_loss += loss.item() * ACCUMULATION_STEPS

    avg_loss = epoch_loss / len(train_loader)
    train_losses.append(avg_loss)
    fid_score_val = evaluate_fid()
    fids.append(fid_score_val)

    print(f"Avg Loss: {avg_loss:.4f}, FID: {fid_score_val:.2f}")
    scheduler_lr.step()

    if avg_loss < best_loss:
        best_loss = avg_loss
        patience_counter = 0
        save_model(unet, text_encoder, vae, "-best")
    else:
        patience_counter += 1
        if patience_counter >= PATIENCE:
            print("Early stopping")
            break

save_model(unet, text_encoder, vae, "")

# 儲存並畫圖
np.save("train_losses.npy", np.array(train_losses))
np.save("fid_scores.npy", np.array(fids))
plt.plot(train_losses, label="Loss")
plt.plot(fids, label="FID")
plt.legend()
plt.title("Training Loss & FID")
plt.savefig("training_curves.png")


## 📉 Visualize training curves

In [None]:
import matplotlib.pyplot as plt
import numpy as np
losses = np.load('train_losses.npy')
fids = np.load('fid_scores.npy')
plt.plot(losses, label='Loss')
plt.plot(fids, label='FID')
plt.legend()
plt.title('Training Loss & FID')
plt.show()

## 🧪 Run Inference with DDIM/DDPM × CFG/NoCFG

In [None]:

import torch
from diffusers import StableDiffusionPipeline, UNet2DConditionModel, AutoencoderKL, DDIMScheduler
from transformers import CLIPTextModel, CLIPTokenizer
from peft import PeftModelPNDMScheduler

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
prompt = "a girl standing in a greenhouse, Studio Ghibli style"

# Load trained components
unet = PeftModel.from_pretrained(UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet", torch_dtype=torch.float16), "ghibli-unet-lora-best").to(DEVICE).eval()
text_encoder = PeftModel.from_pretrained(CLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder", torch_dtype=torch.float16), "ghibli-text-lora-best").to(DEVICE).eval()
vae = AutoencoderKL.from_pretrained("ghibli-vae-best", torch_dtype=torch.float16).to(DEVICE).eval()
for param in vae.parameters():
    param.requires_grad = False
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")

def generate_image(use_ddim, guidance_scale, tag):
    scheduler = DDIMScheduler.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="scheduler") if use_ddim else DDPMScheduler.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="scheduler")
    pipe = StableDiffusionPipeline(
        unet=unet,
        text_encoder=text_encoder,
        vae=vae,
        tokenizer=tokenizer,
        scheduler=scheduler,
        safety_checker=None,
        feature_extractor=None,
        torch_dtype=torch.float16
    ).to(DEVICE)

    image = pipe(prompt, num_inference_steps=30, guidance_scale=guidance_scale).images[0]
    image.save(f"gen_{tag}.png")
    image.show()

generate_image(use_ddim=False, guidance_scale=1.0, tag="ddpm_nocfg")
generate_image(use_ddim=False, guidance_scale=7.5, tag="ddpm_cfg")
generate_image(use_ddim=True, guidance_scale=1.0, tag="ddim_nocfg")
generate_image(use_ddim=True, guidance_scale=7.5, tag="ddim_cfg")
