In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from diffusers import UNet2DConditionModel, AutoencoderKL
from peft import LoraConfig, get_peft_model
from torch.utils.data import DataLoader, Dataset
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import gc
from transformers import CLIPTokenizer, CLIPTextModel
from PIL import Image
import numpy as np
from transformers import CLIPTokenizer, CLIPTextModel
from peft import LoraConfig, PeftModel

### Prompt(GPT 4o): Fine tunning Stable Difussion model using Lora .

In [2]:
# Enable cuDNN optimization for faster training
torch.backends.cudnn.benchmark = True
torch.set_float32_matmul_precision("high")

# Ensure PyTorch uses GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Free GPU memory
torch.cuda.empty_cache()
gc.collect()

Using device: cuda


0

In [3]:
# Dataset Class
class EmojiDataset(Dataset):
    def __init__(self, parquet_file):
        self.data = pd.read_parquet(parquet_file)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image_path = self.data.iloc[idx]["image_path"]
        image_tensor = torch.load(image_path).float() / 127.5 - 1  # Normalize to [-1,1]
        text_embedding = torch.tensor(self.data.iloc[idx]["combined_embedding"], dtype=torch.float32)
        return image_tensor, text_embedding

In [4]:
# Load Dataset
parquet_file = "../data/processed_emoji_dataset.parquet"
dataset = EmojiDataset(parquet_file)
batch_size = 16  
train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Load Stable Diffusion 2 Base Model (Trained on 256x256 images)
vae = AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-2-base", subfolder="vae").to(device, dtype=torch.float16)
unet = UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-2-base", subfolder="unet").to(device, dtype=torch.float16)

print(unet.config) 

FrozenDict({'sample_size': 64, 'in_channels': 4, 'out_channels': 4, 'center_input_sample': False, 'flip_sin_to_cos': True, 'freq_shift': 0, 'down_block_types': ['CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'DownBlock2D'], 'mid_block_type': 'UNetMidBlock2DCrossAttn', 'up_block_types': ['UpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D'], 'only_cross_attention': False, 'block_out_channels': [320, 640, 1280, 1280], 'layers_per_block': 2, 'downsample_padding': 1, 'mid_block_scale_factor': 1, 'dropout': 0.0, 'act_fn': 'silu', 'norm_num_groups': 32, 'norm_eps': 1e-05, 'cross_attention_dim': 1024, 'transformer_layers_per_block': 1, 'reverse_transformer_layers_per_block': None, 'encoder_hid_dim': None, 'encoder_hid_dim_type': None, 'attention_head_dim': [5, 10, 20, 20], 'num_attention_heads': None, 'dual_cross_attention': False, 'use_linear_projection': True, 'class_embed_type': None, 'addition_embed_type': None, 'addition_time_embed_dim': 

In [5]:

# Apply LoRA to UNet
lora_config = LoraConfig(
    r=4,  # LoRA rank
    lora_alpha=16,  # Scaling factor
    target_modules=["to_q", "to_k", "to_v", "proj_out", "proj_in"],  
    lora_dropout=0.05,  # Dropout for regularization
    bias="none"
)

unet = get_peft_model(unet, lora_config)
unet.print_trainable_parameters()  # Print trainable parameters (should be very small)

trainable params: 829,952 || all params: 866,740,676 || trainable%: 0.0958


In [6]:
# Enable memory optimization
unet.enable_gradient_checkpointing()

# Embedding Projector (CLIP 512 → UNet 768)
class EmbeddingProjector(nn.Module):
    def __init__(self, input_dim=512, output_dim=1024):  # Changed to 768
        super().__init__()
        self.fc = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.fc(x)

embedding_projector = EmbeddingProjector().to(device, dtype=torch.float16)

# Freeze VAE
for param in vae.parameters():
    param.requires_grad = False  

# Define optimizer
optimizer = AdamW(filter(lambda p: p.requires_grad, unet.parameters()), lr=0.001)
scaler = torch.amp.GradScaler()

In [None]:
# Training Loop
num_epochs = 5
losses = []
for epoch in range(num_epochs):
    epoch_loss = 0
    progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")
    
    for images, embeddings in progress_bar:
        images = images.to(device, dtype=torch.float16, non_blocking=True)
        embeddings = embeddings.to(device, dtype=torch.float16, non_blocking=True)
        optimizer.zero_grad()

        # Project CLIP embeddings
        with torch.no_grad():
            projected_embeddings = embedding_projector(embeddings).unsqueeze(1)
        
        with torch.no_grad():
            latents = vae.encode(images).latent_dist.sample() * 0.18215

        # Convert latents to bfloat16 to save memory
        latents = latents.to(torch.float16)

        # Generate noise
        noise = torch.randn_like(latents, dtype=torch.float16)
        timesteps = torch.randint(0, 1000, (latents.shape[0],), device=device).long()

        # Forward pass
        with torch.amp.autocast("cuda"):
            noise_pred = unet(latents, timesteps, encoder_hidden_states=projected_embeddings).sample
            loss = F.mse_loss(noise_pred, noise)

        # Backpropagation
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)  # Avoid NaN issues
        scaler.step(optimizer)
        scaler.update()

        epoch_loss += loss.item()
        progress_bar.set_postfix(loss=loss.item())

    avg_epoch_loss = epoch_loss / len(train_dataloader)
    losses.append(avg_epoch_loss)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_epoch_loss:.4f}")


Epoch 1/5:  33%|███▎      | 56/170 [26:49<48:57, 25.77s/it, loss=1]      

In [None]:

# Save Model
torch.save({
    "unet": unet.state_dict(),
    "embedding_projector": embedding_projector.state_dict()
}, "emoji_generator.pth")
print("Model saved successfully!")

# Save LoRA Weights
unet.save_pretrained("lora_emoji_unet")
torch.save(embedding_projector.state_dict(), "embedding_projector.pth")
print("LoRA adapters saved successfully!")


In [None]:

# Plot Loss Curve
plt.plot(range(1, num_epochs + 1), losses, marker="o", linestyle="-")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("Training Loss Curve")
plt.grid()
plt.show()


In [None]:
unet = PeftModel.from_pretrained(unet, "lora_emoji_unet").to(device).to(torch.float16)
embedding_projector.load_state_dict(torch.load("embedding_projector.pth"))
unet.eval()
embedding_projector.eval()


In [None]:

# Assuming the necessary models are already loaded
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
text_encoder = text_encoder.to(device)

In [None]:
# Define your text prompt
text_description = "htdjfgvg iukf iltog"  # This is where we specify "dog"

# Tokenize and encode the text prompt
tokens = tokenizer(text_description, return_tensors="pt").to(device)
text_embedding = text_encoder(**tokens).last_hidden_state.mean(dim=1)  # Aggregate token embeddings
text_embedding = text_embedding.to(torch.float16)  # Ensure it's float16 for compatibility with UNet

In [None]:
# Project the embedding to match UNet’s expected format
projected_embedding = embedding_projector(text_embedding).unsqueeze(0).to(torch.float16)  # Ensure float16

# Generate noise in latent space (fixed size 96x96)
latents = torch.randn(1, 4, 96, 96).to(device).to(torch.float16)  # Ensure latents are in float16
timesteps = torch.tensor([500], device=device).long()
torch.cuda.empty_cache()
gc.collect()

# Generate noise in latent space (smaller size)
latents = torch.randn(1, 4, 64, 64, device=device, dtype=torch.float16)
timesteps = torch.tensor([500], device=device).long()

In [None]:
# Generate Emoji
with torch.no_grad():
    denoised_latents = unet(latents, timesteps, encoder_hidden_states=projected_embedding).sample

# Move to CPU and Decode
denoised_latents = denoised_latents / 0.18215

with torch.no_grad():
    decoded_image = vae.decode(denoised_latents).sample

In [None]:


# Post-process Image
decoded_image = (decoded_image.clamp(-1, 1) + 1) / 2
decoded_image = decoded_image.squeeze(0).permute(1, 2, 0).cpu().numpy()
decoded_image = (decoded_image * 255).astype(np.uint8)
emoji_image = Image.fromarray(decoded_image)

# Display the image
emoji_image.show()
