# FinDiffusion Demo

This notebook demonstrates how to:
1. Load and preprocess financial data
2. Train a conditional diffusion model
3. Generate synthetic financial time series
4. Evaluate the quality of generated data

In [None]:
import sys
sys.path.insert(0, '..')

import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
import seaborn as sns

from src.models import FinancialDiffusion
from src.data import FinancialDataModule
from src.evaluation import validate_stylized_facts, compute_all_metrics

plt.style.use('seaborn-v0_8-whitegrid')
%matplotlib inline

## 1. Load Data

In [None]:
# Setup data module
data_module = FinancialDataModule(
    tickers=["AAPL", "MSFT", "GOOGL", "AMZN", "META"],
    start_date="2015-01-01",
    end_date="2024-01-01",
    seq_len=252,  # 1 trading year
    stride=21,    # ~1 month
    batch_size=32,
)

data_module.setup()

print(f"Training samples: {len(data_module.train_dataset)}")
print(f"Validation samples: {len(data_module.val_dataset)}")
print(f"Test samples: {len(data_module.test_dataset)}")

In [None]:
# Visualize sample data
sample_batch = data_module.get_sample_batch()
sample_batch_denorm = data_module.denormalize(sample_batch.numpy())

fig, axes = plt.subplots(2, 2, figsize=(12, 8))

for i, ax in enumerate(axes.flat):
    returns = sample_batch_denorm[i]
    cum_ret = np.cumprod(1 + returns)
    ax.plot(cum_ret)
    ax.set_title(f'Sample {i+1}: Total Return = {(cum_ret[-1]-1)*100:.1f}%')
    ax.set_xlabel('Days')
    ax.set_ylabel('Cumulative Return')

plt.tight_layout()
plt.show()

## 2. Create Model

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

model = FinancialDiffusion(
    seq_len=252,
    input_dim=1,
    d_model=128,      # Smaller for demo
    n_layers=4,
    n_heads=4,
    d_ff=256,
    d_cond=64,
    timesteps=500,    # Fewer timesteps for demo
)

model = model.to(device)

n_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {n_params:,}")

## 3. Quick Training (Demo)

In [None]:
from src.training import Trainer, TrainingConfig

config = TrainingConfig(
    epochs=5,           # Just a few epochs for demo
    lr=1e-4,
    use_amp=True,
    log_every=50,
    save_every=5,
    use_wandb=False,
    checkpoint_dir='../checkpoints_demo',
)

trainer = Trainer(
    model=model,
    train_loader=data_module.train_dataloader(),
    val_loader=data_module.val_dataloader(),
    config=config,
    device=device,
)

In [None]:
# Train (this will take a few minutes)
trainer.train()

## 4. Generate Synthetic Data

In [None]:
model.eval()

# Generate unconditional samples
with torch.no_grad():
    synthetic = model.generate(
        n_samples=100,
        use_ddim=True,      # Faster sampling
        ddim_steps=50,
        device=device,
    )

synthetic = synthetic.cpu().numpy()
synthetic = data_module.denormalize(synthetic)

print(f"Generated {len(synthetic)} samples")
print(f"Mean daily return: {synthetic.mean():.6f}")
print(f"Std daily return: {synthetic.std():.6f}")

In [None]:
# Generate conditional samples (high volatility)
with torch.no_grad():
    high_vol = model.generate(
        n_samples=50,
        conditions={"trend": 0.0, "volatility": 0.40, "regime": "bear"},
        use_ddim=True,
        ddim_steps=50,
        device=device,
    )

high_vol = data_module.denormalize(high_vol.cpu().numpy())
print(f"High vol samples - Realized vol: {high_vol.std() * np.sqrt(252):.1%}")

In [None]:
# Visualize generated samples
fig, axes = plt.subplots(2, 3, figsize=(15, 8))

for i, ax in enumerate(axes.flat):
    returns = synthetic[i]
    cum_ret = np.cumprod(1 + returns)
    ax.plot(cum_ret, color='blue', alpha=0.8)
    ax.set_title(f'Synthetic {i+1}: Return={100*(cum_ret[-1]-1):.1f}%, Vol={returns.std()*np.sqrt(252):.0%}')
    ax.set_xlabel('Days')

plt.suptitle('Generated Synthetic Price Paths', fontsize=14)
plt.tight_layout()
plt.show()

## 5. Evaluate Stylized Facts

In [None]:
# Get real test samples
real_samples = []
for i in range(min(100, len(data_module.test_dataset))):
    real_samples.append(data_module.test_dataset[i].numpy())
real_samples = np.array(real_samples)
real_samples = data_module.denormalize(real_samples)

# Validate stylized facts
print("=" * 50)
print("REAL DATA STYLIZED FACTS")
print("=" * 50)
real_results = validate_stylized_facts(real_samples)
for test, result in real_results.items():
    if test != 'summary':
        print(f"{test}: {'PASS' if result['passed'] else 'FAIL'} - {result['interpretation']}")

print("\n" + "=" * 50)
print("SYNTHETIC DATA STYLIZED FACTS")
print("=" * 50)
syn_results = validate_stylized_facts(synthetic)
for test, result in syn_results.items():
    if test != 'summary':
        print(f"{test}: {'PASS' if result['passed'] else 'FAIL'} - {result['interpretation']}")

In [None]:
# Distribution comparison
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Histogram
axes[0].hist(real_samples.flatten(), bins=100, alpha=0.7, label='Real', density=True)
axes[0].hist(synthetic.flatten(), bins=100, alpha=0.7, label='Synthetic', density=True)
axes[0].set_xlabel('Daily Returns')
axes[0].set_ylabel('Density')
axes[0].set_title('Return Distribution')
axes[0].legend()
axes[0].set_xlim(-0.1, 0.1)

# ACF of squared returns
max_lag = 20
real_sq = real_samples.flatten() ** 2
syn_sq = synthetic.flatten() ** 2

acf_real = [np.corrcoef(real_sq[:-lag], real_sq[lag:])[0,1] for lag in range(1, max_lag+1)]
acf_syn = [np.corrcoef(syn_sq[:-lag], syn_sq[lag:])[0,1] for lag in range(1, max_lag+1)]

x = np.arange(1, max_lag+1)
axes[1].bar(x - 0.2, acf_real, width=0.4, label='Real', alpha=0.7)
axes[1].bar(x + 0.2, acf_syn, width=0.4, label='Synthetic', alpha=0.7)
axes[1].set_xlabel('Lag')
axes[1].set_ylabel('ACF')
axes[1].set_title('Autocorrelation of Squared Returns\n(Volatility Clustering)')
axes[1].legend()

plt.tight_layout()
plt.show()

## 6. Comprehensive Metrics

In [None]:
metrics = compute_all_metrics(real_samples, synthetic)

print("\nDISTRIBUTION METRICS")
print("-" * 30)
for k, v in metrics['distribution'].items():
    print(f"  {k}: {v:.6f}")

print("\nTEMPORAL METRICS")
print("-" * 30)
for k, v in metrics['temporal'].items():
    print(f"  {k}: {v:.6f}")

print("\nDIVERSITY METRICS")
print("-" * 30)
for k, v in metrics['diversity'].items():
    print(f"  {k}: {v:.6f}")

print("\nOVERALL SCORE")
print("-" * 30)
print(f"  Overall: {metrics['summary']['overall_score']:.4f}")

## 7. Save Generated Data

In [None]:
# Save synthetic data for later use
df = pd.DataFrame(synthetic)
df.columns = [f't_{i}' for i in range(synthetic.shape[1])]
df.to_csv('../outputs/synthetic_returns_demo.csv', index=False)
print("Saved synthetic returns to outputs/synthetic_returns_demo.csv")

## Summary

This demo showed:
- How to load and preprocess financial data using Yahoo Finance
- Training a conditional diffusion model on return sequences
- Generating synthetic data with controllable conditions (trend, volatility, regime)
- Evaluating data quality using stylized facts tests

For production use:
- Train for more epochs (50-100)
- Use larger model dimensions
- Include more tickers for diversity
- Enable Weights & Biases for experiment tracking