# Classifier-Free Guidance for Bitcoin Time Series Diffusion

This notebook demonstrates how classifier-free guidance (CFG) works in the context of time series prediction with diffusion models.

## Key Concepts

**Classifier-Free Guidance** allows us to control the strength of conditioning on historical data:
- `guidance_scale = 1.0`: Standard conditional generation
- `guidance_scale > 1.0`: Stronger conditioning on history
- `guidance_scale < 1.0`: Weaker conditioning (more unconditional)

The formula is: `score = uncond_score + guidance_scale * (cond_score - uncond_score)`

In [None]:
import sys
import os
sys.path.append('..')

import torch
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from pathlib import Path

from src.evaluation import DiffusionPredictor, MetricCalculator
from src.data import DataPreprocessor
from src.utils import load_config, plot_predictions

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print("Libraries imported successfully!")

## Load Trained Model and Data

In [None]:
# Load configuration
config = load_config('../configs/bitcoin_cfg_experiment.yaml')

# Load trained model (adjust path as needed)
model_path = '../models/checkpoints_cfg/best_model.pt'
if not os.path.exists(model_path):
    print(f"Model not found at {model_path}")
    print("Please train a model first using: python scripts/train.py --config configs/bitcoin_cfg_experiment.yaml")
else:
    predictor = DiffusionPredictor.from_checkpoint(model_path)
    print("Model loaded successfully!")

In [None]:
# Load test data
test_data_path = '../data/processed/bitcoin_test.pt'
if os.path.exists(test_data_path):
    preprocessor = DataPreprocessor({})
    test_data, metadata = preprocessor.load_processed_data(test_data_path)
    print(f"Test data shape: {test_data.shape}")
else:
    print("Test data not found. Please run data preparation first.")
    print("python scripts/prepare_data.py --input data/raw/bitcoin.csv --output data/processed/")

## Prepare Sample Data

In [None]:
# Extract a sample for demonstration
history_len = config['model']['history_len']
predict_len = config['model']['predict_len']

# Use the middle portion of test data
start_idx = len(test_data) // 2
sample_history = test_data[start_idx:start_idx + history_len].unsqueeze(0)
sample_future = test_data[start_idx + history_len:start_idx + history_len + predict_len]

print(f"Sample history shape: {sample_history.shape}")
print(f"Sample future shape: {sample_future.shape}")
print(f"Ground truth available: {len(sample_future) == predict_len}")

## Compare Different Guidance Scales

In [None]:
# Test different guidance scales
guidance_scales = [0.5, 1.0, 1.5, 2.0, 3.0, 5.0]
num_samples = 50  # Fewer samples for speed
num_steps = 500   # Fewer steps for speed

predictions = {}

for scale in guidance_scales:
    print(f"Generating predictions with guidance scale {scale}...")
    
    pred = predictor.predict(
        history=sample_history,
        num_samples=num_samples,
        num_steps=num_steps,
        guidance_scale=scale,
        denormalize=True,
        return_dict=True
    )
    
    predictions[scale] = pred
    print(f"  Mean prediction range: [{pred['mean'].min():.4f}, {pred['mean'].max():.4f}]")
    print(f"  Prediction std: {pred['std'].mean():.4f}")

print("\nPrediction generation complete!")

## Visualize the Effects of Classifier-Free Guidance

In [None]:
# Create a comprehensive comparison plot
fig, axes = plt.subplots(2, 3, figsize=(18, 10))
axes = axes.flatten()

# Denormalize history and ground truth if available
history_denorm = sample_history.squeeze().numpy()
if predictor.preprocessor is not None:
    history_denorm = predictor.preprocessor.denormalize(sample_history.squeeze()).numpy()
    if len(sample_future) == predict_len:
        future_denorm = predictor.preprocessor.denormalize(sample_future).numpy()
    else:
        future_denorm = None
else:
    future_denorm = sample_future.numpy() if len(sample_future) == predict_len else None

for i, scale in enumerate(guidance_scales):
    ax = axes[i]
    pred = predictions[scale]
    
    # Time indices
    history_time = np.arange(-history_len, 0)
    future_time = np.arange(predict_len)
    
    # Plot history
    ax.plot(history_time, history_denorm, 'k-', label='History', linewidth=2, alpha=0.7)
    
    # Plot mean prediction
    ax.plot(future_time, pred['mean'][0], 'b-', label='Mean Prediction', linewidth=2)
    
    # Plot confidence intervals
    if 'quantiles' in pred:
        ax.fill_between(
            future_time,
            pred['quantiles'][0.1][0],
            pred['quantiles'][0.9][0],
            alpha=0.2, color='blue', label='80% CI'
        )
        ax.fill_between(
            future_time,
            pred['quantiles'][0.25][0],
            pred['quantiles'][0.75][0],
            alpha=0.3, color='blue', label='50% CI'
        )
    
    # Plot ground truth if available
    if future_denorm is not None:
        ax.plot(future_time, future_denorm, 'r-', label='Ground Truth', linewidth=2)
    
    # Formatting
    ax.set_title(f'Guidance Scale = {scale}', fontsize=12, fontweight='bold')
    ax.set_xlabel('Time Steps')
    ax.set_ylabel('Log Returns')
    ax.grid(True, alpha=0.3)
    ax.axvline(x=0, color='gray', linestyle='--', alpha=0.5)
    
    if i == 0:
        ax.legend()

plt.tight_layout()
plt.suptitle('Effect of Classifier-Free Guidance on Bitcoin Predictions', 
             fontsize=16, fontweight='bold', y=1.02)
plt.show()

## Analyze Prediction Uncertainty

In [None]:
# Compare prediction uncertainty across guidance scales
plt.figure(figsize=(12, 8))

# Plot 1: Mean prediction variance
plt.subplot(2, 2, 1)
variances = [predictions[scale]['std'][0].mean() for scale in guidance_scales]
plt.plot(guidance_scales, variances, 'o-', linewidth=2, markersize=8)
plt.xlabel('Guidance Scale')
plt.ylabel('Mean Prediction Std')
plt.title('Prediction Uncertainty vs Guidance Scale')
plt.grid(True, alpha=0.3)

# Plot 2: Prediction spread (max - min)
plt.subplot(2, 2, 2)
spreads = [predictions[scale]['mean'][0].max() - predictions[scale]['mean'][0].min() 
           for scale in guidance_scales]
plt.plot(guidance_scales, spreads, 'o-', linewidth=2, markersize=8, color='orange')
plt.xlabel('Guidance Scale')
plt.ylabel('Prediction Range')
plt.title('Prediction Range vs Guidance Scale')
plt.grid(True, alpha=0.3)

# Plot 3: Distribution of final predictions
plt.subplot(2, 2, 3)
final_predictions = [predictions[scale]['samples'][0, :, -1] for scale in guidance_scales]
plt.boxplot(final_predictions, labels=[f'{s}' for s in guidance_scales])
plt.xlabel('Guidance Scale')
plt.ylabel('Final Step Prediction')
plt.title('Distribution of Final Predictions')
plt.grid(True, alpha=0.3)

# Plot 4: Quantile width (uncertainty measure)
plt.subplot(2, 2, 4)
quantile_widths = []
for scale in guidance_scales:
    if 'quantiles' in predictions[scale]:
        width = (predictions[scale]['quantiles'][0.9][0] - 
                predictions[scale]['quantiles'][0.1][0]).mean()
        quantile_widths.append(width)
    else:
        quantile_widths.append(0)

plt.plot(guidance_scales, quantile_widths, 'o-', linewidth=2, markersize=8, color='green')
plt.xlabel('Guidance Scale')
plt.ylabel('80% Quantile Width')
plt.title('Uncertainty Band Width vs Guidance Scale')
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## Quantitative Evaluation

In [None]:
# Evaluate predictions quantitatively (if ground truth is available)
if future_denorm is not None:
    metric_calc = MetricCalculator()
    
    results = []
    
    for scale in guidance_scales:
        pred = predictions[scale]
        metrics = metric_calc.compute_all_metrics(pred, future_denorm)
        
        result = {'guidance_scale': scale}
        result.update(metrics)
        results.append(result)
    
    # Convert to DataFrame for easy viewing
    results_df = pd.DataFrame(results)
    print("\nQuantitative Evaluation Results:")
    print("=" * 50)
    print(results_df.round(4))
    
    # Plot key metrics
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    
    # MSE
    axes[0].plot(results_df['guidance_scale'], results_df['mse'], 'o-', linewidth=2)
    axes[0].set_xlabel('Guidance Scale')
    axes[0].set_ylabel('MSE')
    axes[0].set_title('Mean Squared Error')
    axes[0].grid(True, alpha=0.3)
    
    # Directional Accuracy
    axes[1].plot(results_df['guidance_scale'], results_df['directional_accuracy'], 'o-', linewidth=2, color='orange')
    axes[1].set_xlabel('Guidance Scale')
    axes[1].set_ylabel('Directional Accuracy')
    axes[1].set_title('Directional Accuracy')
    axes[1].grid(True, alpha=0.3)
    
    # CRPS (if available)
    if 'crps' in results_df.columns:
        axes[2].plot(results_df['guidance_scale'], results_df['crps'], 'o-', linewidth=2, color='green')
        axes[2].set_xlabel('Guidance Scale')
        axes[2].set_ylabel('CRPS')
        axes[2].set_title('Continuous Ranked Probability Score')
        axes[2].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Find optimal guidance scale
    optimal_mse_scale = results_df.loc[results_df['mse'].idxmin(), 'guidance_scale']
    optimal_dir_scale = results_df.loc[results_df['directional_accuracy'].idxmax(), 'guidance_scale']
    
    print(f"\nOptimal guidance scales:")
    print(f"  Best MSE: {optimal_mse_scale}")
    print(f"  Best Directional Accuracy: {optimal_dir_scale}")
    
else:
    print("Ground truth not available for quantitative evaluation.")

## Key Insights

From this analysis, you should observe:

1. **Guidance Scale Effects**:
   - `guidance_scale < 1.0`: More unconditional, potentially more diverse but less conditioned on history
   - `guidance_scale = 1.0`: Standard conditional generation
   - `guidance_scale > 1.0`: Stronger conditioning on historical patterns

2. **Trade-offs**:
   - Higher guidance → More deterministic, potentially better conditioning
   - Lower guidance → More uncertainty, potentially more diverse samples
   - Optimal scale depends on your specific use case

3. **Practical Recommendations**:
   - For most applications: `guidance_scale = 1.5 to 2.5`
   - For high-confidence predictions: `guidance_scale = 3.0+`
   - For diverse scenario generation: `guidance_scale = 0.5 to 1.0`

## Advanced: Conditional vs Unconditional Comparison

In [None]:
# Compare conditional vs unconditional generation directly
print("Generating conditional vs unconditional samples...")

# Generate with strong conditioning
conditional_pred = predictor.predict(
    history=sample_history,
    num_samples=num_samples,
    num_steps=num_steps,
    guidance_scale=3.0,  # Strong conditioning
    denormalize=True,
    return_dict=True
)

# Generate with weak conditioning (more unconditional)
unconditional_pred = predictor.predict(
    history=sample_history,
    num_samples=num_samples,
    num_steps=num_steps,
    guidance_scale=0.5,  # Weak conditioning
    denormalize=True,
    return_dict=True
)

# Plot comparison
fig, axes = plt.subplots(1, 2, figsize=(15, 6))

# Time indices
history_time = np.arange(-history_len, 0)
future_time = np.arange(predict_len)

for i, (pred, title, scale) in enumerate([
    (unconditional_pred, 'Weak Conditioning (Scale=0.5)', 0.5),
    (conditional_pred, 'Strong Conditioning (Scale=3.0)', 3.0)
]):
    ax = axes[i]
    
    # Plot history
    ax.plot(history_time, history_denorm, 'k-', label='History', linewidth=2)
    
    # Plot multiple sample trajectories
    n_show = min(10, num_samples)
    for j in range(n_show):
        ax.plot(future_time, pred['samples'][j, 0], 
               alpha=0.3, color='blue', linewidth=1)
    
    # Plot mean
    ax.plot(future_time, pred['mean'][0], 'b-', 
           label='Mean Prediction', linewidth=3)
    
    # Plot ground truth if available
    if future_denorm is not None:
        ax.plot(future_time, future_denorm, 'r-', 
               label='Ground Truth', linewidth=2)
    
    ax.set_title(title, fontsize=14, fontweight='bold')
    ax.set_xlabel('Time Steps')
    ax.set_ylabel('Log Returns')
    ax.legend()
    ax.grid(True, alpha=0.3)
    ax.axvline(x=0, color='gray', linestyle='--', alpha=0.5)

plt.tight_layout()
plt.show()

print(f"\nUnconditional (scale=0.5) - Sample std: {unconditional_pred['std'][0].mean():.4f}")
print(f"Conditional (scale=3.0) - Sample std: {conditional_pred['std'][0].mean():.4f}")
print(f"Std ratio (uncond/cond): {unconditional_pred['std'][0].mean() / conditional_pred['std'][0].mean():.2f}")