In [None]:
!pip -q install --upgrade pip
!pip install -U numpy==1.26.4



In [None]:
!pip -q install --upgrade "torch==2.3.1" "torchvision==0.18.1" --index-url https://download.pytorch.org/whl/cpu


In [None]:
!pip -q install --upgrade diffusers==0.30.2 transformers==4.44.0 accelerate==0.33.0 safetensors==0.4.4 peft==0.12.0 datasets==2.21.0 pillow==10.4.0 tqdm==4.66.5

In [None]:
from pathlib import Path
import os

USE_GPU = False  # set True if you later enable GPU in Colab (Runtime > Change runtime type > GPU)

PROJECT = "wedding_dress_lora"
DATA_ZIP_PATH = "/content/drive/MyDrive/wedding_dataset.zip"  # <-- upload your zip to Colab, update path if needed
DATA_DIR = Path("/content/dress_images")
OUTPUT_DIR = Path(f"/content/{PROJECT}_output")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

MODEL_ID = "runwayml/stable-diffusion-v1-5"  # SD 1.5 (smaller & lighter than SDXL for CPU)
LORA_RANK = 8                                # LoRA rank (keep small on CPU)
RESOLUTION = 512
BATCH_SIZE = 1                               # keep small on CPU
MAX_TRAIN_STEPS = 400                        # small demo; increase if you can wait longer
LEARNING_RATE = 1e-4
CHECKPOINT_STEPS = 200
SEED = 42

device = "cuda" if USE_GPU else "cpu"
print(f"Device: {device}")


Device: cpu


In [None]:
import zipfile, shutil

DATA_DIR.mkdir(exist_ok=True, parents=True)
with zipfile.ZipFile(DATA_ZIP_PATH, 'r') as z:
    z.extractall(DATA_DIR)

# Optional: Flatten nested folders to a single images/ dir
IMAGES_DIR = Path("/content/images")
IMAGES_DIR.mkdir(parents=True, exist_ok=True)

def collect_images(src_dir, dst_dir):
    exts = {".jpg", ".jpeg", ".png", ".webp"}
    count = 0
    for p in src_dir.rglob("*"):
        if p.suffix.lower() in exts:
            newp = dst_dir / f"{count:06d}{p.suffix.lower()}"
            shutil.copy2(p, newp)
            count += 1
    return count

num = collect_images(DATA_DIR, IMAGES_DIR)
print("Collected images:", num)
assert num > 5, "Please provide more than 5 images for meaningful fine-tuning."


Collected images: 572


In [None]:
import torch
from PIL import Image
from tqdm import tqdm
from transformers import BlipProcessor, BlipForConditionalGeneration

cap_dir = Path("/content/captions")
cap_dir.mkdir(parents=True, exist_ok=True)

processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
blip.to(device)

def caption_image(path):
    raw_image = Image.open(path).convert("RGB")
    inputs = processor(raw_image, return_tensors="pt").to(device)
    out = blip.generate(**inputs, max_new_tokens=30)
    cap = processor.decode(out[0], skip_special_tokens=True)
    return cap

metadata = []
for img_path in tqdm(sorted(IMAGES_DIR.glob("*"))):
    cap = caption_image(img_path)
    # Keep captions simple; we'll augment with domain tokens during training prompts
    (cap_dir / (img_path.stem + ".txt")).write_text(cap)

print("Captioning complete.")


100%|██████████| 572/572 [1:03:12<00:00,  6.63s/it]

Captioning complete.





In [None]:
from datasets import load_dataset

# We’ll create a directory structure where each image has a same-named .txt caption.
# Diffusers dataset loader can read that easily.
print("Images:", len(list(IMAGES_DIR.glob("*"))))
print("Captions:", len(list(cap_dir.glob("*.txt"))))

# sanity: ensure each image has a caption
missing = []
for img in IMAGES_DIR.glob("*"):
    if not (cap_dir / f"{img.stem}.txt").exists():
        missing.append(img)
len(missing), missing[:3]


Images: 572
Captions: 572


(0, [])

In [None]:
# @title LoRA fine-tuning script (minimal)
import math, torch, random
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler, StableDiffusionPipeline
from transformers import CLIPTextModel, CLIPTokenizer
from peft import LoraConfig, get_peft_model, get_peft_model_state_dict
from torchvision import transforms
from tqdm import tqdm

torch.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)

class DressCaptionDataset(Dataset):
    def __init__(self, img_dir, cap_dir, resolution):
        self.images = sorted([p for p in img_dir.glob("*") if p.suffix.lower() in {".jpg",".jpeg",".png",".webp"}])
        self.cap_dir = cap_dir
        self.prep = transforms.Compose([
            transforms.Resize((resolution, resolution), interpolation=transforms.InterpolationMode.BILINEAR, antialias=True),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ])

    def __len__(self): return len(self.images)

    def __getitem__(self, idx):
        img_p = self.images[idx]
        txt_p = self.cap_dir / (img_p.stem + ".txt")
        caption = txt_p.read_text().strip() if txt_p.exists() else "wedding dress"
        image = Image.open(img_p).convert("RGB")
        image = self.prep(image)
        return {"pixel_values": image, "caption": caption}

dataset = DressCaptionDataset(IMAGES_DIR, cap_dir, RESOLUTION)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

# Load SD1.5 components
tokenizer = CLIPTokenizer.from_pretrained(MODEL_ID, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(MODEL_ID, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(MODEL_ID, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(MODEL_ID, subfolder="unet")

# LoRA wrap UNet attention
peft_config = LoraConfig(
    r=LORA_RANK,
    lora_alpha= LORA_RANK * 2,
    target_modules=["to_q","to_k","to_v","to_out.0"],  # common attention proj names
    lora_dropout=0.1,
    bias="none",
    task_type="UNET_T2I",
)
unet = get_peft_model(unet, peft_config)

# noise scheduler
noise_scheduler = DDPMScheduler.from_pretrained(MODEL_ID, subfolder="scheduler")

# move to device
text_encoder = text_encoder.to(device)
vae = vae.to(device)
unet = unet.to(device)

optimizer = torch.optim.AdamW(unet.parameters(), lr=LEARNING_RATE)

def encode_text(captions):
    enc = tokenizer(captions, padding="max_length", truncation=True, max_length=tokenizer.model_max_length, return_tensors="pt")
    for k in enc: enc[k] = enc[k].to(device)
    with torch.no_grad():
        out = text_encoder(**enc).last_hidden_state
    return out

global_step = 0
unet.train()

for step in range(MAX_TRAIN_STEPS):
    for batch in loader:
        pixel_values = batch["pixel_values"].to(device)
        captions = batch["caption"]
        # encode images to latents
        with torch.no_grad():
            latents = vae.encode(pixel_values).latent_dist.sample()*0.18215
        # sample noise
        noise = torch.randn_like(latents)
        timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (latents.shape[0],), dtype=torch.long, device=device)
        noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

        # text encoding
        encoder_hidden_states = encode_text(captions)

        # UNet predicts noise
        noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample
        loss = torch.nn.functional.mse_loss(noise_pred, noise)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        global_step += 1
        if global_step % 25 == 0:
            print(f"step {global_step}/{MAX_TRAIN_STEPS} - loss {loss.item():.4f}")
        if global_step % CHECKPOINT_STEPS == 0:
            # save LoRA adapter weights
            sd = get_peft_model_state_dict(unet)
            torch.save(sd, OUTPUT_DIR / f"unet_lora_step{global_step}.safetensors")

        if global_step >= MAX_TRAIN_STEPS:
            break

# final save
sd = get_peft_model_state_dict(unet)
torch.save(sd, OUTPUT_DIR / f"unet_lora_final.safetensors")
print("Training finished. LoRA saved to:", OUTPUT_DIR)


tokenizer_config.json:   0%|          | 0.00/806 [00:00<?, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/472 [00:00<?, ?B/s]



config.json:   0%|          | 0.00/617 [00:00<?, ?B/s]

text_encoder/model.safetensors:   0%|          | 0.00/492M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/547 [00:00<?, ?B/s]

vae/diffusion_pytorch_model.safetensors:   0%|          | 0.00/335M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/743 [00:00<?, ?B/s]

unet/diffusion_pytorch_model.safetensors:   0%|          | 0.00/3.44G [00:00<?, ?B/s]

scheduler_config.json:   0%|          | 0.00/308 [00:00<?, ?B/s]

  deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)


step 25/400 - loss 0.1050
step 50/400 - loss 0.0102
step 75/400 - loss 0.0051
step 100/400 - loss 0.0045
step 125/400 - loss 0.3809
step 150/400 - loss 0.2819
step 175/400 - loss 0.2563
step 200/400 - loss 0.0181
step 225/400 - loss 0.1265
step 250/400 - loss 0.0125
step 275/400 - loss 0.2248
step 300/400 - loss 0.0398
step 325/400 - loss 0.0253
step 350/400 - loss 0.0275
step 375/400 - loss 0.0439
step 400/400 - loss 0.0276
step 425/400 - loss 0.0051
step 450/400 - loss 0.0031
step 475/400 - loss 0.2067
step 500/400 - loss 0.0034
step 525/400 - loss 0.0536
step 550/400 - loss 0.3075
step 575/400 - loss 0.1720
step 600/400 - loss 0.0030
step 625/400 - loss 0.0071


In [None]:
# @title Rule-based prompt builder from body type + components
BODY_RULES = {
    "hourglass": "cinched waist, structured bodice, balanced skirt",
    "pear": "A-line skirt, embellished bodice, cap sleeves",
    "apple": "empire waist, flowy skirt, deep V neckline",
    "rectangle": "ruffled skirt, sweet heart neckline, defined waist",
    "inverted_triangle": "A-line or ballgown skirt, off-the-shoulder neckline",
}

def build_prompt(
    body_type="hourglass",
    sleeve="long sleeves",
    neckline="sweetheart neckline",
    bodice="structured bodice",
    skirt="A-line skirt",
    train="chapel train",
    structure="tailored fit",
    extra="high-quality fabric, realistic, detailed, studio lighting"
):
    base = BODY_RULES.get(body_type, "")
    prompt = f"elegant white wedding dress, {sleeve}, {neckline}, {bodice}, {skirt}, {train}, {structure}, {base}, {extra}"
    negative = "low quality, deformed, extra limbs, text, watermark, blurry, bad anatomy, worst quality"
    return prompt, negative

print(build_prompt("pear"))


('elegant white wedding dress, long sleeves, sweetheart neckline, structured bodice, A-line skirt, chapel train, tailored fit, A-line skirt, embellished bodice, cap sleeves, high-quality fabric, realistic, detailed, studio lighting', 'low quality, deformed, extra limbs, text, watermark, blurry, bad anatomy, worst quality')


In [None]:
# @title Inference: Load base SD + your LoRA and generate
import torch
from diffusers import StableDiffusionPipeline
from peft import PeftModel

pipe = StableDiffusionPipeline.from_pretrained(
    MODEL_ID, torch_dtype=torch.float32 if device=="cpu" else torch.float16, safety_checker=None
).to(device)

# Attach LoRA
# We injected LoRA into the UNet; to load at inference, we patch weights similarly:
from peft import LoraConfig, set_peft_model_state_dict
from diffusers.models.unet_2d_condition import UNet2DConditionModel

# load LoRA state dict
lora_sd_path = OUTPUT_DIR / "unet_lora_final.safetensors"
import safetensors.torch as sf
lora_state = sf.load_file(str(lora_sd_path))

# monkey-patch: find UNet and load lora weights into it (names must match)
set_peft_model_state_dict(pipe.unet, lora_state, adapter_name="default")

# Sampling
body_type = "hourglass"
prompt, negative = build_prompt(
    body_type=body_type,
    sleeve="off-the-shoulder sleeves",
    neckline="sweetheart neckline",
    bodice="corset bodice",
    skirt="A-line skirt",
    train="cathedral train",
    structure="tailored fit"
)

g = torch.Generator(device=device).manual_seed(1234)
image = pipe(
    prompt=prompt,
    negative_prompt=negative,
    guidance_scale=7.0,
    num_inference_steps=25,      # keep small on CPU
    height=RESOLUTION, width=RESOLUTION,
    generator=g
).images[0]

display(image)
