# RLHF Diffusion Fine-Tuning for Age, Gender, Ethnicity, Emotion
We will:
1. Load dataset and metadata.
2. Preprocess images and labels.
3. Inject LoRA into U-Net.
4. Add custom label-conditioning embedding.
5. Train the model.
6. Inference for conditioned image generation.


In [1]:
# Install necessary packages
# !pip install pandas Pillow tqdm torch torchvision transformers diffusers scikit-learn IProgress ipywidgets accelerate peft matplotlib

In [2]:
# Imports and Configuration
import os
import gc
import uuid
import json
import random
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
from pathlib import Path
from typing import Tuple, Dict, List, Optional, Any
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

from diffusers.models.attention_processor import LoRAAttnProcessor2_0, AttnProcessor
from diffusers import StableDiffusionPipeline

from transformers import CLIPTokenizer

from peft import LoraConfig

In [3]:
# RUN_ID: str = str(uuid.uuid4()).replace('-', '')[:6]
# print(f"RUN_ID: {RUN_ID}")

# Define Batch Size
batch_size: int = 2

# Dataset Paths
dataset_root: Path = Path('./datasets/appa-real-dataset_v2')
labels_md_train = dataset_root / 'labels_metadata_train.csv'
labels_md_valid = dataset_root / 'labels_metadata_valid.csv'
labels_md_test  = dataset_root / 'labels_metadata_test.csv'

ds_train = dataset_root / 'train_data'
ds_valid = dataset_root / 'valid_data'
ds_test  = dataset_root / 'test_data'

# Load Metadata
df_md_train = pd.read_csv(labels_md_train)
df_md_valid = pd.read_csv(labels_md_valid)
df_md_test  = pd.read_csv(labels_md_test)

print(f"Train: {df_md_train.shape}, Valid: {df_md_valid.shape}, Test: {df_md_test.shape}")

Train: (4065, 5), Valid: (1482, 5), Test: (1978, 5)


In [4]:
def setup_run_logging(
    output_base_dir: Path,
    resume_unet_checkpoint_path: Optional[Path] = None
) -> Dict[str, Any]:
    """
    Sets up the logging directory for a training run, generates a RUN_ID,
    and determines if the run is new or resumed.

    Args:
        output_base_dir (Path): The base directory where all run data will be saved.
        resume_unet_checkpoint_path (Optional[Path]): Path to a .pth file
                                                      if resuming training.

    Returns:
        Dict[str, Any]: A dictionary containing run_id, run_dir, history_file,
                        and a boolean indicating if training is resumed.
    """
    is_resumed_run = False
    run_id = None
    run_dir = None
    history_file = None

    if resume_unet_checkpoint_path and resume_unet_checkpoint_path.exists():
        is_resumed_run = True
        # Extract RUN_ID from the parent directory of the checkpoint path
        # Assuming checkpoints are saved in a run-specific directory
        parts = resume_unet_checkpoint_path.parts
        # Find the part that looks like a RUN_ID (e.g., '1a2b3c_run_YYYYMMDD-HHMMSS')
        for part in reversed(parts):
            if '_run_' in part and len(part.split('_')[0]) == 6: # Heuristic for 6-char UUID prefix
                run_id = part
                break
        if run_id is None:
            # Fallback if the expected RUN_ID format isn't found in path
            print("Warning: Could not extract RUN_ID from resume path. Generating new one.")
            run_id = f"{str(uuid.uuid4()).replace('-', '')[:6]}_run_{pd.Timestamp.now().strftime('%Y%m%d-%H%M%S')}"
        
        # Reconstruct run_dir based on the identified run_id within the output_base_dir
        run_dir = output_base_dir / run_id
        if not run_dir.exists():
            print(f"Warning: Resuming from checkpoint but expected run directory {run_dir} not found. Creating it.")
            run_dir.mkdir(parents=True, exist_ok=True)

        history_file = run_dir / "training_history.json"
        print(f"🔄 Resuming existing run with RUN_ID: {run_id}")
    else:
        # New run
        run_id = f"{str(uuid.uuid4()).replace('-', '')[:6]}_run_{pd.Timestamp.now().strftime('%Y%m%d-%H%M%S')}"
        run_dir = output_base_dir / run_id
        run_dir.mkdir(parents=True, exist_ok=True)
        history_file = run_dir / "training_history.json"
        # Initialize history file for new runs
        with open(history_file, 'w') as f:
            json.dump({"epochs": []}, f)
        print(f"✨ Starting new run with RUN_ID: {run_id}")

    # Create subdirectories for checkpoints and samples within the run_dir
    (run_dir / "lora_checkpoints").mkdir(exist_ok=True)
    (run_dir / "lora_samples").mkdir(exist_ok=True)

    return {
        "run_id": run_id,
        "run_dir": run_dir,
        "history_file": history_file,
        "is_resumed_run": is_resumed_run
    }

In [5]:
class ImageWithPromptDataset(Dataset):
    def __init__(
        self,
        df_md: pd.DataFrame,
        images_dir: Path,
        tokenizer: CLIPTokenizer,
        transform: Optional[transforms.Compose],
    ):
        self.df = df_md
        self.images_dir = images_dir
        self.tokenizer = tokenizer
        self.transform = transform

    def __len__(self) -> int:
        return len(self.df)

    def build_prompt(self,row):
        age_desc = f"{int(row['age'])} years old"
        gender_desc = row['gender']
        ethnicity_desc = row['ethnicity']

        # Emotion mapping
        emotion_map = {
            'neutral': "with a neutral expression",
            'happy': "smiling happily",
            'slightlyhappy': "smiling slightly",
            'other': "showing a subtle emotion"
        }
        emotion_desc = emotion_map.get(row['emotion'], "with an expression")  # fallback if unknown

        prompt = f"A {age_desc} {ethnicity_desc} {gender_desc} {emotion_desc}"
        return prompt

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        row = self.df.iloc[idx]
        img_name = f"{int(row['imageId']):06d}.jpg"
        img_path = self.images_dir / img_name

        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)

        prompt = self.build_prompt(row)
        prompt_ids = self.tokenizer(prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=77).input_ids[0]

        return {
            'pixel_values': image,
            'prompt_ids': prompt_ids
        }

In [6]:
# Data Transforms & Loaders
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")

train_transform = transforms.Compose([
    # transforms.Resize((224, 224)),
    transforms.Resize((512, 512)), # Stable Diffusion 1.5 resolution
    transforms.ToTensor(),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

val_transform = transforms.Compose([
    # transforms.Resize((224, 224)),
    transforms.Resize((512, 512)), # Stable Diffusion 1.5 resolution
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

train_dataset = ImageWithPromptDataset(df_md_train, ds_train, tokenizer, transform=train_transform)
valid_dataset = ImageWithPromptDataset(df_md_valid, ds_valid, tokenizer, transform=val_transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size)

In [7]:
def load_model(lora_r: int = 64, lora_alpha: int = 64, lora_dropout: float = 0.05, device: str = 'cuda', verbose:bool = False) -> StableDiffusionPipeline:
    """
    Loads the Stable Diffusion pipeline with LoRA configuration and ensures correct trainable parameters.

    Args:
        lora_r (int): LoRA rank.
        lora_alpha (int): LoRA alpha scaling.
        lora_dropout (float): Dropout probability for LoRA.
        device (str): Device to load the model onto.

    Returns:
        StableDiffusionPipeline: The configured pipeline ready for training.
    """

    # 1. Load base pipeline
    pipe = StableDiffusionPipeline.from_pretrained(
        "runwayml/stable-diffusion-v1-5",
        torch_dtype=torch.float16
    ).to(device)
    print("✅ Base model loaded.")

    # 2. Define LoRA PEFT config
    peft_lora_config = LoraConfig(
        r=lora_r,
        lora_alpha=lora_alpha,
        target_modules=["to_q", "to_v"],
        lora_dropout=lora_dropout,
        bias="none"
    )

    # 3. Add LoRA Adapter
    pipe.unet.add_adapter(adapter_name="age_gender_lora", adapter_config=peft_lora_config)
    print("✅ LoRA adapter added.")

    # 4. Enable LoRA layers for training
    pipe.unet.enable_lora()
    print("✅ LoRA adapter enabled for training.")

    # 5. Freeze VAE & Text Encoder
    pipe.vae.requires_grad_(False)
    pipe.text_encoder.requires_grad_(False)
    print("✅ VAE and Text Encoder frozen.")

    # 6. Debug LoRA Injection
    print("\n🔍 Checking LoRA Layers Injected into UNet...")
    lora_param_count = 0
    for name, param in pipe.unet.named_parameters():
        if 'lora' in name:
            param_dtype = param.dtype
            nan_inf = torch.isnan(param).any() or torch.isinf(param).any()
            if verbose: print(f" - {name}: shape={param.shape}, mean={param.data.mean():.6f}, std={param.data.std():.6f}, dtype={param_dtype}, requires_grad={param.requires_grad}, nan/inf={nan_inf}")
            lora_param_count += 1

    if lora_param_count == 0:
        raise RuntimeError("❌ No LoRA parameters found in UNet! Injection failed.")

    # 7. Ensure ONLY LoRA Layers Are Trainable
    non_lora_trainable = [n for n, p in pipe.unet.named_parameters() if p.requires_grad and 'lora' not in n]
    if non_lora_trainable:
        print("❌ ERROR: Found non-LoRA parameters set to requires_grad=True:")
        for n in non_lora_trainable:
            print(f" - {n}")
        raise RuntimeError("Non-LoRA params are trainable! You must freeze them explicitly.")
    else:
        print("✅ Only LoRA layers are trainable.")

    print("\n🚀 LoRA setup successful. Ready for training loop.")
    return pipe

def unload_model(pipe: StableDiffusionPipeline) -> None:
    """
    Unloads the UNet from the pipeline and clears GPU memory.

    Args:
        pipe (StableDiffusionPipeline): The current pipeline to unload from.
    """
    print("🔻 Unloading UNet to free GPU memory...")
    del pipe.unet
    torch.cuda.empty_cache()
    gc.collect()
    print("✅ UNet unloaded and GPU cache cleared.")

# --------------------------------------------------------------------------------------------------
## Updated `train_lora_diffusion` function with Gradient Checkpointing
# --------------------------------------------------------------------------------------------------

def train_lora_diffusion(
    num_epochs: int,
    train_dataloader: DataLoader,
    valid_dataloader: DataLoader,
    dataset_for_sampling: Dataset,
    learning_rate: float = 3e-5,
    lora_r: int = 64,
    lora_alpha: int = 64,
    lora_dropout: float = 0.05,
    gradient_checkpoint_enable: bool = True, # New parameter for gradient checkpointing
    gradient_accumulation_steps: int = 1, # Added gradient accumulation steps for completeness
    output_base_dir: Path = Path("./lora_training_runs"), # Changed to base dir
    resume_unet_checkpoint_path: Optional[Path] = None,
    device: str = 'cuda',
    verbose: bool = False,
) -> None:
    """
    Trains a Stable Diffusion model with LoRA for age, gender, ethnicity, and emotion conditioning.
    Can save both LoRA adapter weights and full U-Net checkpoints for resuming.

    Args:
        num_epochs (int): Number of training epochs.
        train_dataloader (DataLoader): DataLoader for the training dataset.
        valid_dataloader (DataLoader): DataLoader for the validation dataset.
        dataset_for_sampling (Dataset): The dataset object used for generating random prompts for sample images.
        learning_rate (float): Learning rate for the AdamW optimizer.
        lora_r (int): LoRA rank.
        lora_alpha (int): LoRA alpha scaling.
        lora_dropout (float): Dropout probability for LoRA.
        gradient_checkpoint_enable (bool): Whether to enable gradient checkpointing to save GPU memory.
        output_dir (Path): Directory to save all checkpoints.
        sample_dir (Path): Directory to save generated sample images.
        resume_unet_checkpoint_path (Optional[Path]): Path to a .pth file containing the full U-Net state (including LoRA) to resume training from. If None, train from scratch.
        device (str): Device to perform training on ('cuda' or 'cpu').
        verbose (bool): If True, prints additional details during initialization.
        gradient_accumulation_steps (int): Number of updates steps to accumulate gradients before performing a backward/update pass.
                                           Used to simulate a larger effective batch size.
    """

    # --- Setup Run Logging ---
    run_info = setup_run_logging(output_base_dir, resume_unet_checkpoint_path)
    run_id = run_info["run_id"]
    run_dir = run_info["run_dir"]
    history_file = run_info["history_file"]
    is_resumed_run = run_info["is_resumed_run"]

    # Adjust output_dir and sample_dir based on the specific run directory
    output_dir = run_dir / "lora_checkpoints"
    sample_dir = run_dir / "lora_samples"

    # Load existing history or start new
    training_history = {"epochs": []}
    if is_resumed_run and history_file.exists():
        with open(history_file, 'r') as f:
            training_history = json.load(f)
        print(f"📊 Loaded existing training history from {history_file}")
    
    # Determine starting epoch for resumed runs
    start_epoch = 0
    if is_resumed_run and training_history["epochs"]:
        # Assuming each entry in "epochs" corresponds to one epoch
        start_epoch = len(training_history["epochs"])
        print(f"Resuming from epoch {start_epoch + 1}.")

    pipe: StableDiffusionPipeline = load_model(lora_r=lora_r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, device=device)
    pipe.safety_checker = lambda images, **kwargs: (images, [False] * len(images))

    # --- Enable gradient checkpointing (conditional) ---
    if gradient_checkpoint_enable:
        pipe.unet.enable_gradient_checkpointing()
        print(f"✅ Gradient checkpointing enabled for UNet: gradient_accumulation_steps = {gradient_accumulation_steps}")

    # Load full U-Net checkpoint if resuming training
    if resume_unet_checkpoint_path is not None: # Changed from != None to is not None for better practice
        if resume_unet_checkpoint_path.exists():
            print(f"🔄 Resuming training from full U-Net checkpoint: {resume_unet_checkpoint_path}")
            # Load the entire UNet state_dict
            unet_state_dict = torch.load(resume_unet_checkpoint_path, map_location='cpu')
            pipe.unet.load_state_dict(unet_state_dict)
            print("✅ Full U-Net state loaded for resuming training.")
        else:
            print(f"❌ Resume U-Net checkpoint not found at {resume_unet_checkpoint_path}. Starting training from scratch.")

    scaler = torch.amp.GradScaler(device)

    # Initialize LoRA_B weights if zero (for new LoRA layers or when starting from scratch)
    for name, module in pipe.unet.named_modules():
        if hasattr(module, 'lora_B'):
            if isinstance(module.lora_B, torch.nn.ModuleDict):
                for sub_name, sub_module in module.lora_B.items():
                    if hasattr(sub_module, 'weight'):
                        if sub_module.weight.abs().sum() == 0:
                            torch.nn.init.kaiming_uniform_(sub_module.weight)
                            if verbose: print(f"Initialized lora_B: {name}.{sub_name}")
            elif hasattr(module.lora_B, 'weight'):
                if module.lora_B.weight.abs().sum() == 0:
                    torch.nn.init.kaiming_uniform_(module.lora_B.weight)
                    if verbose: print(f"Initialized lora_B: {name}")

    lora_params = [p for n, p in pipe.unet.named_parameters() if p.requires_grad]
    for p in lora_params:
        p.data = p.data.float()

    optimizer = torch.optim.AdamW(lora_params, lr=learning_rate)

    print(f"\n🚀 Starting training for {num_epochs} epochs...")

    for epoch in range(num_epochs):
        pipe.unet.train()
        total_loss = 0.0

        for batch_idx, batch in enumerate(tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")):
            pixel_values = batch['pixel_values'].to(device, dtype=torch.float32)

            with torch.no_grad():
                latents = pipe.vae.to(dtype=torch.float32).encode(pixel_values).latent_dist.sample()
                latents = latents.to(dtype=torch.float16) * 0.18215

            if torch.isnan(latents).any() or torch.isinf(latents).any():
                print(f"[Batch {batch_idx}] Latents NaN/Inf Detected! Skipping batch.")
                # This could cause issues with gradient accumulation if you skip a batch without accumulating
                # Consider adding a check for this if accumulation is enabled
                continue

            noise = torch.randn_like(latents)
            if torch.isnan(noise).any() or torch.isinf(noise).any():
                print(f"[Batch {batch_idx}] Noise NaN/Inf Detected! Skipping batch.")
                continue

            timesteps = torch.randint(10, pipe.scheduler.config.num_train_timesteps, (latents.shape[0],), device=device).long()
            noisy_latents = pipe.scheduler.add_noise(latents, noise, timesteps)

            if torch.isnan(noisy_latents).any() or torch.isinf(noisy_latents).any():
                print(f"[Batch {batch_idx}] Noisy Latents NaN/Inf Detected! Skipping batch.")
                continue

            prompt_ids = batch['prompt_ids'].to(device)
            with torch.no_grad():
                encoder_hidden_states = pipe.text_encoder(prompt_ids)[0]

            # No optimizer.zero_grad() here; it's moved inside the accumulation block

            with torch.amp.autocast(device, dtype=torch.float16):
                model_pred = pipe.unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample

                if torch.isnan(model_pred).any() or torch.isinf(model_pred).any():
                    print(f"[Batch {batch_idx}] Model Prediction NaN/Inf Detected! Mean: {model_pred.mean()}, Std: {model_pred.std()}. Skipping batch.")
                    continue

                loss = nn.functional.mse_loss(model_pred.float(), noise.float(), reduction="mean")
                # Scale the loss down by gradient_accumulation_steps
                loss = loss / gradient_accumulation_steps

                if torch.isnan(loss) or torch.isinf(loss):
                    print(f"[Batch {batch_idx}] Loss NaN/Inf Detected! Skipping batch.")
                    continue

            scaler.scale(loss).backward() # This accumulates gradients

            # Accumulate loss for logging
            total_loss += loss.item() * gradient_accumulation_steps # Un-scale loss for total calculation

            if batch_idx % 100 == 0:
                print(f"  [Batch {batch_idx}] Loss: {loss.item() * gradient_accumulation_steps:.6f}") # Display un-scaled loss


            # --- Gradient Accumulation Logic ---
            # Only perform optimizer step and scaler update every `gradient_accumulation_steps` batches
            # or on the last batch of the epoch
            if (batch_idx + 1) % gradient_accumulation_steps == 0 or (batch_idx + 1) == len(train_dataloader):
                # Check for NaN/Inf gradients before stepping (important for stability)
                skip_step = False
                for name, param in pipe.unet.named_parameters():
                    if param.requires_grad and param.grad is not None:
                        if torch.isnan(param.grad).any() or torch.isinf(param.grad).any():
                            print(f"[Batch {batch_idx}] NaN/Inf gradient detected in {name}. Skipping optimizer step.")
                            skip_step = True
                            break

                if skip_step:
                    # If skipping, clear gradients to prevent accumulation with bad gradients
                    optimizer.zero_grad()
                    continue

                scaler.unscale_(optimizer) # Unscale gradients before clipping
                torch.nn.utils.clip_grad_norm_(lora_params, max_norm=1.0) # Clip gradients to prevent explosion
                scaler.step(optimizer)
                scaler.update() # Update the scale for the next iteration
                optimizer.zero_grad() # Reset gradients after performing a step

        # End of epoch, perform a final step if any accumulated gradients are left
        # (This is handled by the `(batch_idx + 1) == len(train_dataloader)` condition in the loop)

        avg_loss = total_loss / len(train_dataloader)
        print(f"Epoch [{epoch+1}/{num_epochs}] Average Loss: {avg_loss:.6f}")

        # # --- Save ONLY LoRA adapter weights (diffusers format) ---
        # lora_adapter_save_path = output_dir / f"lora_adapters_epoch_{epoch+1}"
        # lora_adapter_save_path.mkdir(parents=True, exist_ok=True)
        # pipe.unet.save_pretrained(lora_adapter_save_path, adapter_name="age_gender_lora")
        # print(f"Saved LoRA adapter weights only at {lora_adapter_save_path}")

        # # --- Save FULL U-Net state_dict (including LoRA state) ---
        # full_unet_checkpoint_path = output_dir / f"full_unet_checkpoint_epoch_{epoch+1}.pth"
        # torch.save(pipe.unet.state_dict(), full_unet_checkpoint_path)
        # print(f"Saved full U-Net checkpoint (including LoRA state) at {full_unet_checkpoint_path}")

        # --- Save Checkpoints for this epoch ---
        epoch_lora_adapter_save_path = output_dir / f"lora_adapters_epoch_{epoch+1}"
        epoch_lora_adapter_save_path.mkdir(parents=True, exist_ok=True)
        pipe.unet.save_pretrained(epoch_lora_adapter_save_path, adapter_name="age_gender_lora")
        print(f"Saved LoRA adapter weights only at {epoch_lora_adapter_save_path}")

        full_unet_checkpoint_path = output_dir / f"full_unet_checkpoint_epoch_{epoch+1}.pth"
        torch.save(pipe.unet.state_dict(), full_unet_checkpoint_path)
        print(f"Saved full U-Net checkpoint (including LoRA state) at {full_unet_checkpoint_path}")

        # ---- Generate Sample Image ----
        sample_row = random.choice(dataset_for_sampling.df.to_dict(orient="records"))
        sample_prompt:str = dataset_for_sampling.build_prompt(sample_row)
        generator = torch.Generator(device=device).manual_seed(42)

        with torch.no_grad():
            with torch.amp.autocast(device):
                image = pipe(prompt=sample_prompt, num_inference_steps=75, guidance_scale=7.5, generator=generator).images[0]

        # Save the image
        image_save_path = sample_dir / f"sample_epoch_{epoch+1}.png"
        image.save(image_save_path)
        print(f"Saved sample image at {image_save_path}")

        # Save the prompt to a text file in the sample_dir
        prompt_save_path = sample_dir / f"sample_epoch_{epoch+1}_prompt.txt"
        with open(prompt_save_path, 'w') as file:
            file.write(sample_prompt)
        print(f"Saved prompt to {prompt_save_path}")

        # --- Display Image and Prompt using Matplotlib ---
        plt.figure(figsize=(8, 8)) # Adjust figure size as needed
        plt.imshow(image)
        plt.title(f"Epoch {epoch+1} Sample\nPrompt: {sample_prompt}", wrap=True) # wrap=True for long prompts
        plt.axis('off') # Hide axes
        plt.show() # Display the plot

        # For Jupyter/Colab notebooks, `display` can sometimes help with rendering order
        # display(plt.gcf()) # Get current figure and display it
        # plt.close(plt.gcf()) # Close the plot to free memory and avoid displaying it twice

        print(f"Prompt: {sample_prompt}") # This line is redundant if saved, but fine for immediate output

        # --- Save History for the current epoch ---
        epoch_history = {
            "epoch": epoch + 1,
            "train_loss": avg_loss,
            "sample_prompt": sample_prompt,
            "sample_image_path": str(image_save_path),
            "lora_adapter_path": str(epoch_lora_adapter_save_path),
            "full_unet_checkpoint_path": str(full_unet_checkpoint_path)
        }
        training_history["epochs"].append(epoch_history)

        with open(history_file, 'w') as f:
            json.dump(training_history, f, indent=4)
        print(f"Updated training history saved to {history_file}")

    unload_model(pipe)
    print("✨ Training complete! Model unloaded.")

In [8]:
# Random Sample of ImageDataSet Prompt
dataset = train_loader.dataset
sample_row = random.choice(dataset.df.to_dict(orient="records"))
sample_prompt = dataset.build_prompt(sample_row)
print(sample_prompt)

A 37 years old caucasian female smiling slightly


In [None]:
# --- To train from scratch ---
train_lora_diffusion(
    num_epochs=5,
    train_dataloader=train_loader,
    valid_dataloader=valid_loader,
    dataset_for_sampling=train_dataset,
    learning_rate=3e-5,
    lora_r=128,
    lora_alpha=128,
    lora_dropout=0.05,
    gradient_checkpoint_enable=True,
    gradient_accumulation_steps=4, # Example: accumulate gradients over 4 steps
    output_base_dir=Path("./lora_training_runs"), # Change 'output_dir' to 'output_base_dir'
    # The 'sample_dir' parameter also needs to be removed from the call,
    # as it's now internally derived from 'output_base_dir' and 'run_dir'.
    resume_unet_checkpoint_path=None, # Set to None for training from scratch
    verbose=False,
)

# --- To resume training ---
# Assuming you have a checkpoint saved from a previous run, e.g., unet_lora_weights_epoch_5.pth
# train_lora_diffusion(
#     num_epochs=20, # Train for 5 more epochs
#     train_dataloader=train_loader,
#     valid_dataloader=valid_loader,
#     dataset_for_sampling=train_dataset,
#     learning_rate=1e-5, # You might want a lower LR when resuming
#     lora_r=128,
#     lora_alpha=128,
#     lora_dropout=0.05,
#     gradient_checkpoint_enable=True,
#     gradient_accumulation_steps=4, # Example: accumulate gradients over 4 steps
#     output_base_dir=Path("./lora_training_runs"), # Change 'output_dir' to 'output_base_dir'
#     # Remove 'sample_dir' here too
#     resume_unet_checkpoint_path=Path("./lora_training_runs/YOUR_RUN_ID_HERE/lora_checkpoints/full_unet_checkpoint_epoch_X.pth"), # Update path to reflect new structure
#     verbose=False,
# )

✨ Starting new run with RUN_ID: 760a85_run_20250803-202647


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

✅ Base model loaded.
✅ LoRA adapter added.
✅ LoRA adapter enabled for training.
✅ VAE and Text Encoder frozen.

🔍 Checking LoRA Layers Injected into UNet...
✅ Only LoRA layers are trainable.

🚀 LoRA setup successful. Ready for training loop.
✅ Gradient checkpointing enabled for UNet: gradient_accumulation_steps = 4

🚀 Starting training for 5 epochs...


Epoch 1/5:   0%|          | 1/2033 [00:01<44:12,  1.31s/it]

  [Batch 0] Loss: 0.090369


Epoch 1/5:   3%|▎         | 69/2033 [01:05<30:59,  1.06it/s]

In [None]:
import matplotlib.pyplot as plt
from PIL import Image

if pipe:
    unload_model(pipe=pipe)

def generate_image_from_prompt(pipe: StableDiffusionPipeline, prompt: str) -> Image.Image:
    with torch.no_grad():
        result = pipe(prompt, num_inference_steps=75)
    return result.images[0]

# ---- Load Base Model ----
pipe: StableDiffusionPipeline = load_model()
pipe.safety_checker = lambda images, **kwargs: (images, [False] * len(images))

# ---- Load LoRA Weights ----
lora_adapter_name = "age_gender_lora"
pipe.load_lora_weights("./lora_checkpoint/unet_lora_weights_epoch_10.pth", adapter_name=lora_adapter_name)
pipe.set_adapters([lora_adapter_name])

# ---- Access Dataset & Sample Prompt ----
dataset: Dataset = train_loader.dataset
sample_row: dict[str, Any] = random.choice(dataset.df.to_dict(orient="records"))
sample_prompt: str = dataset.build_prompt(sample_row)

print(f"🔹 Prompt: {sample_prompt}")
print(f"🔹 LoRA Adapter: {lora_adapter_name}")

# ---- Generate Image ----
img: Image.Image = generate_image_from_prompt(pipe, sample_prompt)

# ---- Print Image Resolution ----
print(f"🔹 Image Resolution: {img.width}x{img.height}")

# ---- Display Image ----
plt.imshow(img)
plt.axis('off')  # Hide axis ticks
plt.title(f"Prompt: {sample_prompt}\nLoRA: {lora_adapter_name} | Resolution: {img.width}x{img.height}", fontsize=10)
plt.show()
