In [1]:
import torch
from torch.utils.data import DataLoader, random_split
from diffusers import StableDiffusionPipeline
from peft import LoraConfig, get_peft_model
from accelerate import Accelerator
from tqdm.auto import tqdm
import pandas as pd
import os
import matplotlib.pyplot as plt
from torch.cuda.amp import autocast, GradScaler


In [2]:
# Device Configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True

In [3]:

# Data Preparation
class EmojiDiffusionDataset(torch.utils.data.Dataset):
    def __init__(self, df):
        self.df = df

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        image_tensor = torch.load(self.df.iloc[idx]['image_path']).float()
        prompt = self.df.iloc[idx]['prompt']
        return {"pixel_values": image_tensor, "prompt": prompt}

In [4]:

# Load Dataset
df = pd.read_parquet('../data/processed_emoji_dataset.parquet')
dataset = EmojiDiffusionDataset(df)
train_size = int(0.9 * len(dataset))
train_set, val_set = random_split(dataset, [train_size, len(dataset) - train_size])
train_loader = DataLoader(train_set, batch_size=16, shuffle=True)
val_loader = DataLoader(val_set, batch_size=16)

In [5]:
# Load Model
model_id = "sd-legacy/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to(device)
pipe.enable_attention_slicing()
pipe.unet.enable_gradient_checkpointing()

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

In [6]:
# Disable Safety Checker
if pipe.safety_checker:
    pipe.safety_checker = lambda images, clip_input, **kwargs: (images, [False] * len(images))

# LoRA Configuration
lora_config = LoraConfig(
    r=8, lora_alpha=16,
    target_modules=["to_q", "to_k", "to_v", "to_out.0"],
    lora_dropout=0.1, bias="none"
)
pipe.unet = get_peft_model(pipe.unet, lora_config)
pipe.unet.print_trainable_parameters()

trainable params: 1,594,368 || all params: 861,115,332 || trainable%: 0.1852


In [7]:
# Accelerator Setup
accelerator = Accelerator(mixed_precision='fp16', gradient_accumulation_steps=4)
optimizer = torch.optim.AdamW(pipe.unet.parameters(), lr=1e-4)
pipe.unet, optimizer, train_loader = accelerator.prepare(pipe.unet, optimizer, train_loader)

In [None]:
# Training Loop
num_epochs = 30
training_losses, validation_losses = [], []
output_dir = "../evaluation/lora"
os.makedirs(output_dir, exist_ok=True)

scaler = torch.amp.GradScaler()

for epoch in range(num_epochs):
    pipe.unet.train()
    total_train_loss = 0

    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Training"):
        optimizer.zero_grad()
        pixel_values = batch["pixel_values"].to(device, dtype=torch.float16, non_blocking=True)
        prompts = batch["prompt"]
        
        text_inputs = pipe.tokenizer(prompts, padding="max_length", max_length=pipe.tokenizer.model_max_length, return_tensors="pt")
        text_inputs = {k: v.to(device, non_blocking=True) for k, v in text_inputs.items()}

        with torch.no_grad():
            encoder_hidden_states = pipe.text_encoder(**text_inputs).last_hidden_state
            latents = pipe.vae.encode(pixel_values).latent_dist.sample() * pipe.vae.config.scaling_factor

        noise = torch.randn_like(latents, device=device, dtype=torch.float16)
        timesteps = torch.randint(0, pipe.scheduler.config.num_train_timesteps, (latents.size(0),), device=device).long()
        noisy_latents = pipe.scheduler.add_noise(latents, noise, timesteps)

        with torch.amp.autocast("cuda"):
            noise_pred = pipe.unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample
            loss = torch.nn.functional.mse_loss(noise_pred, noise)
        
        accelerator.backward(loss)
        optimizer.step()

        total_train_loss += loss.item()

    avg_train_loss = total_train_loss / len(train_loader)
    training_losses.append(avg_train_loss)
    print(f"Epoch [{epoch+1}/{num_epochs}] - Avg Train Loss: {avg_train_loss:.6f}")

    # Validation Loop
    pipe.unet.eval()
    total_val_loss = 0
    with torch.no_grad():
        for batch in tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Validation"):
            pixel_values = batch["pixel_values"].to(device, dtype=torch.float16, non_blocking=True)
            prompts = batch["prompt"]
            
            text_inputs = pipe.tokenizer(prompts, padding="max_length", max_length=pipe.tokenizer.model_max_length, return_tensors="pt")
            text_inputs = {k: v.to(device, non_blocking=True) for k, v in text_inputs.items()}

            encoder_hidden_states = pipe.text_encoder(**text_inputs).last_hidden_state
            latents = pipe.vae.encode(pixel_values).latent_dist.sample() * pipe.vae.config.scaling_factor

            noise = torch.randn_like(latents, device=device, dtype=torch.float16)
            timesteps = torch.randint(0, pipe.scheduler.config.num_train_timesteps, (latents.size(0),), device=device).long()
            noisy_latents = pipe.scheduler.add_noise(latents, noise, timesteps)

            noise_pred = pipe.unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample
            loss = torch.nn.functional.mse_loss(noise_pred, noise)

            total_val_loss += loss.item()

    avg_val_loss = total_val_loss / len(val_loader)
    validation_losses.append(avg_val_loss)
    print(f"Epoch [{epoch+1}/{num_epochs}] - Avg Val Loss: {avg_val_loss:.6f}")

    # Save Model Checkpoint
    checkpoint_dir = f"../evaluation/emoji_diffusion_qlora/checkpoint_epoch_{epoch+1}"
    os.makedirs(checkpoint_dir, exist_ok=True)
    pipe.unet.save_pretrained(checkpoint_dir)


Epoch 1/30 - Training:   0%|          | 0/170 [00:00<?, ?it/s]

Epoch [1/30] - Avg Train Loss: 0.031754


Epoch 1/30 - Validation:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch [1/30] - Avg Val Loss: 0.030519


Epoch 2/30 - Training:   0%|          | 0/170 [00:00<?, ?it/s]

Epoch [2/30] - Avg Train Loss: 0.028452


Epoch 2/30 - Validation:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch [2/30] - Avg Val Loss: 0.027492


Epoch 3/30 - Training:   0%|          | 0/170 [00:00<?, ?it/s]

Epoch [3/30] - Avg Train Loss: 0.027433


Epoch 3/30 - Validation:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch [3/30] - Avg Val Loss: 0.027492


Epoch 4/30 - Training:   0%|          | 0/170 [00:00<?, ?it/s]

Epoch [4/30] - Avg Train Loss: 0.027082


Epoch 4/30 - Validation:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch [4/30] - Avg Val Loss: 0.025108


Epoch 5/30 - Training:   0%|          | 0/170 [00:00<?, ?it/s]

Epoch [5/30] - Avg Train Loss: 0.025984


Epoch 5/30 - Validation:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch [5/30] - Avg Val Loss: 0.027190


Epoch 6/30 - Training:   0%|          | 0/170 [00:00<?, ?it/s]

Epoch [6/30] - Avg Train Loss: 0.025301


Epoch 6/30 - Validation:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch [6/30] - Avg Val Loss: 0.026348


Epoch 7/30 - Training:   0%|          | 0/170 [00:00<?, ?it/s]

Epoch [7/30] - Avg Train Loss: 0.025519


Epoch 7/30 - Validation:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch [7/30] - Avg Val Loss: 0.027276


Epoch 8/30 - Training:   0%|          | 0/170 [00:00<?, ?it/s]

Epoch [8/30] - Avg Train Loss: 0.024862


Epoch 8/30 - Validation:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch [8/30] - Avg Val Loss: 0.027526


Epoch 9/30 - Training:   0%|          | 0/170 [00:00<?, ?it/s]

Epoch [9/30] - Avg Train Loss: 0.024734


Epoch 9/30 - Validation:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch [9/30] - Avg Val Loss: 0.026685


Epoch 10/30 - Training:   0%|          | 0/170 [00:00<?, ?it/s]

Epoch [10/30] - Avg Train Loss: 0.023860


Epoch 10/30 - Validation:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch [10/30] - Avg Val Loss: 0.025377


Epoch 11/30 - Training:   0%|          | 0/170 [00:00<?, ?it/s]

Epoch [11/30] - Avg Train Loss: 0.023943


Epoch 11/30 - Validation:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch [11/30] - Avg Val Loss: 0.024321


Epoch 12/30 - Training:   0%|          | 0/170 [00:00<?, ?it/s]

Epoch [12/30] - Avg Train Loss: 0.023262


Epoch 12/30 - Validation:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch [12/30] - Avg Val Loss: 0.026578


Epoch 13/30 - Training:   0%|          | 0/170 [00:00<?, ?it/s]

Epoch [13/30] - Avg Train Loss: 0.023765


Epoch 13/30 - Validation:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch [13/30] - Avg Val Loss: 0.023214


Epoch 14/30 - Training:   0%|          | 0/170 [00:00<?, ?it/s]

Epoch [14/30] - Avg Train Loss: 0.024009


Epoch 14/30 - Validation:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch [14/30] - Avg Val Loss: 0.024935


Epoch 15/30 - Training:   0%|          | 0/170 [00:00<?, ?it/s]

Epoch [15/30] - Avg Train Loss: 0.023464


Epoch 15/30 - Validation:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch [15/30] - Avg Val Loss: 0.024807


Epoch 16/30 - Training:   0%|          | 0/170 [00:00<?, ?it/s]

Epoch [16/30] - Avg Train Loss: 0.023963


Epoch 16/30 - Validation:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch [16/30] - Avg Val Loss: 0.024489


Epoch 17/30 - Training:   0%|          | 0/170 [00:00<?, ?it/s]

Epoch [17/30] - Avg Train Loss: 0.023459


Epoch 17/30 - Validation:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch [17/30] - Avg Val Loss: 0.023869


Epoch 18/30 - Training:   0%|          | 0/170 [00:00<?, ?it/s]

Epoch [18/30] - Avg Train Loss: 0.023979


Epoch 18/30 - Validation:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch [18/30] - Avg Val Loss: 0.023815


Epoch 19/30 - Training:   0%|          | 0/170 [00:00<?, ?it/s]

Epoch [19/30] - Avg Train Loss: 0.023686


Epoch 19/30 - Validation:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch [19/30] - Avg Val Loss: 0.023155


Epoch 20/30 - Training:   0%|          | 0/170 [00:00<?, ?it/s]

Epoch [20/30] - Avg Train Loss: 0.023080


Epoch 20/30 - Validation:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch [20/30] - Avg Val Loss: 0.027011


Epoch 21/30 - Training:   0%|          | 0/170 [00:00<?, ?it/s]

Epoch [21/30] - Avg Train Loss: 0.022518


Epoch 21/30 - Validation:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch [21/30] - Avg Val Loss: 0.022466


Epoch 22/30 - Training:   0%|          | 0/170 [00:00<?, ?it/s]

Epoch [22/30] - Avg Train Loss: 0.022219


Epoch 22/30 - Validation:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch [22/30] - Avg Val Loss: 0.022638


Epoch 23/30 - Training:   0%|          | 0/170 [00:00<?, ?it/s]

Epoch [23/30] - Avg Train Loss: 0.022667


Epoch 23/30 - Validation:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch [23/30] - Avg Val Loss: 0.021064


Epoch 24/30 - Training:   0%|          | 0/170 [00:00<?, ?it/s]

Epoch [24/30] - Avg Train Loss: 0.022052


Epoch 24/30 - Validation:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch [24/30] - Avg Val Loss: 0.020661


Epoch 25/30 - Training:   0%|          | 0/170 [00:00<?, ?it/s]

Epoch [25/30] - Avg Train Loss: 0.021986


Epoch 25/30 - Validation:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch [25/30] - Avg Val Loss: 0.021621


Epoch 26/30 - Training:   0%|          | 0/170 [00:00<?, ?it/s]

Epoch [26/30] - Avg Train Loss: 0.021841


Epoch 26/30 - Validation:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch [26/30] - Avg Val Loss: 0.021855


Epoch 27/30 - Training:   0%|          | 0/170 [00:00<?, ?it/s]

Epoch [27/30] - Avg Train Loss: 0.021868


Epoch 27/30 - Validation:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch [27/30] - Avg Val Loss: 0.022187


Epoch 28/30 - Training:   0%|          | 0/170 [00:00<?, ?it/s]

Epoch [28/30] - Avg Train Loss: 0.022624


Epoch 28/30 - Validation:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch [28/30] - Avg Val Loss: 0.023085


Epoch 29/30 - Training:   0%|          | 0/170 [00:00<?, ?it/s]

Epoch [29/30] - Avg Train Loss: 0.021924


Epoch 29/30 - Validation:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch [29/30] - Avg Val Loss: 0.022795


Epoch 30/30 - Training:   0%|          | 0/170 [00:00<?, ?it/s]

Epoch [30/30] - Avg Train Loss: 0.021486


Epoch 30/30 - Validation:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch [30/30] - Avg Val Loss: 0.019199


In [1]:
# Save Final Model
pipe.unet.save_pretrained("../evaluation/emoji_diffusion_qlora/final_model")

NameError: name 'pipe' is not defined

In [10]:
# Plot Training vs Validation Loss
plt.figure(figsize=(8, 6))
plt.plot(range(1, num_epochs+1), training_losses, label='Training Loss', marker='o')
plt.plot(range(1, num_epochs+1), validation_losses, label='Validation Loss', marker='o')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training vs Validation Loss')
plt.legend()
plt.grid(True)
plt.savefig(os.path.join(output_dir, "training_validation_loss.png"))
plt.close()

In [7]:
import torch
from diffusers import StableDiffusionPipeline
from peft import PeftModel, LoraConfig
from PIL import Image
import os

# Set device to CUDA (GPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load the base Stable Diffusion pipeline
model_id = "sd-legacy/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(
    model_id,
    torch_dtype=torch.float16  # Use float16 for faster inference on GPU
).to(device)

# Enable attention slicing to save memory (optional but recommended)
pipe.enable_attention_slicing()

# Disable NSFW safety checker (if needed)
if pipe.safety_checker is not None:
    pipe.safety_checker = lambda images, clip_input, **kwargs: (images, [False] * len(images))

# Load LoRA weights (make sure the directory is correct)
final_model_dir = "../evaluation/emoji_diffusion_qlora/final_model"
pipe.unet = PeftModel.from_pretrained(pipe.unet, final_model_dir)

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

In [12]:
# Inference parameters
prompt = "monkey"
num_inference_steps = 50
guidance_scale = 7.5

# Generate the image
with torch.inference_mode():
    result = pipe(
        prompt=prompt,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        height=256,
        width=256
    )
    image = result.images[0]

# Save the image
output_path = "outputs/monkey.png"
os.makedirs(os.path.dirname(output_path), exist_ok=True)
image.save(output_path)
print(f"✅ Image saved at: {output_path}")

# Optionally display the image
image.show()

  0%|          | 0/50 [00:00<?, ?it/s]

✅ Image saved at: outputs/monkey.png
