In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler
from PIL import Image

In [None]:
# ----- Configuration -----
subject_images_path = "./subject_images"  # Folder with subject images
subject_prompt = "a unique_dog"  # e.g. unique identifier + class noun
class_prompt = "a dog"         # Class prompt for prior preservation
num_train_steps = 1000
batch_size = 1
learning_rate = 5e-6          # Example: 5e-6 for Stable Diffusion
lambda_prior = 1.0            # Weight for the class prior loss
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
# ----- Set Up Data Transforms and Dataloader -----
transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.CenterCrop(512),
    transforms.ToTensor(),
    # Normalize images to [-1, 1] (the VAE expects inputs in this range)
    transforms.Normalize([0.5], [0.5]),
])

In [None]:
# Use ImageFolder assuming your images are organized under a class folder.
subject_dataset = datasets.ImageFolder(subject_images_path, transform=transform)
subject_dataloader = DataLoader(subject_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

In [None]:
# ----- Load Pretrained Models and Tokenizer -----
# Here we use the Stable Diffusion v1-5 model from Hugging Face.
pipe = StableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    revision="fp16",
    torch_dtype=torch.float16,
).to(device)
pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)

In [None]:
# Extract components
unet = pipe.unet  # The diffusion network
vae = pipe.vae    # The variational autoencoder
text_encoder = pipe.text_encoder
tokenizer = pipe.tokenizer

In [None]:
# Set train mode for modules we want to fine-tune.
unet.train()
text_encoder.train()
# Optionally, you might freeze the VAE if you do not wish to update it.
vae.eval()

In [None]:
# ----- Training Setup -----
optimizer = optim.AdamW(list(unet.parameters()) + list(text_encoder.parameters()), lr=learning_rate)
scheduler = pipe.scheduler  # Using the pipeline’s scheduler for noise schedule

In [None]:
# ----- Helper Function: Generate Class Prior Sample -----
def generate_class_prior_sample(prompt, num_inference_steps=50):
    """
    Use the frozen pipeline to generate a class prior image.
    In practice, you may want to precompute a bank of these samples.
    """
    with torch.no_grad():
        output = pipe(prompt, num_inference_steps=num_inference_steps)
    image = output.images[0]
    # Convert PIL image to tensor with same preprocessing as subject images.
    image = transform(image).unsqueeze(0).to(device)
    return image

In [None]:
# Infinite iterator for the dataloader
def cycle(loader):
    while True:
        for data in loader:
            yield data

In [None]:
subject_iter = cycle(subject_dataloader)

In [None]:
# ----- Training Loop -----
for step in range(num_train_steps):
    optimizer.zero_grad()
    
    # --- Subject Image and Loss ---
    # Get one batch of subject images.
    (subject_images, _), = [next(subject_iter)]  # subject_images: [B, C, H, W]
    subject_images = subject_images.to(device)
    
    # Encode subject images to latent space using the VAE
    with torch.no_grad():
        latent_dist = vae.encode(subject_images).latent_dist
        latents = latent_dist.sample() * 0.18215  # scaling factor
    
    # Tokenize the subject prompt
    subject_tokens = tokenizer(subject_prompt, padding="max_length", truncation=True, max_length=tokenizer.model_max_length, return_tensors="pt").input_ids.to(device)
    with torch.no_grad():
        subject_text_embeddings = text_encoder(subject_tokens)[0]
    
    # Sample random noise and timesteps
    t = torch.randint(0, scheduler.num_train_timesteps, (batch_size,), device=device).long()
    noise = torch.randn_like(latents)
    noisy_latents = scheduler.add_noise(latents, noise, t)
    
    # Predict noise using UNet
    model_pred = unet(noisy_latents, t, encoder_hidden_states=subject_text_embeddings).sample
    loss_subject = ((model_pred - noise) ** 2).mean()
    
    # --- Class Prior Loss ---
    # Generate a class prior image using the frozen pipeline.
    class_image = generate_class_prior_sample(class_prompt)
    with torch.no_grad():
        latent_dist_class = vae.encode(class_image).latent_dist
        latents_class = latent_dist_class.sample() * 0.18215
    
    # Tokenize the class prompt
    class_tokens = tokenizer(class_prompt, padding="max_length", truncation=True, max_length=tokenizer.model_max_length, return_tensors="pt").input_ids.to(device)
    with torch.no_grad():
        class_text_embeddings = text_encoder(class_tokens)[0]
    
    # Use a separate random timestep and noise for the class prior sample.
    t_class = torch.randint(0, scheduler.num_train_timesteps, (1,), device=device).long()
    noise_class = torch.randn_like(latents_class)
    noisy_latents_class = scheduler.add_noise(latents_class, noise_class, t_class)
    
    model_pred_class = unet(noisy_latents_class, t_class, encoder_hidden_states=class_text_embeddings).sample
    loss_class = ((model_pred_class - noise_class) ** 2).mean()
    
    # Combine losses
    loss = loss_subject + lambda_prior * loss_class
    loss.backward()
    optimizer.step()
    
    if step % 100 == 0:
        print(f"Step {step}: Total Loss = {loss.item():.4f} | Subject Loss = {loss_subject.item():.4f} | Class Loss = {loss_class.item():.4f}")

print("Training complete.")
