In [18]:
import torch
from torch.utils.data import DataLoader, random_split
from diffusers import StableDiffusionPipeline
from peft import LoraConfig, get_peft_model
from accelerate import Accelerator
from tqdm.auto import tqdm
import pandas as pd
import os
import matplotlib.pyplot as plt
from torch.cuda.amp import autocast, GradScaler


In [19]:
# Device Configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True

In [20]:

# Data Preparation
class EmojiDiffusionDataset(torch.utils.data.Dataset):
    def __init__(self, df):
        self.df = df

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        image_tensor = torch.load(self.df.iloc[idx]['image_path']).float()
        prompt = self.df.iloc[idx]['prompt']
        return {"pixel_values": image_tensor, "prompt": prompt}

In [21]:

# Load Dataset
df = pd.read_parquet('../data/processed_sticker_dataset.parquet')
dataset = EmojiDiffusionDataset(df)
train_size = int(0.9 * len(dataset))
train_set, val_set = random_split(dataset, [train_size, len(dataset) - train_size])
train_loader = DataLoader(train_set, batch_size=16, shuffle=True)
val_loader = DataLoader(val_set, batch_size=16)

In [22]:
# Load Model
model_id = "sd-legacy/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to(device)
pipe.enable_attention_slicing()
pipe.unet.enable_gradient_checkpointing()

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

In [23]:
# Disable Safety Checker
if pipe.safety_checker:
    pipe.safety_checker = lambda images, clip_input, **kwargs: (images, [False] * len(images))

# LoRA Configuration
lora_config = LoraConfig(
    r=16, lora_alpha=16,
    target_modules=["to_q", "to_k", "to_v", "to_out.0"],
    lora_dropout=0.1, bias="none"
)
pipe.unet = get_peft_model(pipe.unet, lora_config)
pipe.unet.print_trainable_parameters()

trainable params: 3,188,736 || all params: 862,709,700 || trainable%: 0.3696


In [24]:
# Accelerator Setup
accelerator = Accelerator(mixed_precision='fp16', gradient_accumulation_steps=4)
optimizer = torch.optim.AdamW(pipe.unet.parameters(), lr=1e-4)
pipe.unet, optimizer, train_loader = accelerator.prepare(pipe.unet, optimizer, train_loader)

In [25]:
print(df.columns)


Index(['filename', 'prompt', 'image_path'], dtype='object')


In [None]:
# Training Loop
num_epochs = 50
training_losses, validation_losses = [], []
output_dir = "../evaluation/sticker_lora"
os.makedirs(output_dir, exist_ok=True)

scaler = torch.amp.GradScaler()

for epoch in range(num_epochs):
    pipe.unet.train()
    total_train_loss = 0

    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Training"):
        optimizer.zero_grad()
        pixel_values = batch["pixel_values"].to(device, dtype=torch.float16, non_blocking=True)
        prompts = batch["prompt"]
        
        text_inputs = pipe.tokenizer(prompts, padding="max_length", max_length=pipe.tokenizer.model_max_length, return_tensors="pt")
        text_inputs = {k: v.to(device, non_blocking=True) for k, v in text_inputs.items()}

        with torch.no_grad():
            encoder_hidden_states = pipe.text_encoder(**text_inputs).last_hidden_state
            latents = pipe.vae.encode(pixel_values).latent_dist.sample() * pipe.vae.config.scaling_factor

        noise = torch.randn_like(latents, device=device, dtype=torch.float16)
        timesteps = torch.randint(0, pipe.scheduler.config.num_train_timesteps, (latents.size(0),), device=device).long()
        noisy_latents = pipe.scheduler.add_noise(latents, noise, timesteps)

        with torch.amp.autocast("cuda"):
            noise_pred = pipe.unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample
            loss = torch.nn.functional.mse_loss(noise_pred, noise)
        
        accelerator.backward(loss)
        optimizer.step()

        total_train_loss += loss.item()

    avg_train_loss = total_train_loss / len(train_loader)
    training_losses.append(avg_train_loss)
    print(f"Epoch [{epoch+1}/{num_epochs}] - Avg Train Loss: {avg_train_loss:.6f}")

    # Validation Loop
    pipe.unet.eval()
    total_val_loss = 0
    with torch.no_grad():
        for batch in tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Validation"):
            pixel_values = batch["pixel_values"].to(device, dtype=torch.float16, non_blocking=True)
            prompts = batch["prompt"]
            
            text_inputs = pipe.tokenizer(prompts, padding="max_length", max_length=pipe.tokenizer.model_max_length, return_tensors="pt")
            text_inputs = {k: v.to(device, non_blocking=True) for k, v in text_inputs.items()}

            encoder_hidden_states = pipe.text_encoder(**text_inputs).last_hidden_state
            latents = pipe.vae.encode(pixel_values).latent_dist.sample() * pipe.vae.config.scaling_factor

            noise = torch.randn_like(latents, device=device, dtype=torch.float16)
            timesteps = torch.randint(0, pipe.scheduler.config.num_train_timesteps, (latents.size(0),), device=device).long()
            noisy_latents = pipe.scheduler.add_noise(latents, noise, timesteps)

            noise_pred = pipe.unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample
            loss = torch.nn.functional.mse_loss(noise_pred, noise)

            total_val_loss += loss.item()

    avg_val_loss = total_val_loss / len(val_loader)
    validation_losses.append(avg_val_loss)
    print(f"Epoch [{epoch+1}/{num_epochs}] - Avg Val Loss: {avg_val_loss:.6f}")

    # Save Model Checkpoint
    checkpoint_dir = f"../evaluation/sticker_diffusion_qlora/checkpoint_epoch_{epoch+1}"
    os.makedirs(checkpoint_dir, exist_ok=True)
    pipe.unet.save_pretrained(checkpoint_dir)


Epoch 1/50 - Training:   0%|          | 0/244 [00:00<?, ?it/s]

Epoch [1/50] - Avg Train Loss: 0.128103


Epoch 1/50 - Validation:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch [1/50] - Avg Val Loss: 0.118131


Epoch 2/50 - Training:   0%|          | 0/244 [00:00<?, ?it/s]

Epoch [2/50] - Avg Train Loss: 0.122929


Epoch 2/50 - Validation:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch [2/50] - Avg Val Loss: 0.119799


Epoch 3/50 - Training:   0%|          | 0/244 [00:00<?, ?it/s]

Epoch [3/50] - Avg Train Loss: 0.123691


Epoch 3/50 - Validation:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch [3/50] - Avg Val Loss: 0.124828


Epoch 4/50 - Training:   0%|          | 0/244 [00:00<?, ?it/s]

Epoch [4/50] - Avg Train Loss: 0.122926


Epoch 4/50 - Validation:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch [4/50] - Avg Val Loss: 0.130248


Epoch 5/50 - Training:   0%|          | 0/244 [00:00<?, ?it/s]

Epoch [5/50] - Avg Train Loss: 0.121419


Epoch 5/50 - Validation:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch [5/50] - Avg Val Loss: 0.124330


Epoch 6/50 - Training:   0%|          | 0/244 [00:00<?, ?it/s]

Epoch [6/50] - Avg Train Loss: 0.123890


Epoch 6/50 - Validation:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch [6/50] - Avg Val Loss: 0.122641


Epoch 7/50 - Training:   0%|          | 0/244 [00:00<?, ?it/s]

Epoch [7/50] - Avg Train Loss: 0.120536


Epoch 7/50 - Validation:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch [7/50] - Avg Val Loss: 0.117842


Epoch 8/50 - Training:   0%|          | 0/244 [00:00<?, ?it/s]

Epoch [8/50] - Avg Train Loss: 0.122098


Epoch 8/50 - Validation:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch [8/50] - Avg Val Loss: 0.129809


Epoch 9/50 - Training:   0%|          | 0/244 [00:00<?, ?it/s]

Epoch [9/50] - Avg Train Loss: 0.126174


Epoch 9/50 - Validation:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch [9/50] - Avg Val Loss: 0.119048


Epoch 10/50 - Training:   0%|          | 0/244 [00:00<?, ?it/s]

Epoch [10/50] - Avg Train Loss: 0.121862


Epoch 10/50 - Validation:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch [10/50] - Avg Val Loss: 0.124171


Epoch 11/50 - Training:   0%|          | 0/244 [00:00<?, ?it/s]

Epoch [11/50] - Avg Train Loss: 0.119304


Epoch 11/50 - Validation:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch [11/50] - Avg Val Loss: 0.121879


Epoch 12/50 - Training:   0%|          | 0/244 [00:00<?, ?it/s]

Epoch [12/50] - Avg Train Loss: 0.121047


Epoch 12/50 - Validation:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch [12/50] - Avg Val Loss: 0.118589


Epoch 13/50 - Training:   0%|          | 0/244 [00:00<?, ?it/s]

Epoch [13/50] - Avg Train Loss: 0.119622


Epoch 13/50 - Validation:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch [13/50] - Avg Val Loss: 0.124509


Epoch 14/50 - Training:   0%|          | 0/244 [00:00<?, ?it/s]

Epoch [14/50] - Avg Train Loss: 0.124783


Epoch 14/50 - Validation:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch [14/50] - Avg Val Loss: 0.119992


Epoch 15/50 - Training:   0%|          | 0/244 [00:00<?, ?it/s]

Epoch [15/50] - Avg Train Loss: 0.121958


Epoch 15/50 - Validation:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch [15/50] - Avg Val Loss: 0.123765


Epoch 16/50 - Training:   0%|          | 0/244 [00:00<?, ?it/s]

Epoch [16/50] - Avg Train Loss: 0.122160


Epoch 16/50 - Validation:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch [16/50] - Avg Val Loss: 0.119982


Epoch 17/50 - Training:   0%|          | 0/244 [00:00<?, ?it/s]

Epoch [17/50] - Avg Train Loss: 0.120321


Epoch 17/50 - Validation:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch [17/50] - Avg Val Loss: 0.122322


Epoch 18/50 - Training:   0%|          | 0/244 [00:00<?, ?it/s]

Epoch [18/50] - Avg Train Loss: 0.121435


Epoch 18/50 - Validation:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch [18/50] - Avg Val Loss: 0.128804


Epoch 19/50 - Training:   0%|          | 0/244 [00:00<?, ?it/s]

Epoch [19/50] - Avg Train Loss: 0.118559


Epoch 19/50 - Validation:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch [19/50] - Avg Val Loss: 0.118932


Epoch 20/50 - Training:   0%|          | 0/244 [00:00<?, ?it/s]

Epoch [20/50] - Avg Train Loss: 0.122680


Epoch 20/50 - Validation:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch [20/50] - Avg Val Loss: 0.121417


Epoch 21/50 - Training:   0%|          | 0/244 [00:00<?, ?it/s]

Epoch [21/50] - Avg Train Loss: 0.122072


Epoch 21/50 - Validation:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch [21/50] - Avg Val Loss: 0.113584


Epoch 22/50 - Training:   0%|          | 0/244 [00:00<?, ?it/s]

Epoch [22/50] - Avg Train Loss: 0.122283


Epoch 22/50 - Validation:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch [22/50] - Avg Val Loss: 0.114873


Epoch 23/50 - Training:   0%|          | 0/244 [00:00<?, ?it/s]

Epoch [23/50] - Avg Train Loss: 0.121403


Epoch 23/50 - Validation:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch [23/50] - Avg Val Loss: 0.117414


Epoch 24/50 - Training:   0%|          | 0/244 [00:00<?, ?it/s]

Epoch [24/50] - Avg Train Loss: 0.118744


Epoch 24/50 - Validation:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch [24/50] - Avg Val Loss: 0.121230


Epoch 25/50 - Training:   0%|          | 0/244 [00:00<?, ?it/s]

Epoch [25/50] - Avg Train Loss: 0.123253


Epoch 25/50 - Validation:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch [25/50] - Avg Val Loss: 0.123720


Epoch 26/50 - Training:   0%|          | 0/244 [00:00<?, ?it/s]

Epoch [26/50] - Avg Train Loss: 0.121433


Epoch 26/50 - Validation:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch [26/50] - Avg Val Loss: 0.118716


Epoch 27/50 - Training:   0%|          | 0/244 [00:00<?, ?it/s]

Epoch [27/50] - Avg Train Loss: 0.120103


Epoch 27/50 - Validation:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch [27/50] - Avg Val Loss: 0.109413


Epoch 28/50 - Training:   0%|          | 0/244 [00:00<?, ?it/s]

Epoch [28/50] - Avg Train Loss: 0.119832


Epoch 28/50 - Validation:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch [28/50] - Avg Val Loss: 0.113234


Epoch 29/50 - Training:   0%|          | 0/244 [00:00<?, ?it/s]

Epoch [29/50] - Avg Train Loss: 0.121104


Epoch 29/50 - Validation:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch [29/50] - Avg Val Loss: 0.112715


Epoch 30/50 - Training:   0%|          | 0/244 [00:00<?, ?it/s]

Epoch [30/50] - Avg Train Loss: 0.122694


Epoch 30/50 - Validation:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch [30/50] - Avg Val Loss: 0.122581


Epoch 31/50 - Training:   0%|          | 0/244 [00:00<?, ?it/s]

Epoch [31/50] - Avg Train Loss: 0.122364


Epoch 31/50 - Validation:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch [31/50] - Avg Val Loss: 0.111412


Epoch 32/50 - Training:   0%|          | 0/244 [00:00<?, ?it/s]

Epoch [32/50] - Avg Train Loss: 0.120115


Epoch 32/50 - Validation:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch [32/50] - Avg Val Loss: 0.114552


Epoch 33/50 - Training:   0%|          | 0/244 [00:00<?, ?it/s]

Epoch [33/50] - Avg Train Loss: 0.120001


Epoch 33/50 - Validation:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch [33/50] - Avg Val Loss: 0.120324


Epoch 34/50 - Training:   0%|          | 0/244 [00:00<?, ?it/s]

Epoch [34/50] - Avg Train Loss: 0.122957


Epoch 34/50 - Validation:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch [34/50] - Avg Val Loss: 0.114532


Epoch 35/50 - Training:   0%|          | 0/244 [00:00<?, ?it/s]

Epoch [35/50] - Avg Train Loss: 0.118807


Epoch 35/50 - Validation:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch [35/50] - Avg Val Loss: 0.113871


Epoch 36/50 - Training:   0%|          | 0/244 [00:00<?, ?it/s]

Epoch [36/50] - Avg Train Loss: 0.117410


Epoch 36/50 - Validation:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch [36/50] - Avg Val Loss: 0.119500


Epoch 37/50 - Training:   0%|          | 0/244 [00:00<?, ?it/s]

Epoch [37/50] - Avg Train Loss: 0.120149


Epoch 37/50 - Validation:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch [37/50] - Avg Val Loss: 0.117086


Epoch 38/50 - Training:   0%|          | 0/244 [00:00<?, ?it/s]

In [None]:
# Save Final Model
pipe.unet.save_pretrained("../evaluation/sticker_diffusion_qlora/final_model")

NameError: name 'pipe' is not defined

In [None]:
# Plot Training vs Validation Loss
plt.figure(figsize=(8, 6))
plt.plot(range(1, num_epochs+1), training_losses, label='Training Loss', marker='o')
plt.plot(range(1, num_epochs+1), validation_losses, label='Validation Loss', marker='o')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training vs Validation Loss')
plt.legend()
plt.grid(True)
plt.savefig(os.path.join(output_dir, "training_validation_loss.png"))
plt.close()