# Conditional Variational Flow Matching

This notebook demonstrates conditional VFM for controlled generation.

## Overview
Conditional VFM extends basic VFM to incorporate conditioning information, allowing us to control the generation process.

In [None]:
import sys
sys.path.append("../..")

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from src.models import ConditionalVFM
from src.utils import generate_toy_data, plot_samples, train_conditional_vfm

print("PyTorch version:", torch.__version__)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

## 1. Generate Conditional Data

Create multiple datasets with different characteristics, each associated with a class label.

In [None]:
# Generate different datasets for different classes
num_samples_per_class = 500
num_classes = 3

data_list = []
labels_list = []

datasets = ["two_moons", "circles", "gaussian"]

for class_idx, dataset_name in enumerate(datasets):
    data = generate_toy_data(dataset_name, num_samples=num_samples_per_class, noise=0.1)
    # Shift each class to different location
    data = data + torch.tensor([class_idx * 3.0, 0.0])
    data_list.append(data)
    labels_list.append(torch.full((num_samples_per_class,), class_idx, dtype=torch.long))

all_data = torch.cat(data_list, dim=0)
all_labels = torch.cat(labels_list, dim=0)

print(f"Total samples: {len(all_data)}, Classes: {num_classes}")

## 2. Visualize Conditional Data

In [None]:
plt.figure(figsize=(12, 6))
colors = ["red", "green", "blue"]
for i in range(num_classes):
    mask = all_labels == i
    data_i = all_data[mask].numpy()
    plt.scatter(data_i[:, 0], data_i[:, 1], alpha=0.5, s=10, c=colors[i], label=f"Class {i}")
plt.xlabel("x")
plt.ylabel("y")
plt.title("Conditional Training Data")
plt.legend()
plt.grid(True, alpha=0.3)
plt.axis("equal")
plt.show()

## 3. Initialize Conditional Model

In [None]:
# Convert labels to one-hot encoding
conditions = F.one_hot(all_labels, num_classes=num_classes).float()

# Initialize conditional model
model = ConditionalVFM(
    input_dim=2,
    condition_dim=num_classes,
    hidden_dim=128,
    num_layers=3,
    time_embedding_dim=32
)

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

## 4. Train Conditional Model

In [None]:
# Create dataset and dataloader
dataset = torch.utils.data.TensorDataset(all_data, conditions)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=128, shuffle=True)

# Train model
history = train_conditional_vfm(
    model=model,
    dataloader=dataloader,
    num_epochs=50,
    learning_rate=1e-3,
    device=device,
    verbose=True
)

## 5. Generate Conditional Samples

Generate samples conditioned on specific class labels.

In [None]:
model.eval()

fig, axes = plt.subplots(1, num_classes, figsize=(18, 6))

with torch.no_grad():
    for class_idx in range(num_classes):
        # Create condition for this class
        condition = F.one_hot(torch.tensor([class_idx] * 500), num_classes=num_classes).float()
        
        # Generate samples
        samples = model.sample(condition=condition, num_steps=100, device=device)
        samples_np = samples.cpu().numpy()
        
        # Plot
        axes[class_idx].scatter(samples_np[:, 0], samples_np[:, 1], alpha=0.5, s=10, c=colors[class_idx])
        axes[class_idx].set_title(f"Class {class_idx} Samples")
        axes[class_idx].set_xlabel("x")
        axes[class_idx].set_ylabel("y")
        axes[class_idx].grid(True, alpha=0.3)
        axes[class_idx].axis("equal")

plt.tight_layout()
plt.show()

## 6. Compare All Classes

In [None]:
plt.figure(figsize=(12, 6))

with torch.no_grad():
    for class_idx in range(num_classes):
        condition = F.one_hot(torch.tensor([class_idx] * 500), num_classes=num_classes).float()
        samples = model.sample(condition=condition, num_steps=100, device=device)
        samples_np = samples.cpu().numpy()
        plt.scatter(samples_np[:, 0], samples_np[:, 1], alpha=0.5, s=10, c=colors[class_idx], label=f"Class {class_idx}")

plt.xlabel("x")
plt.ylabel("y")
plt.title("All Generated Classes")
plt.legend()
plt.grid(True, alpha=0.3)
plt.axis("equal")
plt.show()

## Conclusion

This notebook demonstrated:
1. Creating conditional training data with multiple classes
2. Training a conditional VFM model
3. Generating samples conditioned on specific classes
4. Visualizing the conditional generation results

You can experiment with:
- More classes or different conditioning variables
- Different conditioning strategies (continuous values, multiple attributes)
- Combining multiple types of conditions