# Text-to-Image Model Training for Clothing Generation

This notebook fine-tunes Stable Diffusion on the H&M Clothes Descriptions dataset.

**Dataset:** [wbensvage/clothes_desc](https://huggingface.co/datasets/wbensvage/clothes_desc)


## Step 1: Install Dependencies


In [None]:
!pip install -q torch torchvision --index-url https://download.pytorch.org/whl/cu118
!pip install -q transformers diffusers accelerate datasets pillow numpy tqdm huggingface-hub

# Set environment variable for better memory management
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
print("‚úÖ Set PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True for better memory management")


## Step 2: Check GPU Availability


In [None]:
import torch

print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA Version: {torch.version.cuda}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    print("‚ö†Ô∏è No GPU detected. Training will be very slow on CPU!")


## Step 3: Import Libraries


In [None]:
import os
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from diffusers import StableDiffusionPipeline, DDPMScheduler, UNet2DConditionModel, AutoencoderKL
from transformers import CLIPTextModel, CLIPTokenizer
from datasets import load_dataset
from PIL import Image
import numpy as np
from tqdm.auto import tqdm
from accelerate import Accelerator
import pickle
from pathlib import Path


## Step 4: Configuration

**Note:** The configuration is optimized for Colab's T4 GPU (15GB). If you have more memory, you can increase `train_batch_size`.


In [None]:
# Configuration
CONFIG = {
    "pretrained_model": "runwayml/stable-diffusion-v1-5",
    "dataset_name": "wbensvage/clothes_desc",
    "output_dir": "./models/clothes-diffusion",
    "cache_dir": "./data/cached_latents",
    "resolution": 512,
    "train_batch_size": 1,  # Reduced for Colab GPU memory (T4 has ~15GB)
    "gradient_accumulation_steps": 8,  # Increased to maintain effective batch size
    "learning_rate": 1e-5,
    "max_train_steps": 1000,  # Reduce to 100-200 for quick testing
    "gradient_checkpointing": True,
    "mixed_precision": "fp16",  # Use fp16 for faster training on GPU
    "seed": 42,
    "preprocess_batch_size": 8,  # Batch size for preprocessing
}

# Create directories
os.makedirs(CONFIG["output_dir"], exist_ok=True)
os.makedirs(CONFIG["cache_dir"], exist_ok=True)

print("Configuration:")
for key, value in CONFIG.items():
    print(f"  {key}: {value}")


## Step 5: Dataset Class


In [None]:
class ClothesDataset(Dataset):
    """Dataset class for clothes images and text descriptions"""
    
    def __init__(self, dataset, tokenizer, vae=None, size=512, device="cpu", cached_latents=None, cached_texts=None):
        self.tokenizer = tokenizer
        self.size = size
        
        # Use cached latents if available (much faster!)
        if cached_latents is not None and cached_texts is not None:
            print("Using pre-cached latents (fast mode)")
            self.latents = cached_latents
            self.texts = cached_texts
            self.use_cache = True
        else:
            print("Using on-the-fly encoding (slower)")
            self.dataset = dataset
            self.vae = vae
            self.device = device
            self.use_cache = False
        
    def __len__(self):
        if self.use_cache:
            return len(self.texts)
        return len(self.dataset)
    
    def __getitem__(self, idx):
        if self.use_cache:
            # Fast path: use pre-cached latents
            latents = self.latents[idx]
            text = self.texts[idx]
        else:
            # Slow path: encode on-the-fly
            item = self.dataset[idx]
            image = item['image']
            text = item['text']
            
            # Convert to RGB if needed
            if image.mode != "RGB":
                image = image.convert("RGB")
            
            # Resize image if needed
            if image.size != (self.size, self.size):
                image = image.resize((self.size, self.size), Image.LANCZOS)
            
            # Convert to tensor and normalize
            image = np.array(image).astype(np.float32) / 255.0
            image = (image - 0.5) / 0.5  # Normalize to [-1, 1]
            image = torch.from_numpy(image).permute(2, 0, 1)  # CHW format
            
            # Encode image to latent space using VAE
            with torch.no_grad():
                image_batch = image.unsqueeze(0).to(self.device)
                latents = self.vae.encode(image_batch).latent_dist.sample()
                latents = latents * self.vae.config.scaling_factor
                latents = latents.squeeze(0)
        
        # Tokenize text
        text_inputs = self.tokenizer(
            text,
            padding="max_length",
            max_length=self.tokenizer.model_max_length,
            truncation=True,
            return_tensors="pt"
        )
        
        return {
            "latents": latents.cpu() if not self.use_cache else latents,
            "input_ids": text_inputs.input_ids.flatten(),
        }

def collate_fn(examples):
    """Collate function for DataLoader"""
    latents = [example["latents"] for example in examples]
    input_ids = [example["input_ids"] for example in examples]
    
    latents = torch.stack(latents)
    latents = latents.to(memory_format=torch.contiguous_format).float()
    
    input_ids = torch.stack(input_ids)
    
    return {
        "latents": latents,
        "input_ids": input_ids,
    }


In [None]:
def preprocess_dataset():
    """Pre-process dataset by encoding all images to latent space"""
    
    print("Loading VAE encoder...")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    vae = AutoencoderKL.from_pretrained(
        CONFIG["pretrained_model"], subfolder="vae"
    )
    vae = vae.to(device)
    vae.eval()
    vae.requires_grad_(False)
    
    print("Loading dataset...")
    dataset = load_dataset(CONFIG["dataset_name"], split="train")
    
    # Check if cache already exists
    latents_path = os.path.join(CONFIG["cache_dir"], "latents.pt")
    texts_path = os.path.join(CONFIG["cache_dir"], "texts.pkl")
    
    if os.path.exists(latents_path) and os.path.exists(texts_path):
        print(f"Cache already exists at {CONFIG['cache_dir']}")
        print("Loading cached latents...")
        cached_latents = torch.load(latents_path)
        with open(texts_path, "rb") as f:
            cached_texts = pickle.load(f)
        print(f"Loaded {len(cached_texts)} pre-cached latents!")
        return cached_latents, cached_texts
    
    # Process images in batches
    all_latents = []
    all_texts = []
    
    print(f"Encoding {len(dataset)} images to latent space...")
    print(f"This may take a few minutes, but will speed up training significantly!")
    
    with torch.no_grad():
        for i in tqdm(range(0, len(dataset), CONFIG["preprocess_batch_size"])):
            batch_end = min(i + CONFIG["preprocess_batch_size"], len(dataset))
            batch_images = []
            batch_texts = []
            
            # Prepare batch
            for j in range(i, batch_end):
                item = dataset[j]
                image = item['image']
                text = item['text']
                
                # Convert to RGB if needed
                if image.mode != "RGB":
                    image = image.convert("RGB")
                
                # Resize if needed
                if image.size != (CONFIG["resolution"], CONFIG["resolution"]):
                    image = image.resize((CONFIG["resolution"], CONFIG["resolution"]), Image.LANCZOS)
                
                # Convert to tensor and normalize
                image_array = np.array(image).astype(np.float32) / 255.0
                image_array = (image_array - 0.5) / 0.5  # Normalize to [-1, 1]
                image_tensor = torch.from_numpy(image_array).permute(2, 0, 1)  # CHW format
                
                batch_images.append(image_tensor)
                batch_texts.append(text)
            
            # Stack batch
            batch_tensor = torch.stack(batch_images).to(device)
            
            # Encode to latent space
            latents = vae.encode(batch_tensor).latent_dist.sample()
            latents = latents * vae.config.scaling_factor
            
            # Move to CPU and store
            all_latents.append(latents.cpu())
            all_texts.extend(batch_texts)
    
    # Concatenate all latents
    print("Concatenating latents...")
    all_latents = torch.cat(all_latents, dim=0)
    
    # Save cached data
    print(f"Saving cached latents to {CONFIG['cache_dir']}...")
    torch.save(all_latents, latents_path)
    
    with open(texts_path, "wb") as f:
        pickle.dump(all_texts, f)
    
    print(f"‚úÖ Pre-processing complete!")
    print(f"Cached {len(all_texts)} images")
    print(f"Latent shape: {all_latents.shape}")
    
    # Delete VAE to free memory
    del vae
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    return all_latents, all_texts

# Run preprocessing
cached_latents, cached_texts = preprocess_dataset()

# Clear GPU cache after preprocessing
if torch.cuda.is_available():
    import gc
    gc.collect()
    torch.cuda.empty_cache()
    print("‚úÖ GPU cache cleared after preprocessing")


## Step 7: Load Models and Setup Training


In [None]:
# Clear GPU cache before loading models
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    import gc
    gc.collect()
    print("üßπ Cleared GPU cache before loading models")

# Initialize accelerator
use_cpu = not torch.cuda.is_available()
mixed_precision = "no" if use_cpu else CONFIG["mixed_precision"]

accelerator = Accelerator(
    gradient_accumulation_steps=CONFIG["gradient_accumulation_steps"],
    mixed_precision=mixed_precision,
)

print(f"Accelerator device: {accelerator.device}")
print(f"Mixed precision: {mixed_precision}")

# Set seed
torch.manual_seed(CONFIG["seed"])

# Load tokenizer and text encoder
print("Loading tokenizer and text encoder...")
tokenizer = CLIPTokenizer.from_pretrained(
    CONFIG["pretrained_model"], subfolder="tokenizer"
)
text_encoder = CLIPTextModel.from_pretrained(
    CONFIG["pretrained_model"], subfolder="text_encoder"
)

# Load scheduler and UNet
print("Loading scheduler and UNet...")
noise_scheduler = DDPMScheduler.from_pretrained(
    CONFIG["pretrained_model"], subfolder="scheduler"
)
unet = UNet2DConditionModel.from_pretrained(
    CONFIG["pretrained_model"], subfolder="unet"
)

# Enable gradient checkpointing for memory efficiency
if CONFIG["gradient_checkpointing"]:
    unet.enable_gradient_checkpointing()
    print("Gradient checkpointing enabled")

# Freeze text encoder
# IMPORTANT: Keep text encoder on CPU to save GPU memory - we'll move it to GPU only when needed
text_encoder.requires_grad_(False)
text_encoder.eval()
# Keep text encoder on CPU - we'll move inputs to CPU for encoding
print("üìå Text encoder kept on CPU to save GPU memory")

# Clear cache after loading models
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print("‚úÖ Models loaded successfully! GPU cache cleared.")
else:
    print("‚úÖ Models loaded successfully!")


## Step 8: Create Dataset and DataLoader


In [None]:
# Load dataset (just for compatibility, we use cached latents)
dataset = load_dataset(CONFIG["dataset_name"], split="train")

# Create dataset wrapper using cached latents
train_dataset = ClothesDataset(
    dataset, tokenizer, cached_latents=cached_latents, cached_texts=cached_texts
)

# Create dataloader with pin_memory=False to save memory
train_dataloader = DataLoader(
    train_dataset,
    batch_size=CONFIG["train_batch_size"],
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=0,  # Reduced to 0 to save memory
    pin_memory=False,  # Don't pin memory to save GPU memory
)

print(f"Dataset size: {len(train_dataset)}")
print(f"Number of batches: {len(train_dataloader)}")

# Note: Gradient checkpointing is already enabled, which helps with memory
# Attention slicing is for inference pipelines, not needed during training
print("‚úÖ Dataset and DataLoader ready")


## Step 9: Initialize Optimizer


In [None]:
# Initialize optimizer
optimizer = torch.optim.AdamW(
    unet.parameters(),
    lr=CONFIG["learning_rate"],
)

# Prepare for training
unet, optimizer, train_dataloader = accelerator.prepare(
    unet, optimizer, train_dataloader
)

print("‚úÖ Optimizer initialized!")


## Step 9.5: Memory Optimization (Important!)

**If you get OutOfMemoryError, restart the runtime and run all cells from Step 1 again.**


In [None]:
# Memory optimization: Clear all unnecessary variables
import gc

# Delete VAE if it was loaded (we don't need it anymore)
if 'vae' in locals():
    del vae

# Delete dataset object to free memory (we have cached latents)
if 'dataset' in locals():
    del dataset

# Force garbage collection
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()  # Reset peak memory stats
    print(f"üíæ GPU Memory: {torch.cuda.memory_allocated() / 1e9:.2f} GB allocated")
    print(f"üíæ GPU Memory: {torch.cuda.memory_reserved() / 1e9:.2f} GB reserved")
    print("‚úÖ Memory optimized before training")
    
    # Check if we have enough free memory
    free_memory = (torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_reserved()) / 1e9
    if free_memory < 0.5:
        print("‚ö†Ô∏è  WARNING: Very little free GPU memory! Training may fail.")
        print("   Consider: 1) Restart runtime, 2) Reduce max_train_steps, 3) Use Colab Pro")


## Step 10: Training Loop


In [None]:
# Training loop
print("üöÄ Starting training...")
if use_cpu:
    print("‚ö†Ô∏è  Training on CPU - this will be slow. Consider using GPU.")

# Clear cache before training
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    import gc
    gc.collect()

global_step = 0
save_interval = max(100, CONFIG["max_train_steps"] // 10)  # Save every 100 steps or 10% of total

unet.train()
progress_bar = tqdm(total=CONFIG["max_train_steps"], desc="Training")

for step, batch in enumerate(train_dataloader):
    with accelerator.accumulate(unet):
        # Get latents (already encoded by dataset)
        latents = batch["latents"].to(accelerator.device)
        
        # Sample noise
        noise = torch.randn_like(latents)
        bsz = latents.shape[0]
        timesteps = torch.randint(
            0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device
        ).long()
        
        # Add noise to latents
        noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
        
        # Get text embeddings (with no_grad to save memory)
        # Text encoder is on CPU, so we need to move input_ids to CPU
        with torch.no_grad():
            # Ensure input_ids are on CPU (they might have been moved to GPU by accelerator)
            input_ids_cpu = batch["input_ids"].cpu()
            # Encode on CPU where text encoder is
            encoder_hidden_states = text_encoder(input_ids_cpu)[0]
            # Move result to GPU for training
            encoder_hidden_states = encoder_hidden_states.to(accelerator.device)
            del input_ids_cpu
        
        # Predict noise
        model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
        
        # Calculate loss
        loss = F.mse_loss(model_pred.float(), noise.float(), reduction="mean")
        
        # Backward pass
        accelerator.backward(loss)
        optimizer.step()
        optimizer.zero_grad()
        
        # Clear intermediate variables to free memory immediately
        del model_pred, noisy_latents, encoder_hidden_states
        del latents, noise, timesteps
    
    progress_bar.update(1)
    progress_bar.set_postfix({"loss": f"{loss.item():.4f}", "step": global_step})
    global_step += 1
    
    # Periodic cache clearing and checkpoint saving
    if global_step % 25 == 0:  # Clear cache more frequently (every 25 steps)
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            gc.collect()
    
    if global_step % save_interval == 0:
        checkpoint_dir = os.path.join(CONFIG["output_dir"], f"checkpoint-{global_step}")
        os.makedirs(checkpoint_dir, exist_ok=True)
        unwrapped_unet = accelerator.unwrap_model(unet)
        unwrapped_unet.save_pretrained(checkpoint_dir)
        print(f"\nüíæ Saved checkpoint at step {global_step} to {checkpoint_dir}")
        # Clear cache after saving
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    
    if global_step >= CONFIG["max_train_steps"]:
        break

progress_bar.close()
print("\n‚úÖ Training complete!")


## Step 11: Save Final Model


In [None]:
print(f"Saving final model to {CONFIG['output_dir']}...")

# Save UNet
unwrapped_unet = accelerator.unwrap_model(unet)
unwrapped_unet.save_pretrained(CONFIG["output_dir"])

# Save tokenizer, text encoder, VAE, and scheduler
tokenizer.save_pretrained(CONFIG["output_dir"])
text_encoder.save_pretrained(os.path.join(CONFIG["output_dir"], "text_encoder"))

# Load and save VAE
vae = AutoencoderKL.from_pretrained(CONFIG["pretrained_model"], subfolder="vae")
vae.save_pretrained(os.path.join(CONFIG["output_dir"], "vae"))

noise_scheduler.save_pretrained(os.path.join(CONFIG["output_dir"], "scheduler"))

print("‚úÖ Model saved successfully!")
print(f"Model location: {CONFIG['output_dir']}")


## Step 12: Download Model (for use on your local machine)


In [None]:
# Create a zip file of the model
import shutil

model_zip = "clothes-diffusion-model.zip"
print(f"Creating zip file: {model_zip}...")

# Remove old zip if exists
if os.path.exists(model_zip):
    os.remove(model_zip)

# Create zip
shutil.make_archive("clothes-diffusion-model", "zip", CONFIG["output_dir"])

print(f"‚úÖ Zip file created: {model_zip}")
print(f"File size: {os.path.getsize(model_zip) / 1e9:.2f} GB")
print("\nüì• Download the zip file from Colab's file browser (left sidebar)")


## Step 13: Test Generation (Optional)


In [None]:
# Test the trained model
from diffusers import StableDiffusionPipeline

print("Loading trained model for testing...")
pipe = StableDiffusionPipeline.from_pretrained(
    CONFIG["output_dir"],
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    safety_checker=None,  # Disable safety checker for clothing images
)

device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = pipe.to(device)

# Generate a test image
test_prompt = "Black boxer briefs with elasticated waist"
print(f"Generating image for: '{test_prompt}'...")

with torch.autocast(device) if device == "cuda" else torch.no_grad():
    image = pipe(
        test_prompt,
        num_inference_steps=50,
        guidance_scale=7.5,
        height=512,
        width=512,
    ).images[0]

# Save test image
test_output_path = "test_generation.png"
image.save(test_output_path)
print(f"‚úÖ Test image saved to: {test_output_path}")
image
