In [None]:
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from tqdm import tqdm
from datasets import load_dataset, concatenate_datasets
from transformers import CLIPTextModel, CLIPTokenizer
from PIL import Image
from torchvision import transforms

In [None]:
# -------------------------------
# 1. Load & Prepare Dataset with DreamBooth Prompts
# -------------------------------
# Set DreamBooth prompt parameters.
use_dreambooth_prompts = True
subject_identifier = "sks"        # Unique token (e.g., a rare token)
subject_class = "dog"             # The coarse class (e.g., "dog")
subject_prompt = f"a {subject_identifier} {subject_class}"
class_prompt = f"a {subject_class}"

In [None]:
# Load the dataset (using your chosen subset)
dataset = load_dataset('poloclub/diffusiondb', 'large_random_1k', trust_remote_code=True)
full_dataset = concatenate_datasets([split for split in dataset.values()])

# Replace prompt with DreamBooth subject prompt if desired
if use_dreambooth_prompts:
    filtered_dataset = full_dataset.map(lambda x: {'image': x['image'], 'prompt': subject_prompt})
else:
    filtered_dataset = full_dataset.map(lambda x: {'image': x['image'], 'prompt': x['prompt']})

In [None]:
# Display first image and its prompt for verification
first_item = filtered_dataset[0]
plt.imshow(first_item['image'])
plt.axis('off')  
plt.show()
print("Subject Prompt:", first_item['prompt'])

In [None]:
# -------------------------------
# 2. Define Tokenizer, Text Encoder and Image Transforms
# -------------------------------
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
text_encoder = CLIPTextModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="text_encoder")
# Define image transformations
train_transforms = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

In [None]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, tokenizer, transforms):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.transforms = transforms

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        image = item['image'].convert("RGB")
        prompt = item['prompt']
        image = self.transforms(image)
        inputs = self.tokenizer(prompt, return_tensors="pt", max_length=self.tokenizer.model_max_length, 
                                padding="max_length", truncation=True)
        input_ids = inputs["input_ids"].squeeze()
        attention_mask = inputs["attention_mask"].squeeze()
        return image, input_ids, attention_mask

In [None]:
custom_dataset = CustomDataset(filtered_dataset, tokenizer, train_transforms)

In [None]:
def collate_fn(examples):
    pixel_values = torch.stack([example[0] for example in examples])
    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
    input_ids = torch.stack([example[1] for example in examples])
    attention_mask = torch.stack([example[2] for example in examples])
    return {"pixel_values": pixel_values, "input_ids": input_ids, "attention_mask": attention_mask}

In [None]:
train_batch_size = 10
train_dataloader = torch.utils.data.DataLoader(
    custom_dataset,
    shuffle=True,
    collate_fn=collate_fn,
    batch_size=train_batch_size,
)

In [None]:
# -------------------------------
# 3. Load Pretrained Diffusion Components
# -------------------------------
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, StableDiffusionPipeline
from diffusers.optimization import get_scheduler

In [None]:
pretrained_model_name_or_path = "CompVis/stable-diffusion-v1-4"
revision = None
variant = None

In [None]:
noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae", revision=revision, variant=variant)
unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet", revision=revision, variant=variant)

In [None]:
# Freeze VAE, text encoder and (base) unet parameters for memory and stability.
mixed_precision = "fp16"
unet.requires_grad_(False)
vae.requires_grad_(False)
text_encoder.requires_grad_(False)

In [None]:
# Set weight dtype and move models to device.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
weight_dtype = torch.float16 if mixed_precision == "fp16" else torch.float32
unet.to(device, dtype=weight_dtype)
vae.to(device, dtype=weight_dtype)
text_encoder.to(device, dtype=weight_dtype)

In [None]:
# -------------------------------
# 4. Set Up LoRA Adapter (Optional)
# -------------------------------
from peft import LoraConfig
from diffusers.training_utils import cast_training_params

unet_lora_config = LoraConfig(
    r=40,
    lora_alpha=16,
    init_lora_weights="gaussian",
    target_modules=["to_k", "to_q", "to_v", "to_out.0"],
)
unet.add_adapter(unet_lora_config)
if mixed_precision == "fp16":
    cast_training_params(unet, dtype=torch.float32)

In [None]:
# -------------------------------
# 5. Set Up Prior Preservation Mechanism
# -------------------------------
use_prior_preservation = True
prior_loss_weight = 1.0  # lambda in the paper

In [None]:
if use_prior_preservation:
    # Create a frozen copy of the original UNet for generating prior samples.
    unet_pretrained = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, 
                                                           subfolder="unet", revision=revision, variant=variant)
    unet_pretrained.to(device, dtype=weight_dtype)
    unet_pretrained.eval()
    
    # Build a separate pipeline using the frozen UNet.
    pipeline_pretrained = StableDiffusionPipeline(
        vae=vae,
        text_encoder=text_encoder,
        tokenizer=tokenizer,
        unet=unet_pretrained,
        scheduler=noise_scheduler,
        safety_checker=None,
        feature_extractor=None,
    ).to(device)
    
    # Pre-generate a small set of prior samples.
    num_prior_images = 50  # For demo; in practice, many more samples are used.
    prior_latents_list = []
    print("Generating prior preservation latents...")
    for i in range(num_prior_images):
        with torch.autocast("cuda"):
            output = pipeline_pretrained(class_prompt, num_inference_steps=50, guidance_scale=7.5)
        gen_image = output.images[0].convert("RGB")
        # Convert generated image to tensor.
        gen_image_tensor = train_transforms(gen_image).unsqueeze(0).to(device, dtype=weight_dtype)
        with torch.no_grad():
            latent = vae.encode(gen_image_tensor).latent_dist.sample() * vae.config.scaling_factor
        prior_latents_list.append(latent)
    
    # Prepare prior prompt embeddings (to be reused each iteration)
    prior_inputs = tokenizer(class_prompt, return_tensors="pt", max_length=tokenizer.model_max_length,
                             padding="max_length", truncation=True)
    prior_input_ids = prior_inputs["input_ids"].to(device)
    prior_attention_mask = prior_inputs["attention_mask"].to(device)
    with torch.no_grad():
        prior_embeddings = text_encoder(input_ids=prior_input_ids, attention_mask=prior_attention_mask).last_hidden_state

In [None]:
# -------------------------------
# 6. Set Up Training Hyperparameters & Optimizer
# -------------------------------
num_train_epochs = 3
learning_rate = 1e-4

optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, unet.parameters()),
    lr=learning_rate,
    betas=(0.9, 0.999),
    weight_decay=1e-2,
    eps=1e-08,
)

In [None]:
# Use total steps based on dataset size (using custom_dataset length)
num_training_steps = num_train_epochs * len(custom_dataset) // train_batch_size
lr_scheduler = get_scheduler(
    "constant",
    optimizer=optimizer,
    num_warmup_steps=500,
    num_training_steps=num_training_steps,
)

In [None]:
# -------------------------------
# 7. Training Loop with Dual (Subject + Prior) Losses
# -------------------------------
for epoch in tqdm(range(num_train_epochs), desc="Epochs"):
    unet.train()
    epoch_losses = []

    for step, batch in enumerate(tqdm(train_dataloader, desc="Batches", leave=False)):
        # Move subject batch to device.
        pixel_values = batch["pixel_values"].to(dtype=weight_dtype, device=device)
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)

        # Convert subject images to latent space.
        latents = vae.encode(pixel_values).latent_dist.sample() * vae.config.scaling_factor

        # Sample noise and timesteps for subject images.
        noise = torch.randn_like(latents, device=device)
        timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, 
                                  (latents.shape[0],), device=device).long()
        noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

        # Get text embeddings from subject prompts.
        outputs = text_encoder(input_ids=input_ids, attention_mask=attention_mask)
        embeddings = outputs.last_hidden_state

        # Predict noise residual for subject branch.
        model_pred = unet(noisy_latents, timesteps, encoder_hidden_states=embeddings,
                          encoder_attention_mask=attention_mask, return_dict=False)[0]
        loss_subject = F.mse_loss(model_pred.float(), noise.float(), reduction="mean")

        # ----- Prior Preservation Loss -----
        if use_prior_preservation:
            # Sample a batch of prior latents randomly from the pre-generated list.
            indices = torch.randint(0, len(prior_latents_list), (latents.shape[0],))
            prior_batch = torch.cat([prior_latents_list[i] for i in indices], dim=0)  # shape: (B, C, H, W)

            # For the prior branch, sample noise and timesteps.
            noise_prior = torch.randn_like(prior_batch, device=device)
            t_prior = torch.randint(0, noise_scheduler.config.num_train_timesteps, 
                                      (prior_batch.shape[0],), device=device).long()
            noisy_prior_latents = noise_scheduler.add_noise(prior_batch, noise_prior, t_prior)

            # Expand prior embeddings to match batch size.
            prior_embeddings_expanded = prior_embeddings.expand(noisy_prior_latents.shape[0], -1, -1)
            prior_attention_mask_expanded = prior_attention_mask.expand(noisy_prior_latents.shape[0], -1)

            # Predict noise residual for prior branch.
            model_pred_prior = unet(noisy_prior_latents, t_prior, encoder_hidden_states=prior_embeddings_expanded,
                                    encoder_attention_mask=prior_attention_mask_expanded, return_dict=False)[0]
            loss_prior = F.mse_loss(model_pred_prior.float(), noise_prior.float(), reduction="mean")
        else:
            loss_prior = 0.0

        # Total loss is the sum of subject loss and weighted prior loss.
        loss = loss_subject + prior_loss_weight * loss_prior

        optimizer.zero_grad()  
        loss.backward()  
        optimizer.step()  
        lr_scheduler.step() 

        epoch_losses.append(loss.item())

    avg_loss = sum(epoch_losses) / len(epoch_losses)
    print(f"Epoch {epoch + 1}/{num_train_epochs}, Loss: {avg_loss:.4f}")

In [None]:
# -------------------------------
# 8. Inference: Generate Images using the Fine-Tuned Model
# -------------------------------
from diffusers import StableDiffusionPipeline

pipeline = StableDiffusionPipeline(
    vae=vae,
    text_encoder=text_encoder,
    tokenizer=tokenizer,
    unet=unet,
    scheduler=noise_scheduler,
    safety_checker=None,
    feature_extractor=None,
)
pipeline = pipeline.to(device)

# Use a prompt with the unique subject identifier to generate subject images.
prompt = subject_prompt + " in a beautiful landscape"
with torch.autocast(device.type):
    images = pipeline(prompt, num_inference_steps=100, guidance_scale=7.5).images

for idx, img in enumerate(images):
    plt.imshow(img)
    plt.axis("off")
    plt.show()
    img.save(f"generated_image_{idx}.png")
