# Basic Variational Flow Matching

This notebook demonstrates the basic VFM model for learning 2D distributions.

## Overview
Variational Flow Matching learns a continuous normalizing flow that transforms a simple base distribution (Gaussian) to a complex target distribution.

In [None]:
import sys
sys.path.append("../..")

import torch
import matplotlib.pyplot as plt
from src.models import BasicVFM
from src.utils import generate_toy_data, plot_samples, train_vfm, create_dataloader

print("PyTorch version:", torch.__version__)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

## 1. Generate Target Data

First, we generate a toy 2D dataset that we want our model to learn.

In [None]:
# Generate two moons dataset
target_data = generate_toy_data("two_moons", num_samples=2000, noise=0.1)
plot_samples(target_data, title="Target Distribution: Two Moons")

## 2. Initialize Model

Create a basic VFM model with specified architecture.

In [None]:
# Initialize model
model = BasicVFM(
    input_dim=2,
    hidden_dim=128,
    num_layers=3,
    time_embedding_dim=32
)

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

## 3. Train the Model

Train the VFM model to learn the target distribution.

In [None]:
# Create dataloader
dataloader = create_dataloader(target_data, batch_size=128, shuffle=True)

# Train model
history = train_vfm(
    model=model,
    dataloader=dataloader,
    num_epochs=50,
    learning_rate=1e-3,
    device=device,
    verbose=True
)

## 4. Visualize Training Loss

In [None]:
plt.figure(figsize=(10, 5))
plt.plot(history["losses"])
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training Loss")
plt.grid(True)
plt.show()

## 5. Generate Samples

Generate new samples from the trained model.

In [None]:
model.eval()
with torch.no_grad():
    generated_samples = model.sample(num_samples=2000, num_steps=100, device=device)

plot_samples(generated_samples, title="Generated Samples from VFM")

## 6. Compare Target and Generated Distributions

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(16, 7))

# Plot target
target_np = target_data.numpy()
axes[0].scatter(target_np[:, 0], target_np[:, 1], alpha=0.5, s=10)
axes[0].set_title("Target Distribution")
axes[0].set_xlabel("x")
axes[0].set_ylabel("y")
axes[0].grid(True, alpha=0.3)
axes[0].axis("equal")

# Plot generated
generated_np = generated_samples.cpu().numpy()
axes[1].scatter(generated_np[:, 0], generated_np[:, 1], alpha=0.5, s=10)
axes[1].set_title("Generated Distribution")
axes[1].set_xlabel("x")
axes[1].set_ylabel("y")
axes[1].grid(True, alpha=0.3)
axes[1].axis("equal")

plt.tight_layout()
plt.show()

## Conclusion

This notebook demonstrated:
1. How to set up a basic VFM model
2. Training the model on a 2D toy dataset
3. Generating samples from the trained model
4. Comparing target and generated distributions

You can experiment with:
- Different datasets (swiss_roll, circles, gaussian)
- Different model architectures (hidden_dim, num_layers)
- Different training hyperparameters (learning_rate, num_epochs)