In [None]:
import torch
from diffusers import StableDiffusionPipeline, DDPMScheduler
from torch.optim import AdamW
from torch.utils.data import DataLoader

In [None]:
# Load model
model_id = "CompVis/stable-diffusion-v1-4"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
scheduler = DDPMScheduler.from_config(model_id, subfolder="scheduler")

In [None]:
# Unfreeze UNet and text encoder
pipe.unet.requires_grad_(True)
pipe.text_encoder.requires_grad_(True)
optimizer = AdamW(
    list(pipe.unet.parameters()) + list(pipe.text_encoder.parameters()),
    lr=5e-6
)

In [None]:
# Dummy datasets (replace with actual data loading)
subject_latents = torch.randn(3, 4, 64, 64).to("cuda")  # Example subject latents
subject_prompts = ["a sks dog"] * 3
prior_latents = torch.randn(100, 4, 64, 64).to("cuda")  # Pre-generated prior latents
prior_prompts = ["a dog"] * 100

In [None]:
# Training parameters
lambda_prior = 1.0
batch_size = 1
epochs = 1000

In [None]:
# Training loop
for epoch in range(epochs):
    # Subject loss
    for i in range(0, len(subject_latents), batch_size):
        batch_latents = subject_latents[i:i+batch_size]
        batch_prompts = subject_prompts[i:i+batch_size]
        
        # Noise and timesteps
        noise = torch.randn_like(batch_latents)
        timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (batch_size,)).to("cuda")
        
        # Add noise
        noisy_latents = scheduler.add_noise(batch_latents, noise, timesteps)
        
        # Encode text
        text_input = pipe.tokenizer(
            batch_prompts, padding="max_length",
            max_length=pipe.tokenizer.model_max_length,
            return_tensors="pt"
        ).to("cuda")
        text_embeddings = pipe.text_encoder(text_input.input_ids)[0]
        
        # Predict noise
        noise_pred = pipe.unet(noisy_latents, timesteps, text_embeddings).sample
        subject_loss = torch.nn.functional.mse_loss(noise_pred, noise)
        
    # Prior loss
    for i in range(0, len(prior_latents), batch_size):
        batch_prior = prior_latents[i:i+batch_size]
        batch_prior_prompts = prior_prompts[i:i+batch_size]
        
        noise_prior = torch.randn_like(batch_prior)
        timesteps_prior = torch.randint(0, scheduler.config.num_train_timesteps, (batch_size,)).to("cuda")
        
        noisy_prior = scheduler.add_noise(batch_prior, noise_prior, timesteps_prior)
        
        prior_text_input = pipe.tokenizer(
            batch_prior_prompts, padding="max_length",
            max_length=pipe.tokenizer.model_max_length,
            return_tensors="pt"
        ).to("cuda")
        prior_embeddings = pipe.text_encoder(prior_text_input.input_ids)[0]
        
        noise_pred_prior = pipe.unet(noisy_prior, timesteps_prior, prior_embeddings).sample
        prior_loss = torch.nn.functional.mse_loss(noise_pred_prior, noise_prior)
    
    # Total loss
    loss = subject_loss + lambda_prior * prior_loss
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    print(f"Epoch {epoch}, Loss: {loss.item()}")