# Guided Variational Flow Matching

This notebook demonstrates guided VFM with classifier guidance for improved generation.

## Overview
Guided VFM incorporates guidance signals (e.g., from a classifier) to steer the generation process toward desired characteristics.

In [None]:
import sys
sys.path.append("../..")

import torch
import matplotlib.pyplot as plt
import numpy as np
from src.models import GuidedVFM
from src.utils import generate_toy_data, plot_samples, train_vfm, create_dataloader

print("PyTorch version:", torch.__version__)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

## 1. Generate Target Data

In [None]:
# Generate target data
target_data = generate_toy_data("two_moons", num_samples=2000, noise=0.1)
plot_samples(target_data, title="Target Distribution")

## 2. Initialize Guided Model

In [None]:
# Initialize guided model
model = GuidedVFM(
    input_dim=2,
    hidden_dim=128,
    num_layers=3,
    time_embedding_dim=32,
    guidance_scale=1.0
)

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

## 3. Train the Model

In [None]:
# Create dataloader
dataloader = create_dataloader(target_data, batch_size=128, shuffle=True)

# Train model
history = train_vfm(
    model=model,
    dataloader=dataloader,
    num_epochs=50,
    learning_rate=1e-3,
    device=device,
    verbose=True
)

## 4. Generate Samples Without Guidance

In [None]:
model.eval()
with torch.no_grad():
    samples_no_guidance = model.sample(num_samples=2000, num_steps=100, device=device)

plot_samples(samples_no_guidance, title="Generated Samples (No Guidance)")

## 5. Define Guidance Function

We define a simple guidance function that encourages samples to stay in a specific region.

In [None]:
def guidance_function(x):
    """Guide samples toward positive y values."""
    # Create gradient that pushes y-coordinate upward
    grad = torch.zeros_like(x)
    grad[:, 1] = 1.0  # Positive gradient in y direction
    return grad

# Alternative: Guide toward a specific point
def guidance_to_point(x, target_point=torch.tensor([0.5, 0.5])):
    """Guide samples toward a target point."""
    target = target_point.to(x.device).unsqueeze(0)
    direction = target - x
    return direction * 0.1  # Scale factor

## 6. Generate Samples With Guidance

In [None]:
# Generate with guidance
model.guidance_scale = 0.5  # Adjust guidance strength

with torch.no_grad():
    samples_with_guidance = model.sample(
        num_samples=2000,
        num_steps=100,
        guidance_fn=guidance_function,
        device=device
    )

plot_samples(samples_with_guidance, title="Generated Samples (With Guidance)")

## 7. Compare Different Guidance Scales

In [None]:
guidance_scales = [0.0, 0.5, 1.0, 2.0]
fig, axes = plt.subplots(2, 2, figsize=(16, 16))
axes = axes.flatten()

with torch.no_grad():
    for idx, scale in enumerate(guidance_scales):
        model.guidance_scale = scale
        samples = model.sample(
            num_samples=1000,
            num_steps=100,
            guidance_fn=guidance_function if scale > 0 else None,
            device=device
        )
        samples_np = samples.cpu().numpy()
        
        axes[idx].scatter(samples_np[:, 0], samples_np[:, 1], alpha=0.5, s=10)
        axes[idx].set_title(f"Guidance Scale: {scale}")
        axes[idx].set_xlabel("x")
        axes[idx].set_ylabel("y")
        axes[idx].grid(True, alpha=0.3)
        axes[idx].axis("equal")

plt.tight_layout()
plt.show()

## 8. Energy-Based Guidance

Demonstrate guidance using an energy function.

In [None]:
def energy_based_guidance(x):
    """Guide using a simple quadratic energy function."""
    # Energy that prefers center of coordinate system
    energy = (x ** 2).sum(dim=-1, keepdim=True)
    
    # Compute gradient
    if x.requires_grad:
        grad = torch.autograd.grad(energy.sum(), x, create_graph=True)[0]
        return -grad  # Negative gradient moves toward lower energy
    return torch.zeros_like(x)

model.guidance_scale = 0.3

with torch.no_grad():
    samples_energy = model.sample(
        num_samples=2000,
        num_steps=100,
        guidance_fn=energy_based_guidance,
        device=device
    )

plot_samples(samples_energy, title="Generated Samples (Energy-Based Guidance)")

## Conclusion

This notebook demonstrated:
1. Training a guided VFM model
2. Defining custom guidance functions
3. Generating samples with and without guidance
4. Comparing different guidance scales
5. Using energy-based guidance

You can experiment with:
- Different guidance functions (classifier-based, energy-based, etc.)
- Multiple guidance signals simultaneously
- Adaptive guidance scales during generation
- Conditional guidance for specific attributes