# Quantformer: Transformer-based Quantitative Trading

This notebook demonstrates the implementation of the Quantformer model as described in the paper:
"Quantformer: from attention to profit with a quantitative transformer trading strategy"

## Overview

The Quantformer adapts the transformer architecture for quantitative trading by:
1. Replacing word embeddings with linear embeddings for numerical data
2. Removing positional encoding (time series have inherent order)
3. Simplifying the decoder for classification tasks
4. Using market sentiment information (returns and turnover rates)

## Contents
1. [Setup and Imports](#setup)
2. [Data Generation and Preprocessing](#data)
3. [Model Architecture](#model)
4. [Training](#training)
5. [Trading Strategy](#strategy)
6. [Backtesting](#backtesting)
7. [Results Analysis](#results)


In [None]:
# Install required packages if needed
# !pip install torch numpy pandas scikit-learn matplotlib seaborn plotly tqdm wandb

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm
import warnings
warnings.filterwarnings('ignore')

# Set style for plots
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)

# Check if CUDA is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Import our Quantformer modules
from quantformer import (
    Quantformer, QuantformerTrainer, create_quantformer_model,
    StockDataProcessor, prepare_training_data, create_sample_data,
    QuantformerTradingStrategy, TradingConfig,
    setup_training_experiment, setup_backtesting_experiment
)


In [None]:
## WandB Setup (Optional)

# Set up Weights & Biases for experiment tracking
# You can disable this by setting ENABLE_WANDB = False
ENABLE_WANDB = True  # Set to False to disable WandB logging

if ENABLE_WANDB:
    try:
        import wandb
        print("🔧 WandB is available for experiment tracking")
        print("📝 To use WandB:")
        print("   1. Create a free account at https://wandb.ai")
        print("   2. Run 'wandb login' in terminal")
        print("   3. Or set WANDB_API_KEY environment variable")
        print("   4. The notebook will automatically log experiments")
        
        # Test WandB connection (optional)
        # Uncomment the next line to test WandB setup
        # wandb.login()
        
    except ImportError:
        print("⚠️  WandB not installed. Install with: pip install wandb")
        ENABLE_WANDB = False
else:
    print("📊 WandB logging disabled - experiments will run without tracking")

print(f"WandB Status: {'Enabled' if ENABLE_WANDB else 'Disabled'}")


In [None]:
# Configuration
N_STOCKS = 200  # Number of stocks in our universe
N_TIMESTEPS = 1000  # Total time periods
SEQ_LEN = 20  # Input sequence length (as in paper)
N_CLASSES = 3  # Number of quantile classes (ρ in paper)
PHI = 0.2  # Percentage of stocks for each quantile

print(f"Generating synthetic stock data...")
print(f"- {N_STOCKS} stocks")
print(f"- {N_TIMESTEPS} time periods")
print(f"- {SEQ_LEN}-day input sequences")

# Generate sample data with improved realism
feature_data, return_data = create_sample_data(
    n_stocks=N_STOCKS, 
    n_timesteps=N_TIMESTEPS, 
    seq_len=SEQ_LEN,
    random_seed=123  # Different seed for more challenging data
)

print(f"Feature data shape: {feature_data.shape}")
print(f"Return data shape: {return_data.shape}")
print(f"Features: [returns, turnover_rates]")

# Check data statistics
print(f"\nData Statistics:")
print(f"Returns - Mean: {return_data.mean():.6f}, Std: {return_data.std():.6f}")
print(f"Returns - Range: [{return_data.min():.6f}, {return_data.max():.6f}]")
print(f"Turnover - Mean: {feature_data[:,:,1].mean():.6f}, Std: {feature_data[:,:,1].std():.6f}")


In [None]:
# Model configuration (as specified in the paper)
model_config = {
    'input_dim': 2,  # Returns and turnover rates
    'd_model': 32,   # Increased hidden dimension for better capacity
    'n_heads': 8,    # Number of attention heads
    'n_layers': 4,   # Reduced layers to prevent overfitting
    'd_ff': 128,     # Increased feed-forward dimension
    'n_classes': N_CLASSES,
    'seq_len': SEQ_LEN,
    'dropout': 0.2   # Increased dropout for regularization
}

# Create model
model = create_quantformer_model(model_config)
model = model.to(device)

print(f"Model created with {sum(p.numel() for p in model.parameters())} parameters")
print(f"Model configuration: {model_config}")

# Test model with sample input to verify it works
sample_input = torch.randn(1, SEQ_LEN, 2).to(device)
with torch.no_grad():
    sample_output = model(sample_input)
    print(f"Sample output shape: {sample_output.shape}")
    print(f"Sample output probabilities: {sample_output.cpu().numpy().flatten()}")
    print(f"Sum of probabilities: {sample_output.sum().item():.6f} (should be ~1.0)")


In [None]:
# Prepare training and validation data
print("Preparing training data...")

train_loader, val_loader, processor = prepare_training_data(
    feature_data=feature_data,
    return_data=return_data,
    seq_len=SEQ_LEN,
    n_classes=N_CLASSES,
    phi=PHI,
    train_ratio=0.8,
    batch_size=64
)

print(f"Training samples: {len(train_loader.dataset)}")
print(f"Validation samples: {len(val_loader.dataset)}")

# Examine a batch
for batch_x, batch_y in train_loader:
    print(f"Batch input shape: {batch_x.shape}")
    print(f"Batch label shape: {batch_y.shape}")
    print(f"Sample label distribution: {batch_y.sum(dim=0)}")
    break


In [None]:
# Training configuration
EPOCHS = 25  # Slightly more epochs
LEARNING_RATE = 0.0005  # Lower learning rate for stability

# Set up WandB experiment tracking
wandb_logger = None
if ENABLE_WANDB:
    # Create experiment configuration
    training_config = {
        "epochs": EPOCHS,
        "learning_rate": LEARNING_RATE,
        "batch_size": 64,
        "optimizer": "Adam",
        "loss_function": "MSE"
    }
    
    data_config = {
        "n_stocks": N_STOCKS,
        "n_timesteps": N_TIMESTEPS,
        "seq_len": SEQ_LEN,
        "n_classes": N_CLASSES,
        "phi": PHI
    }
    
    # Initialize WandB logger
    wandb_logger = setup_training_experiment(
        model_config=model_config,
        training_config=training_config,
        data_config=data_config,
        notes="Quantformer training with improved synthetic data and fixed label generation"
    )
    
    # Log data statistics
    if wandb_logger.enabled:
        wandb_logger.log_data_statistics(feature_data, return_data)

# Initialize trainer with WandB logger
trainer = QuantformerTrainer(
    model=model,
    device=device,
    learning_rate=LEARNING_RATE,
    wandb_logger=wandb_logger
)

print(f"Starting training for {EPOCHS} epochs...")
print(f"Learning rate: {LEARNING_RATE}")
print(f"Device: {device}")
print(f"Training samples: {len(train_loader.dataset)}")
print(f"Validation samples: {len(val_loader.dataset)}")
if wandb_logger and wandb_logger.enabled:
    print(f"🔗 WandB tracking: {wandb_logger.run.url}")

# Training loop with progress tracking
train_losses = []
val_losses = []
val_accuracies = []

# Progress bar
pbar = tqdm(range(EPOCHS), desc="Training")

for epoch in pbar:
    # Training
    epoch_train_loss = 0.0
    model.train()
    
    for batch_x, batch_y in train_loader:
        loss = trainer.train_step(batch_x, batch_y)
        epoch_train_loss += loss
    
    avg_train_loss = epoch_train_loss / len(train_loader)
    train_losses.append(avg_train_loss)
    
    # Validation
    val_loss, val_accuracy = trainer.evaluate(val_loader)
    val_losses.append(val_loss)
    val_accuracies.append(val_accuracy)
    
    # Update progress bar
    pbar.set_postfix({
        'Train Loss': f'{avg_train_loss:.4f}',
        'Val Loss': f'{val_loss:.4f}',
        'Val Acc': f'{val_accuracy:.4f}'
    })
    
    # Early stopping if accuracy is suspiciously high
    if val_accuracy > 0.95 and epoch > 5:
        print(f"\nWarning: Very high accuracy ({val_accuracy:.4f}) detected at epoch {epoch+1}")
        print("This might indicate data leakage or overfitting.")

print("\nTraining completed!")
print(f"Final training loss: {train_losses[-1]:.4f}")
print(f"Final validation loss: {val_losses[-1]:.4f}")
print(f"Final validation accuracy: {val_accuracies[-1]:.4f}")

# Check if training was successful
if val_accuracies[-1] > 0.9:
    print("⚠️  Very high accuracy - check for data issues!")
elif val_accuracies[-1] < 0.4:
    print("⚠️  Very low accuracy - model might need tuning!")
else:
    print("✅ Training completed with reasonable accuracy!")


In [None]:
# Plot training curves with detailed analysis
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Loss curves
axes[0, 0].plot(train_losses, label='Training Loss', linewidth=2, color='blue')
axes[0, 0].plot(val_losses, label='Validation Loss', linewidth=2, color='red')
axes[0, 0].set_title('Training and Validation Loss')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)
axes[0, 0].set_yscale('log')  # Log scale for better visualization

# Accuracy curve
axes[0, 1].plot(val_accuracies, label='Validation Accuracy', linewidth=2, color='green')
axes[0, 1].axhline(y=0.33, color='gray', linestyle='--', alpha=0.7, label='Random Guess (33%)')
axes[0, 1].set_title('Validation Accuracy')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Accuracy')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)
axes[0, 1].set_ylim([0, 1])

# Loss difference (overfitting indicator)
loss_diff = np.array(val_losses) - np.array(train_losses)
axes[1, 0].plot(loss_diff, linewidth=2, color='purple')
axes[1, 0].axhline(y=0, color='gray', linestyle='--', alpha=0.7)
axes[1, 0].set_title('Validation - Training Loss (Overfitting Indicator)')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Loss Difference')
axes[1, 0].grid(True, alpha=0.3)

# Learning curve analysis
epochs = range(1, len(train_losses) + 1)
axes[1, 1].plot(epochs, train_losses, 'o-', label='Training', alpha=0.7)
axes[1, 1].plot(epochs, val_losses, 's-', label='Validation', alpha=0.7)
axes[1, 1].set_title('Learning Curves (Detailed)')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Loss')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Training analysis
print("\n" + "="*50)
print("TRAINING ANALYSIS")
print("="*50)

final_train_loss = train_losses[-1]
final_val_loss = val_losses[-1]
final_accuracy = val_accuracies[-1]

print(f"Final Training Loss: {final_train_loss:.6f}")
print(f"Final Validation Loss: {final_val_loss:.6f}")
print(f"Loss Difference: {final_val_loss - final_train_loss:.6f}")
print(f"Final Accuracy: {final_accuracy:.4f}")

# Check for common issues
if final_accuracy > 0.95:
    print("\n🚨 POTENTIAL ISSUES DETECTED:")
    print("- Accuracy too high (>95%) - possible data leakage")
    print("- Check if future information is leaking into features")
    print("- Verify label generation is correct")
elif final_accuracy < 0.4:
    print("\n🚨 POTENTIAL ISSUES DETECTED:")
    print("- Accuracy too low (<40%) - model not learning")
    print("- Try increasing learning rate or model capacity")
    print("- Check if data preprocessing is correct")
elif abs(final_val_loss - final_train_loss) > 0.1:
    print("\n⚠️  OVERFITTING DETECTED:")
    print("- Large gap between training and validation loss")
    print("- Consider increasing dropout or reducing model complexity")
else:
    print("\n✅ TRAINING LOOKS HEALTHY:")
    print("- Reasonable accuracy range")
    print("- No significant overfitting")
    print("- Model is learning properly")


In [None]:
# Log training results to WandB
if wandb_logger and wandb_logger.enabled:
    print("📊 Logging training results to WandB...")
    
    # Log training curves
    wandb_logger.log_training_curves(train_losses, val_losses, val_accuracies)
    
    # Log final metrics
    wandb_logger.log_metrics({
        "final/train_loss": train_losses[-1],
        "final/val_loss": val_losses[-1], 
        "final/val_accuracy": val_accuracies[-1],
        "final/epochs_trained": len(train_losses)
    })
    
    # Save model artifact
    wandb_logger.log_model_artifact(
        model, 
        name="quantformer_trained_model",
        metadata={
            "final_train_loss": train_losses[-1],
            "final_val_loss": val_losses[-1],
            "final_val_accuracy": val_accuracies[-1],
            "epochs_trained": len(train_losses),
            "model_config": model_config
        }
    )
    
    print("✅ Training results logged to WandB")
else:
    print("📊 WandB logging disabled - results not logged")


## Fixes Applied for Training Issues

If you previously saw perfect accuracy (1.0) and very low loss, here are the fixes that have been applied:

### 1. **Fixed Quantile Label Generation**
- Corrected the quantile range calculation in `create_quantile_labels()`
- Now properly creates non-overlapping quantile ranges
- For 3 classes: Bottom 20%, Middle 20%, Top 20% (with gaps as in paper)

### 2. **Improved Synthetic Data Generation**
- More realistic stock returns with varying volatility per stock
- Added momentum and mean-reversion effects
- Correlated turnover rates with volatility
- Market-wide effects to simulate real market conditions
- Different random seed to avoid overfitting to specific patterns

### 3. **Model Architecture Adjustments**
- Increased model capacity (d_model=32, d_ff=128)
- Reduced layers to prevent overfitting (n_layers=4)
- Increased dropout for regularization (0.2)
- Lower learning rate for stability (0.0005)

### 4. **Enhanced Debugging**
- Added data statistics and distribution checks
- Training progress monitoring with early warning systems
- Detailed analysis of potential overfitting or underfitting

### Expected Results
With these fixes, you should now see:
- **Accuracy**: 50-80% (reasonable for financial prediction)
- **Loss**: Gradually decreasing but not reaching zero
- **Learning**: Steady improvement without perfect predictions

The previous perfect accuracy was likely due to data leakage or overly simplistic synthetic data patterns.


In [None]:
# Trading strategy configuration
trading_config = TradingConfig(
    n_classes=N_CLASSES,
    decision_factor=1,  # Select top 1 quantile (b=1 in paper)
    phi=PHI,
    transaction_fee=0.003,  # 0.3% as in paper
    initial_capital=1000000.0,  # $1M initial capital
    rebalance_frequency='monthly'
)

print(f"Trading Strategy Configuration:")
print(f"- Number of classes: {trading_config.n_classes}")
print(f"- Decision factor: {trading_config.decision_factor}")
print(f"- Phi (quantile size): {trading_config.phi}")
print(f"- Transaction fee: {trading_config.transaction_fee:.1%}")
print(f"- Initial capital: ${trading_config.initial_capital:,.0f}")

# Set up WandB for backtesting (if enabled)
backtest_logger = None
if ENABLE_WANDB:
    strategy_config_dict = {
        "n_classes": trading_config.n_classes,
        "decision_factor": trading_config.decision_factor,
        "phi": trading_config.phi,
        "transaction_fee": trading_config.transaction_fee,
        "initial_capital": trading_config.initial_capital,
        "rebalance_frequency": trading_config.rebalance_frequency
    }
    
    backtest_logger = setup_backtesting_experiment(
        strategy_config=strategy_config_dict,
        notes="Quantformer backtesting with trained model"
    )

# Initialize trading strategy with WandB logger
strategy = QuantformerTradingStrategy(model, trading_config, wandb_logger=backtest_logger)
print(f"\nTrading strategy initialized")
if backtest_logger and backtest_logger.enabled:
    print(f"🔗 Backtesting WandB: {backtest_logger.run.url}")


In [None]:
# Prepare data for backtesting
backtest_start = int(0.8 * N_TIMESTEPS)  # Start backtesting from 80% of the data
backtest_periods = min(100, N_TIMESTEPS - backtest_start - SEQ_LEN)  # Limit for demo

print(f"Preparing backtest data...")
print(f"Backtest start: timestep {backtest_start}")
print(f"Backtest periods: {backtest_periods}")

# Create features for backtesting
features_for_backtest = np.zeros((backtest_periods, N_STOCKS, SEQ_LEN, 2))
returns_for_backtest = np.zeros((backtest_periods, N_STOCKS))

for t in range(backtest_periods):
    actual_t = backtest_start + t
    
    # Create sequences for each stock at this timestep
    for stock in range(N_STOCKS):
        features_for_backtest[t, stock] = feature_data[actual_t:actual_t + SEQ_LEN, stock]
    
    # Get returns for next period
    if actual_t + SEQ_LEN < N_TIMESTEPS:
        returns_for_backtest[t] = return_data[actual_t + SEQ_LEN]

print(f"Features for backtest shape: {features_for_backtest.shape}")
print(f"Returns for backtest shape: {returns_for_backtest.shape}")

# Run backtest
print("\nRunning backtest...")
backtest_results = strategy.backtest(
    features_data=features_for_backtest,
    returns_data=returns_for_backtest
)

print("Backtest completed!")


In [None]:
# Display performance metrics
print("=" * 50)
print("QUANTFORMER TRADING STRATEGY RESULTS")
print("=" * 50)

for metric, value in backtest_results.items():
    if isinstance(value, float):
        if 'return' in metric.lower() or 'alpha' in metric.lower():
            print(f"{metric.replace('_', ' ').title():<25}: {value:>10.2%}")
        elif 'ratio' in metric.lower():
            print(f"{metric.replace('_', ' ').title():<25}: {value:>10.4f}")
        elif 'value' in metric.lower():
            print(f"{metric.replace('_', ' ').title():<25}: ${value:>10,.0f}")
        else:
            print(f"{metric.replace('_', ' ').title():<25}: {value:>10.4f}")

# Get portfolio history for visualization
portfolio_df = strategy.get_portfolio_history()

if not portfolio_df.empty:
    print(f"\nPortfolio history shape: {portfolio_df.shape}")
    
    # Create performance visualization
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    
    # Portfolio value over time
    axes[0, 0].plot(portfolio_df['portfolio_value'], linewidth=2, color='blue')
    axes[0, 0].axhline(y=trading_config.initial_capital, color='red', linestyle='--', alpha=0.7, label='Initial Capital')
    axes[0, 0].set_title('Portfolio Value Over Time')
    axes[0, 0].set_xlabel('Time Period')
    axes[0, 0].set_ylabel('Portfolio Value ($)')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    # Cumulative returns
    axes[0, 1].plot(portfolio_df['cumulative_return'] * 100, linewidth=2, color='green')
    axes[0, 1].axhline(y=0, color='red', linestyle='--', alpha=0.7)
    axes[0, 1].set_title('Cumulative Returns')
    axes[0, 1].set_xlabel('Time Period')
    axes[0, 1].set_ylabel('Cumulative Return (%)')
    axes[0, 1].grid(True, alpha=0.3)
    
    # Daily returns distribution
    daily_returns = portfolio_df['returns'][1:] * 100  # Skip first zero return
    axes[1, 0].hist(daily_returns, bins=20, alpha=0.7, edgecolor='black', color='purple')
    axes[1, 0].axvline(x=daily_returns.mean(), color='red', linestyle='--', label=f'Mean: {daily_returns.mean():.3f}%')
    axes[1, 0].set_title('Distribution of Daily Returns')
    axes[1, 0].set_xlabel('Daily Return (%)')
    axes[1, 0].set_ylabel('Frequency')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)
    
    # Portfolio turnover over time
    axes[1, 1].plot(portfolio_df['turnover'] * 100, linewidth=2, color='orange')
    axes[1, 1].set_title('Portfolio Turnover Over Time')
    axes[1, 1].set_xlabel('Time Period')
    axes[1, 1].set_ylabel('Turnover Rate (%)')
    axes[1, 1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
else:
    print("No portfolio history available for visualization.")


In [None]:
## Finish WandB Runs

# Clean up WandB runs
if ENABLE_WANDB:
    if wandb_logger and wandb_logger.enabled:
        wandb_logger.finish()
        print("✅ Training WandB run finished")
    
    if backtest_logger and backtest_logger.enabled:
        backtest_logger.finish()
        print("✅ Backtesting WandB run finished")
    
    print("🎉 All WandB experiments completed!")
    print("📊 Check your WandB dashboard for detailed results and comparisons")
else:
    print("📊 WandB was disabled - no runs to finish")
