# Phase 1: Pretrained VQ-VAE Experimentation

This notebook validates the LatentGPT pipeline using pretrained VQGAN from HuggingFace.

**Goals:**
1. Test VQ-VAE encoding/decoding
2. Initialize LatentGPT transformer
3. Test text encoding with CLIP
4. Experiment with generation
5. Track with MLflow

**MLflow Experiment:** `latent-gpt-pretrained-vqvae`

## 1. Setup and Imports

In [None]:
from __future__ import annotations

import sys
from pathlib import Path

# Add project root to path
PROJECT_ROOT = Path("/home/doshlom4/work/final_project")
sys.path.insert(0, str(PROJECT_ROOT))

import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from datetime import datetime

# Check GPU availability
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"\nUsing device: {device}")

## 2. MLflow Configuration

Set up MLflow tracking for experiment organization.

In [None]:
import os
import mlflow
from mlflow.tracking import MlflowClient

# MLflow configuration
MLFLOW_TRACKING_URI = os.environ.get("MLFLOW_TRACKING_URI", "http://127.0.0.1:5000")
mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)

# Define our experiments
EXPERIMENTS = {
    "vqvae_training": "vqvae-training",
    "pretrained_vqvae": "latent-gpt-pretrained-vqvae",
    "custom_vqvae": "latent-gpt-custom-vqvae",
}

print(f"MLflow Tracking URI: {MLFLOW_TRACKING_URI}")
print(f"\nConfigured Experiments:")
for key, name in EXPERIMENTS.items():
    print(f"  - {key}: {name}")

In [None]:
# Set the current experiment for Phase 1
CURRENT_EXPERIMENT = EXPERIMENTS["pretrained_vqvae"]
mlflow.set_experiment(CURRENT_EXPERIMENT)

print(f"Active experiment: {CURRENT_EXPERIMENT}")
print("\nNote: Make sure MLflow server is running:")
print("  sbatch slurm/mlflow_server.sh")
print("  # or locally: mlflow server --port 5000")

## 3. Load Configuration

In [None]:
from src.utils.config import Config

# Load configuration for Phase 1 experiments
config = Config.from_yaml(PROJECT_ROOT / "configs" / "vqvae_pretrained.yaml")

print("Configuration loaded:")
print(f"  VQ-VAE: {config.vqvae.checkpoint}")
print(f"  Codebook size: {config.vqvae.codebook_size}")
print(f"  Downsample factor: {config.vqvae.downsample_factor}")
print(f"  Transformer layers: {config.transformer.num_layers}")
print(f"  Transformer hidden: {config.transformer.hidden_size}")
print(f"  Image size: {config.data.image_size}")

## 4. Load Models

Load the pretrained VQ-VAE and CLIP text encoder.

In [None]:
from src.models.vqvae import VQVAEWrapper
from src.models.clip_encoder import CLIPEncoder

# Load pretrained VQ-VAE
print("Loading VQ-VAE...")
vqvae = VQVAEWrapper.from_pretrained(
    checkpoint=config.vqvae.checkpoint,
    device=device
)
print(f"VQ-VAE loaded: vocab_size={vqvae.vocab_size}")

# Load CLIP encoder
print("\nLoading CLIP encoder...")
clip_encoder = CLIPEncoder(
    model_name=config.text_encoder.model_name,
    device=device
)
print(f"CLIP hidden size: {clip_encoder.hidden_size}")

## 5. Test VQ-VAE Encoding/Decoding

Verify that the VQ-VAE can encode and decode images correctly.

In [None]:
# Create a random test image
test_image = torch.randn(1, 3, config.data.image_size, config.data.image_size).to(device)
test_image = (test_image - test_image.min()) / (test_image.max() - test_image.min())  # Normalize to [0, 1]

# Encode to tokens
tokens = vqvae.encode(test_image)
print(f"Input image shape: {test_image.shape}")
print(f"Encoded tokens shape: {tokens.shape}")
print(f"Token range: [{tokens.min().item()}, {tokens.max().item()}]")

# Decode back to image
reconstructed = vqvae.decode(tokens)
print(f"Reconstructed image shape: {reconstructed.shape}")

# Calculate reconstruction error
mse = torch.nn.functional.mse_loss(reconstructed, test_image)
print(f"Reconstruction MSE: {mse.item():.4f}")

In [None]:
# Visualize original vs reconstructed
fig, axes = plt.subplots(1, 2, figsize=(10, 5))

axes[0].imshow(test_image[0].permute(1, 2, 0).cpu().numpy())
axes[0].set_title("Original")
axes[0].axis("off")

axes[1].imshow(reconstructed[0].permute(1, 2, 0).detach().cpu().clamp(0, 1).numpy())
axes[1].set_title("Reconstructed")
axes[1].axis("off")

plt.tight_layout()
plt.show()

## 6. Test CLIP Text Encoding

Verify that CLIP can encode text prompts correctly.

In [None]:
# Test text encoding
test_prompts = [
    "A beautiful sunset over the ocean",
    "A cat sitting on a couch",
    "A mountain landscape with snow",
]

# Encode prompts
text_embeddings = clip_encoder.encode_text(test_prompts)
print(f"Text embeddings shape: {text_embeddings.shape}")
print(f"Expected: [batch_size, seq_len, hidden_size] = [{len(test_prompts)}, *, {clip_encoder.hidden_size}]")

# Test null embedding for CFG
null_embedding = clip_encoder.get_null_embedding(batch_size=2)
print(f"\nNull embedding shape: {null_embedding.shape}")

## 7. Initialize LatentGPT Transformer

Create a small transformer for testing the pipeline.

In [None]:
from src.models.latent_gpt import LatentGPT

# For notebook experimentation, use a smaller model
small_config = config.transformer
small_config.num_layers = 4  # Reduce layers for testing
small_config.hidden_size = 256  # Reduce hidden size for testing

# Create model
model = LatentGPT(
    vocab_size=vqvae.vocab_size,
    max_seq_len=config.transformer.max_seq_len,
    hidden_size=small_config.hidden_size,
    num_layers=small_config.num_layers,
    num_heads=4,  # Reduce heads for testing
    context_dim=clip_encoder.hidden_size,
    dropout=config.transformer.dropout,
).to(device)

# Count parameters
num_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {num_params / 1e6:.2f}M")

## 8. Test Forward Pass

Verify the model can process tokens and text conditioning.

In [None]:
# Test forward pass with random tokens
batch_size = 2
seq_len = 64  # Small sequence for testing

# Random token sequence
dummy_tokens = torch.randint(0, vqvae.vocab_size, (batch_size, seq_len)).to(device)

# Get text conditioning
dummy_context = clip_encoder.encode_text(["A test prompt"] * batch_size)

# Forward pass
logits, loss = model(dummy_tokens, context=dummy_context)

print(f"Input tokens shape: {dummy_tokens.shape}")
print(f"Context shape: {dummy_context.shape}")
print(f"Output logits shape: {logits.shape}")
print(f"Training loss: {loss.item():.4f}")

## 9. Load Flickr30k Dataset

Load a few samples from Flickr30k for testing.

In [None]:
from src.data.flickr30k import Flickr30kDataset, create_dataloader

# Dataset cache path from existing project
cache_dir = PROJECT_ROOT / "stable_diffusion" / "notebooks-old" / "dataset_cache"

# Create dataset
dataset = Flickr30kDataset(
    split="test",
    image_size=config.data.image_size,
    cache_dir=str(cache_dir),
    cfg_dropout=0.0,  # No dropout for testing
)

print(f"Dataset size: {len(dataset)}")

# Get a sample
sample = dataset[0]
print(f"Image shape: {sample['image'].shape}")
print(f"Caption: {sample['caption'][:100]}...")

In [None]:
# Visualize some samples
dataloader = create_dataloader(dataset, batch_size=4, shuffle=True, num_workers=0)
batch = next(iter(dataloader))

fig, axes = plt.subplots(2, 2, figsize=(12, 12))
for i, ax in enumerate(axes.flat):
    img = batch["image"][i].permute(1, 2, 0).numpy()
    img = (img - img.min()) / (img.max() - img.min())  # Normalize for display
    ax.imshow(img)
    ax.set_title(batch["caption"][i][:50] + "...", fontsize=10)
    ax.axis("off")

plt.tight_layout()
plt.show()

## 10. Test Generation Pipeline

Test the full generation pipeline with the untrained model.

In [None]:
# Calculate sequence length for target resolution
downsample_factor = config.vqvae.downsample_factor
latent_size = config.data.image_size // downsample_factor
target_seq_len = latent_size * latent_size

print(f"Image size: {config.data.image_size}x{config.data.image_size}")
print(f"Latent size: {latent_size}x{latent_size}")
print(f"Target sequence length: {target_seq_len}")

# Generate tokens (note: model is untrained, so output will be random)
prompt = "A beautiful sunset over the ocean"
context = clip_encoder.encode_text([prompt])

print(f"\nGenerating {target_seq_len} tokens for prompt: '{prompt}'")
print("Note: Model is untrained, output will be random noise")

generated_tokens = model.generate(
    context=context,
    max_new_tokens=min(target_seq_len, 64),  # Limit for testing speed
    temperature=1.0,
    top_k=50,
)

print(f"Generated tokens shape: {generated_tokens.shape}")

## 11. MLflow Logging Example

Example of how to log a training run to MLflow.

In [None]:
from datetime import datetime
from src.utils.logging import setup_mlflow, log_config

# Start a test run
run_name = f"notebook_test_{datetime.now().strftime('%Y%m%d_%H%M%S')}"

with mlflow.start_run(run_name=run_name):
    # Set tags for this run
    mlflow.set_tags({
        "resolution": str(config.data.image_size),
        "conditioning": "text_conditional",
        "vqvae_source": "pretrained_hf",
        "phase": "1",
        "experiment_type": "notebook_test",
    })
    
    # Log configuration
    mlflow.log_params({
        "vqvae_checkpoint": config.vqvae.checkpoint,
        "codebook_size": config.vqvae.codebook_size,
        "downsample_factor": config.vqvae.downsample_factor,
        "transformer_layers": small_config.num_layers,
        "transformer_hidden": small_config.hidden_size,
        "model_params_millions": num_params / 1e6,
        "image_size": config.data.image_size,
    })
    
    # Log some example metrics
    for step in range(5):
        mlflow.log_metrics({
            "loss": 10.0 - step * 1.5,
            "perplexity": 1000.0 - step * 150,
        }, step=step)
    
    print(f"Logged test run: {run_name}")
    print(f"View at: {MLFLOW_TRACKING_URI}/#/experiments/{EXPERIMENTS[CURRENT_EXPERIMENT]}")

## 12. Mini Training Loop

Quick training loop to validate the pipeline works end-to-end.

In [None]:
from torch.optim import AdamW
from tqdm.auto import tqdm

# Create optimizer
optimizer = AdamW(model.parameters(), lr=1e-4)

# Mini training loop (just a few steps to validate)
num_steps = 10
model.train()

print("Running mini training loop...")
losses = []

for step in tqdm(range(num_steps)):
    # Get a batch
    batch = next(iter(dataloader))
    images = batch["image"].to(device)
    captions = batch["caption"]
    
    # Encode images to tokens
    with torch.no_grad():
        tokens = vqvae.encode(images)
        tokens = tokens.view(tokens.size(0), -1)  # Flatten to sequence
        context = clip_encoder.encode_text(captions)
    
    # Forward pass
    logits, loss = model(tokens, context=context)
    
    # Backward pass
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    losses.append(loss.item())
    
print(f"\nTraining complete!")
print(f"Initial loss: {losses[0]:.4f}")
print(f"Final loss: {losses[-1]:.4f}")

In [None]:
# Plot training loss
plt.figure(figsize=(10, 5))
plt.plot(losses, marker='o')
plt.xlabel("Step")
plt.ylabel("Loss")
plt.title("Mini Training Loss")
plt.grid(True)
plt.show()

## 13. Next Steps

Once the pipeline is validated:

1. **For full training**, use the Python scripts with multi-GPU:
   ```bash
   # Start MLflow server
   sbatch slurm/mlflow_server.sh
   
   # Submit training job
   sbatch slurm/train_8gpu.sh
   ```

2. **Monitor training** in MLflow UI

3. **After Phase 1**, move to Phase 2 (custom VQ-VAE) and Phase 3 (full transformer training)