In [None]:
!pip install -q \
  diffusers==0.24.0 \
  transformers==4.36.2 \
  huggingface_hub==0.19.4 \
  accelerate==0.25.0 \
  datasets==2.14.5 \
  safetensors==0.4.0 \
  Pillow

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m126.8/126.8 kB[0m [31m5.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m45.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.2/8.2 MB[0m [31m117.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m311.7/311.7 kB[0m [31m25.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m265.7/265.7 kB[0m [31m24.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m519.6/519.6 kB[0m [31m41.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m70.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.3/115.3 kB[0m [31m11.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [1]:
import os, torch, torch.nn as nn
from torch.utils.data import DataLoader
from datasets import load_dataset
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm

from diffusers import StableDiffusionPipeline, DDPMScheduler
from diffusers.models.attention_processor import LoRAAttnProcessor
from huggingface_hub import hf_hub_download, model_info

In [2]:
import os, io, glob, math
from typing import Any
from PIL import Image
from torchvision.transforms import InterpolationMode
from tqdm.auto import tqdm

In [None]:
MODEL_ID         = "runwayml/stable-diffusion-v1-5"
LOCAL_PARQUET    = "train-00000-of-00001.parquet"    # downloaded file (optional)
FALLBACK_DATASET = "huggan/pokemon"                  # public; no auth
OUT_DIR          = "sd15_pokemon_lora"
MAX_STEPS        = 1000
BATCH_SIZE       = 2
LR               = 1e-4
IMG_SIZE         = 512
GUIDANCE         = 7.5
SEED             = 42

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE  = torch.float16 if DEVICE == "cuda" else torch.float32
os.makedirs(OUT_DIR, exist_ok=True)
torch.manual_seed(SEED)

# -------------------
# Transforms (SD expects [-1,1])
# -------------------
tf = transforms.Compose([
    transforms.Resize(IMG_SIZE, interpolation=InterpolationMode.BICUBIC),
    transforms.CenterCrop(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize([0.5],[0.5]),
])

# -------------------
# Helpers: robust image normalization → PIL.Image
# -------------------
def to_pil(img_obj: Any) -> Image.Image:
    """
    Convert various dataset image representations to a RGB PIL.Image:
    - PIL.Image.Image
    - bytes / bytearray
    - dict with 'bytes' and/or 'path'
    - str path
    """
    if isinstance(img_obj, Image.Image):
        return img_obj.convert("RGB")
    if isinstance(img_obj, (bytes, bytearray)):
        return Image.open(io.BytesIO(img_obj)).convert("RGB")
    if isinstance(img_obj, dict):
        b = img_obj.get("bytes", None)
        p = img_obj.get("path", None)
        if b is not None:
            return Image.open(io.BytesIO(b)).convert("RGB")
        if p is not None:
            return Image.open(p).convert("RGB")
    if isinstance(img_obj, str):
        return Image.open(img_obj).convert("RGB")
    # Last-ditch: some datasets wrap a dict in another container
    raise TypeError(f"Unsupported image type for PIL conversion: {type(img_obj)}")

def map_record(ex):
    """
    Standardize a dataset example to:
      { 'pixel_values': tensor[-1..1], 'caption': str }
    Picks a plausible image-like field automatically; synthesizes caption if missing.
    """
    # try common image keys first
    img_like = ex.get("image", None) or ex.get("img", None) or ex.get("image_bytes", None) or ex.get("file", None)
    if img_like is None:
        # pick first non-caption-like field as image
        for k, v in ex.items():
            if k.lower() not in {"text","prompt","caption","label","labels","class"}:
                img_like = v
                break
    if img_like is None:
        raise KeyError("No image-like field found in example. Keys: " + ", ".join(ex.keys()))

    pil = to_pil(img_like)
    cap = ex.get("text") or ex.get("prompt") or ex.get("caption") or "a cute pokemon creature, high detail"
    return {"pixel_values": tf(pil), "caption": cap}

# -------------------
# Dataset loader (local parquet → fallback to public)
# -------------------
def load_pokemon_dataset():
    if os.path.exists(LOCAL_PARQUET):
        ds = load_dataset("parquet", data_files=LOCAL_PARQUET, split="train")
        ds = ds.map(map_record, remove_columns=ds.column_names)
        return ds
    ds = load_dataset(FALLBACK_DATASET, split="train")
    ds = ds.map(map_record, remove_columns=ds.column_names)
    return ds

dataset = load_pokemon_dataset()
print(dataset)
print("Example keys:", dataset[0].keys())
print("Example caption:", dataset[0]["caption"][:80])

loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, drop_last=True)

# -------------------
# SD Pipeline + LoRA setup
# -------------------
pipe = StableDiffusionPipeline.from_pretrained(MODEL_ID, torch_dtype=DTYPE)
pipe.scheduler = DDPMScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to(DEVICE)

# freeze base; train only LoRA
pipe.vae.requires_grad_(False)
pipe.text_encoder.requires_grad_(False)
pipe.unet.requires_grad_(False)

# Attach LoRA adapters to all attention processors
rank = 8
attn_procs = {}
for name, proc in pipe.unet.attn_processors.items():
    cross_dim = getattr(proc, "cross_attention_dim", None)
    hidden_size = getattr(proc, "hidden_size", pipe.unet.config.attention_head_dim)
    attn_procs[name] = LoRAAttnProcessor(
        hidden_size=hidden_size, cross_attention_dim=cross_dim, rank=rank
    )
pipe.unet.set_attn_processor(attn_procs)

# optimizer on LoRA params only
lora_params = [p for p in pipe.unet.parameters() if p.requires_grad]
optimizer = torch.optim.AdamW(lora_params, lr=LR)

# shortcuts
tokenizer       = pipe.tokenizer
text_encoder    = pipe.text_encoder
vae             = pipe.vae
unet            = pipe.unet
noise_scheduler = pipe.scheduler

scaler = torch.cuda.amp.GradScaler(enabled=(DEVICE=="cuda"))
global_step, loss_acc = 0, 0.0
pbar = tqdm(total=MAX_STEPS, desc="Training (LoRA)")

# -------------------
# Training loop
# -------------------
while global_step < MAX_STEPS:
    for batch in loader:
        if global_step >= MAX_STEPS:
            break

        with torch.no_grad():
            # captions -> text embeddings
            tok = tokenizer(
                batch["caption"], padding="max_length", truncation=True,
                max_length=tokenizer.model_max_length, return_tensors="pt"
            )
            enc = text_encoder(tok.input_ids.to(DEVICE))[0]

            # images -> latents
            imgs = batch["pixel_values"].to(DEVICE, dtype=DTYPE)
            latents = vae.encode(imgs).latent_dist.sample() * 0.18215  # SD scaling

        # add noise
        noise = torch.randn_like(latents)
        t = torch.randint(0, noise_scheduler.config.num_train_timesteps,
                          (latents.shape[0],), device=DEVICE, dtype=torch.long)
        noisy = noise_scheduler.add_noise(latents, noise, t)

        optimizer.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=(DEVICE=="cuda")):
            pred = unet(noisy, t, encoder_hidden_states=enc).sample
            loss = nn.functional.mse_loss(pred, noise)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        global_step += 1
        loss_acc += loss.item()

        if global_step % 50 == 0:
            print(f"step {global_step}: loss {loss_acc/50:.4f}")
            loss_acc = 0.0

        if global_step % 250 == 0 or global_step == MAX_STEPS:
            save_dir = os.path.join(OUT_DIR, f"lora_step_{global_step}")
            os.makedirs(save_dir, exist_ok=True)
            pipe.unet.save_attn_procs(save_dir)
            print("Saved LoRA to:", save_dir)

        pbar.update(1)

pbar.close()
print("Training done!")

# -------------------
# Inference with latest LoRA
# -------------------
chkpts = sorted(glob.glob(os.path.join(OUT_DIR, "lora_step_*")),
                key=lambda p: int(p.split("_")[-1]))
assert len(chkpts) > 0, "No LoRA checkpoints found!"
latest = chkpts[-1]
pipe.unet.load_attn_procs(latest)
print("Loaded LoRA:", latest)

prompt = "a cute watercolor pokemon, pastel colors, high detail"
neg    = "blurry, low quality, watermark"
images = pipe(prompt, num_inference_steps=30, guidance_scale=GUIDANCE,
              negative_prompt=neg, num_images_per_prompt=2).images

os.makedirs("samples", exist_ok=True)
for i, im in enumerate(images):
    fp = f"samples/poke_{i}.png"
    im.save(fp)
    print("Saved:", fp)

print("All done ✅")

Map:   0%|          | 0/7357 [00:00<?, ? examples/s]