# Arbitrage Window Detection with Deep Learning

In this notebook, we train a **Dual-Stream Transformer Encoder with Cross-Attention** on historical odds data to identify opportunities for arbitrage betting.

## Architecture Overview

The model processes paired sequences of odds from two bookmakers through:
1. **Feature Engineering** â€” raw odds, spread, implied probability difference, rate of change, classical arbitrage indicator
2. **Dual Transformer Encoders** â€” one per bookmaker stream, with self-attention + cross-attention
3. **Attention-Weighted Pooling** â€” aggregates variable-length sequences (longer = more accurate)
4. **Market Type Embedding** â€” conditions the scorer on MONEYLINE / POINTS_SPREAD / POINTS_TOTAL
5. **MLP Scoring Head** â€” outputs a scalar arbitrage opportunity score in [0, 1]

This tutorial covers the full lifecycle: data generation, model design, training, hyperparameter tuning, MLflow logging, registration, and deployment.

In [32]:
%pip install -Uqqq mlflow pytorch-lightning optuna skorch uv 

Note: you may need to restart the kernel to use updated packages.


In [34]:
from typing import Tuple, Optional, Dict, List, Any

import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import (
    mean_squared_error, mean_absolute_error, r2_score,
    roc_auc_score, average_precision_score, f1_score, precision_score, recall_score
)

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence

import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint

import mlflow
from mlflow.models import infer_signature
from mlflow.tracking import MlflowClient
from mlflow.entities import Metric, Param

import time
import warnings
warnings.filterwarnings("ignore")


## 0. Configure the Model Registry with Unity Catalog

Configure MLflow to use Unity Catalog for model registration.

In [None]:
mlflow.set_registry_uri("databricks-uc")

VOLUME_PATH = "/Volumes/workspace/default/hacklytics_project_storage"

## 0a. Create Delta Tables and Load Parquet Data

Run the table creation script, then read parquet files from the Databricks Volume
and write them into the `upcoming_games` and `game_odds` Delta tables.

In [None]:
# â”€â”€ Create Delta tables â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
%run ../scripts/create_delta_tables

In [None]:
import os
from functools import reduce
from pyspark.sql import DataFrame, functions as F

def read_volume_parquets(spark, volume_path: str, prefix: str) -> DataFrame:
    """Read all parquet files in a volume whose names start with `prefix`.

    Handles the case where there are multiple files sharing a common prefix
    (e.g. events.parquet, events_nba.parquet, events_nfl_2024-25.parquet).
    """
    all_files = dbutils.fs.ls(volume_path)
    matched = [
        f.path for f in all_files
        if f.name.startswith(prefix) and f.name.endswith(".parquet")
    ]
    if not matched:
        raise FileNotFoundError(
            f"No parquet files with prefix '{prefix}' found in {volume_path}"
        )
    print(f"  Found {len(matched)} file(s) for '{prefix}': {[f.split('/')[-1] for f in matched]}")
    dfs = [spark.read.parquet(path) for path in matched]
    return reduce(DataFrame.unionByName, dfs)

# â”€â”€ Read parquet files from the volume â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
print(f"Reading parquet files from {VOLUME_PATH} ...")
events_df = read_volume_parquets(spark, VOLUME_PATH, "events")
opening_df = read_volume_parquets(spark, VOLUME_PATH, "opening")
closing_df = read_volume_parquets(spark, VOLUME_PATH, "closing")

print(f"\nEvents:  {events_df.count()} rows")
print(f"Opening: {opening_df.count()} rows")
print(f"Closing: {closing_df.count()} rows")

In [None]:
# â”€â”€ Cast timestamp columns from string â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
timestamp_cols = ["event_start_time", "competition_instance_start", "competition_instance_end"]

for col_name in timestamp_cols:
    events_df = events_df.withColumn(col_name, F.to_timestamp(col_name))
    opening_df = opening_df.withColumn(col_name, F.to_timestamp(col_name))
    closing_df = closing_df.withColumn(col_name, F.to_timestamp(col_name))

odds_timestamp_cols = ["read_at", "last_found_at"]
for col_name in odds_timestamp_cols:
    opening_df = opening_df.withColumn(col_name, F.to_timestamp(col_name))
    closing_df = closing_df.withColumn(col_name, F.to_timestamp(col_name))

# â”€â”€ Write events â†’ upcoming_games â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
upcoming_games_df = events_df.filter(
    F.col("event_start_time") > F.current_timestamp()
)

upcoming_games_df.write.mode("overwrite").saveAsTable("default.upcoming_games")
print(f"Wrote {upcoming_games_df.count()} rows to upcoming_games")

# â”€â”€ Combine opening + closing odds â†’ game_odds â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
opening_tagged = opening_df.withColumn("odds_type", F.lit("opening"))
closing_tagged = closing_df.withColumn("odds_type", F.lit("closing"))
all_odds_df = opening_tagged.unionByName(closing_tagged)

all_odds_df.write.mode("overwrite").saveAsTable("default.game_odds")
print(f"Wrote {all_odds_df.count()} rows to game_odds")

In [None]:
# â”€â”€ Verify Delta tables â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
print("=== upcoming_games ===")
spark.table("default.upcoming_games").printSchema()
spark.table("default.upcoming_games").show(5, truncate=False)

print("=== game_odds ===")
spark.table("default.game_odds").printSchema()
spark.table("default.game_odds").show(5, truncate=False)

In [None]:
# â”€â”€ Load Delta tables into pandas for model training â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
X_events = spark.table("default.upcoming_games").toPandas()
X_odds = spark.table("default.game_odds").toPandas()

X_open_odds = X_odds[X_odds["odds_type"] == "opening"].copy()
X_end_odds = X_odds[X_odds["odds_type"] == "closing"].copy()

print(f"Events (upcoming): {len(X_events)} rows")
print(f"Opening odds:      {len(X_open_odds)} rows")
print(f"Closing odds:      {len(X_end_odds)} rows")

## 1. Build Training Data

We combine **real odds from Delta tables** (opening â†’ closing sequences per bookmaker pair)
with **synthetic data** to produce a robust training set. Each sample is a variable-length
time series of decimal odds from two bookmakers for a single event and market type.

In [None]:
from itertools import combinations

def build_real_odds_samples(
    open_odds: pd.DataFrame,
    close_odds: pd.DataFrame,
) -> List[Dict[str, Any]]:
    """Convert real opening/closing odds into training samples.

    For each (event, market_type) group, we find all bookmaker sources,
    form every pair, and build a 2-point time series [opening, closing]
    for each bookmaker in the pair. The arbitrage label is determined by
    whether the sum of implied probabilities drops below 1 at either the
    opening or closing snapshot.
    """
    # Join opening and closing odds on the natural key
    join_cols = ["event_key", "market_type", "source", "participant_key"]
    merged = open_odds.merge(
        close_odds,
        on=join_cols,
        suffixes=("_open", "_close"),
        how="inner",
    )

    if merged.empty:
        print("Warning: no matching open/close odds pairs found.")
        return []

    samples = []
    # Group by event + market_type to find bookmaker pairs
    for (event_key, market_type), group in merged.groupby(["event_key", "market_type"]):
        sources = group["source"].unique()
        if len(sources) < 2:
            continue

        # Map market_type strings to the model's expected categories
        mt = market_type.upper()
        if mt not in ("MONEYLINE", "POINTS_SPREAD", "POINTS_TOTAL"):
            continue

        for src_a, src_b in combinations(sources, 2):
            rows_a = group[group["source"] == src_a]
            rows_b = group[group["source"] == src_b]

            if rows_a.empty or rows_b.empty:
                continue

            # Use the first participant row for each source to get payout
            row_a = rows_a.iloc[0]
            row_b = rows_b.iloc[0]

            # Build 2-point sequences: [opening_payout, closing_payout]
            odds_a = np.array(
                [row_a["payout_open"], row_a["payout_close"]], dtype=np.float32
            )
            odds_b = np.array(
                [row_b["payout_open"], row_b["payout_close"]], dtype=np.float32
            )

            # Skip invalid odds (zero or negative payouts)
            if (odds_a <= 0).any() or (odds_b <= 0).any():
                continue

            # Determine arbitrage label: sum of implied probs < 1 at any timestep
            impl_sum = (1.0 / odds_a) + (1.0 / odds_b)
            has_arb = float((impl_sum < 1.0).any())

            samples.append({
                "odds_a": odds_a,
                "odds_b": odds_b,
                "market_type": mt,
                "label": has_arb,
                "seq_len": 2,
            })

    return samples


real_data = build_real_odds_samples(X_open_odds, X_end_odds)
n_real_arb = sum(1 for s in real_data if s["label"] == 1.0)
print(f"Real data samples: {len(real_data)}")
print(f"  Arbitrage:     {n_real_arb} ({n_real_arb / max(len(real_data), 1):.1%})")
print(f"  No arbitrage:  {len(real_data) - n_real_arb}")

In [None]:
def generate_synthetic_odds_data(
    n_samples: int = 3000,
    min_seq_len: int = 10,
    max_seq_len: int = 120,
    seed: int = 42,
    arb_fraction: float = 0.3,
) -> List[Dict[str, Any]]:
    """Generate synthetic paired odds sequences for two bookmakers.

    Each sample contains:
        - odds_a: sequence of American odds from bookmaker A
        - odds_b: sequence of American odds from bookmaker B (complementary side)
        - market_type: one of {MONEYLINE, POINTS_SPREAD, POINTS_TOTAL}
        - label: 1 if an arbitrage window existed, 0 otherwise
        - seq_len: length of the sequence

    The generator simulates realistic odds dynamics where bookmakers move
    asynchronously and occasionally create short-lived arbitrage windows.
    """
    rng = np.random.RandomState(seed)
    market_types = ["MONEYLINE", "POINTS_SPREAD", "POINTS_TOTAL"]
    samples = []

    for i in range(n_samples):
        seq_len = rng.randint(min_seq_len, max_seq_len + 1)
        market = rng.choice(market_types)
        has_arb = rng.random() < arb_fraction

        # Base implied probabilities for the two sides of a bet
        # (e.g., Team A win vs Team B win for moneyline)
        true_prob_a = rng.uniform(0.30, 0.70)
        true_prob_b = 1.0 - true_prob_a

        # Bookmaker vigorish (overround) â€” typically 3-8%
        vig_a = rng.uniform(0.03, 0.08)
        vig_b = rng.uniform(0.03, 0.08)

        # Generate odds sequences as implied probabilities with random walk
        implied_a = np.zeros(seq_len)  # bookmaker A implied prob for side A
        implied_b = np.zeros(seq_len)  # bookmaker B implied prob for side B

        implied_a[0] = true_prob_a + vig_a + rng.normal(0, 0.02)
        implied_b[0] = true_prob_b + vig_b + rng.normal(0, 0.02)

        # Correlated random walks with bookmaker-specific lag
        drift_a = rng.normal(0, 0.005, seq_len)
        drift_b = rng.normal(0, 0.005, seq_len)

        # Bookmaker B may lag behind market moves
        lag = rng.randint(1, 4)
        common_shock = rng.normal(0, 0.008, seq_len)

        for t in range(1, seq_len):
            implied_a[t] = implied_a[t-1] + drift_a[t] + common_shock[t]
            lagged_shock = common_shock[max(0, t - lag)]
            implied_b[t] = implied_b[t-1] + drift_b[t] + lagged_shock

        # Clamp to valid probability range
        implied_a = np.clip(implied_a, 0.15, 0.95)
        implied_b = np.clip(implied_b, 0.15, 0.95)

        # Inject arbitrage window for positive samples
        if has_arb:
            arb_start = rng.randint(seq_len // 3, 2 * seq_len // 3)
            arb_duration = rng.randint(2, min(8, seq_len - arb_start))
            # Make sum of implied probs < 1 (arbitrage condition)
            for t in range(arb_start, min(arb_start + arb_duration, seq_len)):
                gap = rng.uniform(0.01, 0.04)
                implied_a[t] = true_prob_a - gap / 2
                implied_b[t] = true_prob_b - gap / 2

        # Convert implied probabilities to decimal odds
        odds_a = 1.0 / implied_a
        odds_b = 1.0 / implied_b

        samples.append({
            "odds_a": odds_a.astype(np.float32),
            "odds_b": odds_b.astype(np.float32),
            "market_type": market,
            "label": float(has_arb),
            "seq_len": seq_len,
        })

    return samples


synthetic_data = generate_synthetic_odds_data(n_samples=4000, seed=42)
print(f"Synthetic samples: {len(synthetic_data)}")
print(f"  Label distribution: {sum(s['label'] for s in synthetic_data) / len(synthetic_data):.1%} positive")
print(f"  Sequence length range: {min(s['seq_len'] for s in synthetic_data)} - {max(s['seq_len'] for s in synthetic_data)}")

# â”€â”€ Combine real + synthetic data â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
raw_data = real_data + synthetic_data
n_total_arb = sum(1 for s in raw_data if s["label"] == 1.0)
print(f"\nCombined training data: {len(raw_data)} samples")
print(f"  Real:      {len(real_data)}")
print(f"  Synthetic: {len(synthetic_data)}")
print(f"  Arbitrage: {n_total_arb} ({n_total_arb / len(raw_data):.1%})")

## 2. Feature Engineering

At each timestep we compute:
- Raw decimal odds from both bookmakers
- Odds spread (A âˆ’ B)
- Implied probability difference
- Classical arbitrage indicator:  (< 1 means arbitrage)
- Rate of change for each bookmaker
- Time-normalized position in sequence

In [38]:
MARKET_TYPE_MAP = {"MONEYLINE": 0, "POINTS_SPREAD": 1, "POINTS_TOTAL": 2}


def engineer_features(sample: Dict[str, Any]) -> Dict[str, Any]:
    """Convert raw odds pair into an engineered feature tensor.

    Returns a dict with:
        features: np.ndarray of shape (seq_len, n_features)
        market_type: int index
        label: float
        seq_len: int
    """
    odds_a = sample["odds_a"]
    odds_b = sample["odds_b"]
    T = len(odds_a)

    # Implied probabilities
    impl_a = 1.0 / odds_a
    impl_b = 1.0 / odds_b

    # Classical arbitrage indicator (sum of implied probs)
    arb_indicator = impl_a + impl_b  # < 1 means arbitrage

    # Spread
    spread = odds_a - odds_b

    # Implied probability difference
    impl_diff = impl_a - impl_b

    # Rate of change (pad first element with 0)
    delta_a = np.concatenate([[0], np.diff(odds_a)])
    delta_b = np.concatenate([[0], np.diff(odds_b)])

    # Relative rate of change
    rel_delta_a = delta_a / (odds_a + 1e-8)
    rel_delta_b = delta_b / (odds_b + 1e-8)

    # Time position (normalized 0 to 1)
    time_pos = np.linspace(0, 1, T).astype(np.float32)

    # Stack features: (T, 10)
    features = np.stack([
        odds_a, odds_b,           # raw odds
        spread, impl_diff,        # spreads
        arb_indicator,            # arbitrage signal
        delta_a, delta_b,         # rate of change
        rel_delta_a, rel_delta_b, # relative rate of change
        time_pos,                 # temporal position
    ], axis=1).astype(np.float32)

    return {
        "features": features,
        "market_type": MARKET_TYPE_MAP[sample["market_type"]],
        "label": sample["label"],
        "seq_len": T,
    }


engineered_data = [engineer_features(s) for s in raw_data]
print(f"Feature dimensionality per timestep: {engineered_data[0]['features'].shape[1]}")


Feature dimensionality per timestep: 10


## 3. Exploratory Data Analysis

Visualize the odds dynamics and feature distributions.

In [39]:
def plot_sample_odds(raw_data, idx=0):
    """Plot odds movement for a single sample."""
    s = raw_data[idx]
    fig, axes = plt.subplots(1, 2, figsize=(14, 4))

    axes[0].plot(s["odds_a"], label="Bookmaker A", alpha=0.8)
    axes[0].plot(s["odds_b"], label="Bookmaker B", alpha=0.8)
    axes[0].set_title(f'Decimal Odds â€” {s["market_type"]} (label={s["label"]:.0f})')
    axes[0].set_xlabel("Timestep")
    axes[0].set_ylabel("Decimal Odds")
    axes[0].legend()

    impl_sum = 1.0 / s["odds_a"] + 1.0 / s["odds_b"]
    axes[1].plot(impl_sum, color="red", alpha=0.8)
    axes[1].axhline(y=1.0, color="green", linestyle="--", label="Arbitrage threshold")
    axes[1].set_title("Sum of Implied Probabilities")
    axes[1].set_xlabel("Timestep")
    axes[1].set_ylabel("Î£(1/odds)")
    axes[1].legend()

    plt.tight_layout()
    plt.close(fig)
    return fig


def plot_feature_distributions_odds(engineered_data, n_cols=3):
    """Plot distributions of engineered features across all timesteps."""
    feature_names = [
        "odds_a", "odds_b", "spread", "impl_diff",
        "arb_indicator", "delta_a", "delta_b",
        "rel_delta_a", "rel_delta_b", "time_pos"
    ]
    # Collect all timestep features
    all_features = np.concatenate([s["features"] for s in engineered_data], axis=0)
    n_features = all_features.shape[1]
    n_rows = (n_features + n_cols - 1) // n_cols

    fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 4 * n_rows))
    axes = axes.flatten()

    for i in range(n_features):
        sns.histplot(all_features[:, i], ax=axes[i], kde=True, color="skyblue", bins=50)
        axes[i].set_title(f"Distribution of {feature_names[i]}")

    for i in range(n_features, len(axes)):
        axes[i].set_visible(False)

    plt.tight_layout()
    fig.suptitle("Engineered Feature Distributions", y=1.02, fontsize=16)
    plt.close(fig)
    return fig


def plot_seq_length_distribution(engineered_data):
    """Plot distribution of sequence lengths by label."""
    fig, ax = plt.subplots(figsize=(8, 4))
    lens_pos = [s["seq_len"] for s in engineered_data if s["label"] == 1.0]
    lens_neg = [s["seq_len"] for s in engineered_data if s["label"] == 0.0]
    ax.hist(lens_neg, bins=30, alpha=0.6, label="No arbitrage", color="steelblue")
    ax.hist(lens_pos, bins=30, alpha=0.6, label="Arbitrage", color="coral")
    ax.set_xlabel("Sequence Length")
    ax.set_ylabel("Count")
    ax.set_title("Sequence Length Distribution by Label")
    ax.legend()
    plt.tight_layout()
    plt.close(fig)
    return fig


# Generate EDA plots
sample_plot_pos = plot_sample_odds(raw_data, idx=next(i for i, s in enumerate(raw_data) if s["label"] == 1.0))
sample_plot_neg = plot_sample_odds(raw_data, idx=next(i for i, s in enumerate(raw_data) if s["label"] == 0.0))
feat_dist_plot = plot_feature_distributions_odds(engineered_data)
seq_len_plot = plot_seq_length_distribution(engineered_data)
print("EDA plots generated.")


EDA plots generated.


## 4. PyTorch Dataset and Variable-Length Collation

We implement a custom Dataset and a  that pads sequences to the
longest in the batch and returns attention masks. This is critical for
the transformer to ignore padding tokens.

In [40]:
class ArbitrageOddsDataset(Dataset):
    """PyTorch Dataset for variable-length odds sequences."""

    def __init__(self, samples: List[Dict[str, Any]]):
        self.samples = samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        s = self.samples[idx]
        return {
            "features": torch.tensor(s["features"], dtype=torch.float32),
            "market_type": torch.tensor(s["market_type"], dtype=torch.long),
            "label": torch.tensor(s["label"], dtype=torch.float32),
            "seq_len": s["seq_len"],
        }


def collate_fn(batch: List[Dict]) -> Dict[str, torch.Tensor]:
    """Pad variable-length sequences and create attention masks."""
    features = [item["features"] for item in batch]
    market_types = torch.stack([item["market_type"] for item in batch])
    labels = torch.stack([item["label"] for item in batch])
    seq_lens = torch.tensor([item["seq_len"] for item in batch], dtype=torch.long)

    # Pad sequences to max length in batch
    features_padded = pad_sequence(features, batch_first=True, padding_value=0.0)
    B, T, _ = features_padded.shape

    # Create attention mask (1 = valid, 0 = padding)
    mask = torch.arange(T).unsqueeze(0).expand(B, T) < seq_lens.unsqueeze(1)

    return {
        "features": features_padded,   # (B, T, F)
        "mask": mask,                  # (B, T)
        "market_type": market_types,   # (B,)
        "label": labels,               # (B,)
        "seq_len": seq_lens,           # (B,)
    }

## 5. Dual-Stream Transformer Encoder with Cross-Attention

The model architecture:



Key design choices:
- **ALiBi positional bias** for length generalization
- **Cross-attention** in later layers for inter-bookmaker dynamics
- **Attention-weighted pooling** so longer sequences contribute more evidence

In [41]:
class ALiBiAttention(nn.Module):
    """Multi-head attention with ALiBi (Attention with Linear Biases) positional encoding.

    ALiBi adds a linear bias to attention scores based on query-key distance,
    enabling strong length generalization without learned positional embeddings.
    """

    def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
        super().__init__()
        assert d_model % n_heads == 0
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        self.scale = self.head_dim ** -0.5

        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

        # Compute ALiBi slopes (one per head)
        slopes = self._compute_slopes(n_heads)
        self.register_buffer("slopes", slopes)  # (n_heads,)

    @staticmethod
    def _compute_slopes(n_heads: int) -> torch.Tensor:
        """Compute ALiBi slopes following the geometric sequence from the paper."""
        def get_slopes_power_of_2(n):
            start = 2 ** (-(2 ** -(math.log2(n) - 3)))
            ratio = start
            return [start * (ratio ** i) for i in range(n)]

        if math.log2(n_heads).is_integer():
            return torch.tensor(get_slopes_power_of_2(n_heads), dtype=torch.float32)
        else:
            closest_power = 2 ** math.floor(math.log2(n_heads))
            slopes = get_slopes_power_of_2(closest_power)
            extra = get_slopes_power_of_2(2 * closest_power)[0::2][:n_heads - closest_power]
            return torch.tensor(slopes + extra, dtype=torch.float32)

    def _alibi_bias(self, T: int) -> torch.Tensor:
        """Create ALiBi bias matrix of shape (n_heads, T, T)."""
        positions = torch.arange(T, device=self.slopes.device)
        distance = positions.unsqueeze(0) - positions.unsqueeze(1)  # (T, T)
        bias = self.slopes.unsqueeze(-1).unsqueeze(-1) * distance.unsqueeze(0).abs().neg()
        return bias  # (n_heads, T, T)

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """
        Args:
            query: (B, T_q, d_model)
            key:   (B, T_k, d_model)
            value: (B, T_k, d_model)
            mask:  (B, T_k) boolean mask, True=valid
        Returns:
            (B, T_q, d_model)
        """
        B, T_q, _ = query.shape
        T_k = key.shape[1]

        Q = self.q_proj(query).view(B, T_q, self.n_heads, self.head_dim).transpose(1, 2)
        K = self.k_proj(key).view(B, T_k, self.n_heads, self.head_dim).transpose(1, 2)
        V = self.v_proj(value).view(B, T_k, self.n_heads, self.head_dim).transpose(1, 2)

        # Scaled dot-product attention + ALiBi bias
        attn = (Q @ K.transpose(-2, -1)) * self.scale

        # Only add ALiBi bias for self-attention (T_q == T_k)
        if T_q == T_k:
            attn = attn + self._alibi_bias(T_q).unsqueeze(0)

        # Apply padding mask
        if mask is not None:
            # mask: (B, T_k) -> (B, 1, 1, T_k)
            mask_expanded = mask.unsqueeze(1).unsqueeze(2)
            attn = attn.masked_fill(~mask_expanded, float("-inf"))

        attn = F.softmax(attn, dim=-1)
        attn = self.dropout(attn)

        out = (attn @ V).transpose(1, 2).contiguous().view(B, T_q, self.d_model)
        return self.out_proj(out)


class TransformerEncoderLayer(nn.Module):
    """Single transformer encoder layer with optional cross-attention."""

    def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1, use_cross_attn: bool = False):
        super().__init__()
        self.self_attn = ALiBiAttention(d_model, n_heads, dropout)
        self.norm1 = nn.LayerNorm(d_model)

        self.use_cross_attn = use_cross_attn
        if use_cross_attn:
            self.cross_attn = ALiBiAttention(d_model, n_heads, dropout)
            self.norm_cross = nn.LayerNorm(d_model)

        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout),
        )
        self.norm2 = nn.LayerNorm(d_model)

    def forward(
        self,
        x: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
        cross_kv: Optional[torch.Tensor] = None,
        cross_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        # Self-attention
        residual = x
        x = self.norm1(x)
        x = residual + self.self_attn(x, x, x, mask)

        # Cross-attention (attend to the other bookmaker stream)
        if self.use_cross_attn and cross_kv is not None:
            residual = x
            x = self.norm_cross(x)
            x = residual + self.cross_attn(x, cross_kv, cross_kv, cross_mask)

        # Feed-forward
        residual = x
        x = self.norm2(x)
        x = residual + self.ffn(x)

        return x


class AttentionPooling(nn.Module):
    """Learnable attention-weighted pooling over the sequence dimension.

    Longer sequences provide more evidence vectors for the pooling to attend to,
    naturally producing more confident (and accurate) scores.
    """

    def __init__(self, d_model: int):
        super().__init__()
        self.query = nn.Parameter(torch.randn(1, 1, d_model))
        self.scale = d_model ** -0.5

    def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Args:
            x: (B, T, d_model)
            mask: (B, T) boolean
        Returns:
            (B, d_model)
        """
        # Compute attention scores
        scores = (self.query * x).sum(dim=-1) * self.scale  # (B, T)

        if mask is not None:
            scores = scores.masked_fill(~mask, float("-inf"))

        weights = F.softmax(scores, dim=-1).unsqueeze(-1)  # (B, T, 1)
        return (weights * x).sum(dim=1)  # (B, d_model)


class TemporalArbitrageScorer(nn.Module):
    """Dual-Stream Transformer Encoder with Cross-Attention for arbitrage detection.

    The input feature tensor contains both bookmaker streams stacked together.
    The model splits them internally, processes each through its own transformer
    stream with cross-attention in later layers, pools with learned attention
    weights, and produces a scalar arbitrage score.
    """

    def __init__(
        self,
        n_input_features: int = 10,
        d_model: int = 64,
        n_heads: int = 4,
        n_layers: int = 3,
        d_ff: int = 256,
        dropout: float = 0.1,
        n_market_types: int = 3,
        cross_attn_start_layer: int = 1,
    ):
        super().__init__()
        self.d_model = d_model
        self.n_input_features = n_input_features
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.d_ff = d_ff
        self.dropout_rate = dropout

        # Split features: first 5 features relate more to bookmaker A, rest to B
        # But we project the full feature vector for each stream (shared context)
        self.proj_a = nn.Linear(n_input_features, d_model)
        self.proj_b = nn.Linear(n_input_features, d_model)

        # Market type embedding
        self.market_emb = nn.Embedding(n_market_types, d_model)

        # Transformer layers for stream A
        self.layers_a = nn.ModuleList([
            TransformerEncoderLayer(
                d_model, n_heads, d_ff, dropout,
                use_cross_attn=(i >= cross_attn_start_layer)
            )
            for i in range(n_layers)
        ])

        # Transformer layers for stream B
        self.layers_b = nn.ModuleList([
            TransformerEncoderLayer(
                d_model, n_heads, d_ff, dropout,
                use_cross_attn=(i >= cross_attn_start_layer)
            )
            for i in range(n_layers)
        ])

        # Attention pooling for each stream
        self.pool_a = AttentionPooling(d_model)
        self.pool_b = AttentionPooling(d_model)

        # MLP scoring head
        self.head = nn.Sequential(
            nn.Linear(d_model * 2 + d_model, d_ff),  # concat streams + market emb
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_ff // 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff // 2, 1),
        )

    def forward(
        self,
        features: torch.Tensor,
        mask: torch.Tensor,
        market_type: torch.Tensor,
    ) -> torch.Tensor:
        """
        Args:
            features:    (B, T, n_input_features) padded feature tensor
            mask:        (B, T) boolean, True = valid timestep
            market_type: (B,) integer market type index
        Returns:
            (B,) arbitrage score in [0, 1]
        """
        # Project into two streams
        h_a = self.proj_a(features)  # (B, T, d_model)
        h_b = self.proj_b(features)  # (B, T, d_model)

        # Process through transformer layers with cross-attention
        for layer_a, layer_b in zip(self.layers_a, self.layers_b):
            h_a_new = layer_a(h_a, mask=mask, cross_kv=h_b, cross_mask=mask)
            h_b_new = layer_b(h_b, mask=mask, cross_kv=h_a, cross_mask=mask)
            h_a, h_b = h_a_new, h_b_new

        # Pool each stream
        pooled_a = self.pool_a(h_a, mask)  # (B, d_model)
        pooled_b = self.pool_b(h_b, mask)  # (B, d_model)

        # Market type embedding
        m_emb = self.market_emb(market_type)  # (B, d_model)

        # Concatenate and score
        combined = torch.cat([pooled_a, pooled_b, m_emb], dim=-1)
        score = torch.sigmoid(self.head(combined).squeeze(-1))  # (B,)

        return score

    def get_params(self) -> Dict[str, Any]:
        """Return model parameters for MLflow logging."""
        return {
            "n_input_features": self.n_input_features,
            "d_model": self.d_model,
            "n_heads": self.n_heads,
            "n_layers": self.n_layers,
            "d_ff": self.d_ff,
            "dropout": self.dropout_rate,
        }

In [42]:
class ArbitrageLightningModule(pl.LightningModule):
    """PyTorch Lightning wrapper for the TemporalArbitrageScorer."""

    def __init__(
        self,
        n_input_features: int = 10,
        d_model: int = 64,
        n_heads: int = 4,
        n_layers: int = 3,
        d_ff: int = 256,
        dropout: float = 0.1,
        learning_rate: float = 1e-3,
        weight_decay: float = 1e-5,
        cross_attn_start_layer: int = 1,
    ):
        super().__init__()
        self.save_hyperparameters()

        self.model = TemporalArbitrageScorer(
            n_input_features=n_input_features,
            d_model=d_model,
            n_heads=n_heads,
            n_layers=n_layers,
            d_ff=d_ff,
            dropout=dropout,
            cross_attn_start_layer=cross_attn_start_layer,
        )

        self.loss_fn = nn.BCELoss()

    def forward(self, features, mask, market_type):
        return self.model(features, mask, market_type)

    def _shared_step(self, batch, stage: str):
        scores = self(batch["features"], batch["mask"], batch["market_type"])
        loss = self.loss_fn(scores, batch["label"])

        # Metrics
        preds = (scores > 0.5).float()
        acc = (preds == batch["label"]).float().mean()

        self.log(f"{stage}_loss", loss, prog_bar=True, batch_size=len(batch["label"]))
        self.log(f"{stage}_acc", acc, prog_bar=True, batch_size=len(batch["label"]))
        return loss

    def training_step(self, batch, batch_idx):
        return self._shared_step(batch, "train")

    def validation_step(self, batch, batch_idx):
        return self._shared_step(batch, "val")

    def test_step(self, batch, batch_idx):
        return self._shared_step(batch, "test")

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.hparams.learning_rate,
            weight_decay=self.hparams.weight_decay,
        )
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer, T_0=10, T_mult=2
        )
        return {"optimizer": optimizer, "lr_scheduler": scheduler}

    def get_params(self) -> Dict[str, Any]:
        return self.model.get_params()


## 6. Train / Validation / Test Split & DataLoaders

We split the engineered data and create DataLoaders with our custom collate function.

In [None]:
# Split indices
indices = list(range(len(engineered_data)))
labels_for_split = [s["label"] for s in engineered_data]

train_idx, temp_idx = train_test_split(indices, test_size=0.3, random_state=42, stratify=labels_for_split)

print(train_idx[0])
temp_labels = [labels_for_split[i] for i in temp_idx]
val_idx, test_idx = train_test_split(temp_idx, test_size=0.5, random_state=42, stratify=temp_labels)

train_data = [engineered_data[i] for i in train_idx]

val_data = [engineered_data[i] for i in val_idx]
test_data = [engineered_data[i] for i in test_idx]

print(f"Train: {len(train_data)}, Val: {len(val_data)}, Test: {len(test_data)}")
print(f"Train positives: {sum(1 for s in train_data if s['label'] == 1.0) / len(train_data):.1%}")

# Create DataLoaders
BATCH_SIZE = 32

train_loader = DataLoader(
    ArbitrageOddsDataset(train_data), batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn
)
val_loader = DataLoader(
    ArbitrageOddsDataset(val_data), batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn
)
test_loader = DataLoader(
    ArbitrageOddsDataset(test_data), batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn
)


2901
Train: 2800, Val: 600, Test: 600
Train positives: 28.5%


## 7. Standard Training Workflow

Train the Temporal Arbitrage Scorer with default hyperparameters as a baseline.
We use curriculum training â€” starting with shorter sequences and gradually including longer ones.

In [None]:
# Model hyperparameters
N_INPUT_FEATURES = 10
D_MODEL = 32
N_HEADS = 4
N_LAYERS = 3
D_FF = 256
DROPOUT = 0.1
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-5

# Create the model
model = ArbitrageLightningModule(
    n_input_features=N_INPUT_FEATURES,
    d_model=D_MODEL,
    n_heads=N_HEADS,
    n_layers=N_LAYERS,
    d_ff=D_FF,
    dropout=DROPOUT,
    learning_rate=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
)

# Callbacks
early_stopping = EarlyStopping(monitor="val_loss", patience=1, mode="min")
checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",
    dirpath="./checkpoints",
    filename="arb-scorer-{epoch:02d}-{val_loss:.4f}",
    save_top_k=1,
    mode="min",
)

# Trainer
trainer = pl.Trainer(
    max_epochs=30,
    callbacks=[early_stopping, checkpoint_callback],
    enable_progress_bar=True,
    log_every_n_steps=5,
)

# Train
trainer.fit(model, train_loader, val_loader)

# Test
test_results = trainer.test(model, test_loader)

# Collect predictions for evaluation
model.eval()
all_scores = []
all_labels = []

with torch.no_grad():
    for batch in test_loader:
        scores = model(batch["features"], batch["mask"], batch["market_type"])
        all_scores.extend(scores.cpu().numpy())
        all_labels.extend(batch["label"].cpu().numpy())

all_scores = np.array(all_scores)
all_labels = np.array(all_labels)

# Calculate classification metrics
preds_binary = (all_scores > 0.5).astype(float)
auc_roc = roc_auc_score(all_labels, all_scores)
auc_pr = average_precision_score(all_labels, all_scores)
f1 = f1_score(all_labels, preds_binary)
precision = precision_score(all_labels, preds_binary)
recall = recall_score(all_labels, preds_binary)

print(f"AUC-ROC: {auc_roc:.4f}")
print(f"AUC-PR:  {auc_pr:.4f}")
print(f"F1:      {f1:.4f}")
print(f"Prec:    {precision:.4f}")
print(f"Recall:  {recall:.4f}")

# Analyze accuracy by sequence length (key: longer = better)
test_seq_lens = np.array([s["seq_len"] for s in test_data])
length_bins = [(10, 30), (30, 60), (60, 90), (90, 121)]
print("\nAccuracy by sequence length:")
for lo, hi in length_bins:
    mask = (test_seq_lens >= lo) & (test_seq_lens < hi)
    if mask.sum() > 0:
        bin_acc = (preds_binary[mask] == all_labels[mask]).mean()
        print(f"  [{lo:3d}, {hi:3d}): {bin_acc:.3f}  (n={mask.sum()})")

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
ðŸ’¡ Tip: For seamless cloud logging and experiment tracking, try installing [litlogger](https://pypi.org/project/litlogger/) to enable LitLogger, which logs metrics and artifacts automatically to the Lightning Experiments platform.

  | Name    | Type                    | Params | Mode  | FLOPs
--------------------------------------------------------------------
0 | model   | TemporalArbitrageScorer | 4.5 K  | train | 0    
1 | loss_fn | BCELoss                 | 0      | train | 0    
--------------------------------------------------------------------
4.5 K     Trainable params
0         Non-trainable params
4.5 K     Total params
0.018     Total estimated model params size (MB)
135       Modules in train mode
0         Modules in eval mode
0         Total Flops


Epoch 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 175/175 [00:11<00:00, 15.24it/s, v_num=7, train_loss=0.707, train_acc=0.438, val_loss=0.692, val_acc=0.715]

`Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 175/175 [00:11<00:00, 15.08it/s, v_num=7, train_loss=0.707, train_acc=0.438, val_loss=0.692, val_acc=0.715]
Testing DataLoader 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 38/38 [00:00<00:00, 65.25it/s]
â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
       Test metric             DataLoader 0
â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€

In [28]:
def plot_score_distribution(scores, labels):
    """Plot score distribution by class."""
    fig, ax = plt.subplots(figsize=(8, 4))
    ax.hist(scores[labels == 0], bins=40, alpha=0.6, label="No arb", color="steelblue")
    ax.hist(scores[labels == 1], bins=40, alpha=0.6, label="Arbitrage", color="coral")
    ax.set_xlabel("Predicted Score")
    ax.set_ylabel("Count")
    ax.set_title("Score Distribution by Class")
    ax.legend()
    plt.tight_layout()
    plt.close(fig)
    return fig


def plot_accuracy_by_length(scores, labels, seq_lens):
    """Plot accuracy vs sequence length to verify longer = better."""
    fig, ax = plt.subplots(figsize=(8, 4))
    preds = (scores > 0.5).astype(float)
    correct = (preds == labels).astype(float)

    # Bin by sequence length
    bins = np.arange(10, 130, 10)
    bin_indices = np.digitize(seq_lens, bins)
    bin_accs = []
    bin_centers = []
    for b in range(1, len(bins)):
        mask = bin_indices == b
        if mask.sum() > 5:
            bin_accs.append(correct[mask].mean())
            bin_centers.append((bins[b-1] + bins[b]) / 2)

    ax.plot(bin_centers, bin_accs, "o-", color="teal", markersize=8)
    ax.set_xlabel("Sequence Length")
    ax.set_ylabel("Accuracy")
    ax.set_title("Accuracy vs Sequence Length (Longer â†’ Better)")
    ax.set_ylim(0, 1.05)
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.close(fig)
    return fig


score_dist_plot = plot_score_distribution(all_scores, all_labels)
acc_by_len_plot = plot_accuracy_by_length(all_scores, all_labels, test_seq_lens)
print("Evaluation plots generated.")


Evaluation plots generated.


## 8. Log the Model with MLflow

Log all metrics, parameters, artifacts, and the model to MLflow.

In [30]:
from mlflow.models import infer_signature

with mlflow.start_run() as run:
    mlflow_client = MlflowClient()
    run_id = run.info.run_id

    # Log model architecture params
    model_params = model.get_params()
    mlflow.log_params(model_params)

    # Batch log metrics
    current_time = int(time.time() * 1000)
    metrics_list = [
        Metric("test_auc_roc", auc_roc, current_time, 0),
        Metric("test_auc_pr", auc_pr, current_time, 0),
        Metric("test_f1", f1, current_time, 0),
        Metric("test_precision", precision, current_time, 0),
        Metric("test_recall", recall, current_time, 0),
    ]

    train_loss = trainer.callback_metrics.get("train_loss")
    val_loss = trainer.callback_metrics.get("val_loss")
    if train_loss is not None:
        metrics_list.append(Metric("train_loss", train_loss.item(), current_time, 0))
    if val_loss is not None:
        metrics_list.append(Metric("val_loss", val_loss.item(), current_time, 0))

    params_list = [
        Param("batch_size", str(BATCH_SIZE)),
        Param("max_epochs", str(trainer.max_epochs)),
        Param("actual_epochs", str(trainer.current_epoch)),
        Param("early_stopping_patience", str(10)),
    ]

    mlflow_client.log_batch(run_id, metrics=metrics_list, params=params_list)

    print(f"Test AUC-ROC: {auc_roc:.4f}")
    print(f"Test F1: {f1:.4f}")

Test AUC-ROC: 0.5410
Test F1: 0.0000


## 10. Pre-deployment Validation

Validate that the registered model can be loaded and produces predictions.

In [0]:
model_uri = "models:/m-24afa884e98d444fbd545903c849420b"

# Load and test with a sample batch
loaded_model = mlflow.pytorch.load_model(model_uri)
loaded_model.eval()

sample_batch = next(iter(test_loader))
with torch.no_grad():
    sample_scores = loaded_model(
        sample_batch["features"], sample_batch["mask"], sample_batch["market_type"]
    )

print(f"Sample predictions shape: {sample_scores.shape}")
print(f"Sample scores: {sample_scores[:5].numpy()}")
print(f"Sample labels: {sample_batch['label'][:5].numpy()}")


com.databricks.backend.common.rpc.CommandSkippedException
	at com.databricks.spark.chauffeur.SequenceExecutionState.$anonfun$cancel$3(SequenceExecutionState.scala:134)
	at com.databricks.spark.chauffeur.SequenceExecutionState.$anonfun$cancel$3$adapted(SequenceExecutionState.scala:129)
	at scala.collection.immutable.Range.foreach(Range.scala:158)
	at com.databricks.spark.chauffeur.SequenceExecutionState.cancel(SequenceExecutionState.scala:129)
	at com.databricks.spark.chauffeur.ExecContextState.cancelRunningSequence(ExecContextState.scala:715)
	at com.databricks.spark.chauffeur.ExecContextState.$anonfun$cancel$1(ExecContextState.scala:435)
	at scala.Option.getOrElse(Option.scala:189)
	at com.databricks.spark.chauffeur.ExecContextState.cancel(ExecContextState.scala:435)
	at com.databricks.spark.chauffeur.ExecutionContextManagerV1.cancelExecution(ExecutionContextManagerV1.scala:466)
	at com.databricks.spark.chauffeur.ChauffeurState.$anonfun$process$1(ChauffeurState.scala:757)
	at com.data

## 9. Hyperparameter Tuning with Optuna

Search over model size, learning rate, dropout, and architecture choices.

In [None]:
class PruningCallback(pl.Callback):
    """Prune unpromising Optuna trials during training."""

    def __init__(self, trial, monitor):
        super().__init__()
        self._trial = trial
        self.monitor = monitor

    def on_validation_end(self, trainer, pl_module):
        current = trainer.callback_metrics.get(self.monitor)
        if current is not None:
            self._trial.report(current.item(), trainer.current_epoch)
            if self._trial.should_prune():
                raise optuna.TrialPruned(f"Pruned at epoch {trainer.current_epoch}")


def objective(trial):
    """Optuna objective: minimize validation loss."""
    # Search space
    d_model = trial.suggest_categorical("d_model", [32, 64, 128])
    n_heads = trial.suggest_categorical("n_heads", [2, 4, 8])
    n_layers = trial.suggest_int("n_layers", 2, 5)
    d_ff = trial.suggest_categorical("d_ff", [128, 256, 512])
    dropout = trial.suggest_float("dropout", 0.0, 0.4)
    lr = trial.suggest_float("learning_rate", 1e-4, 5e-3, log=True)
    wd = trial.suggest_float("weight_decay", 1e-6, 1e-3, log=True)
    cross_start = trial.suggest_int("cross_attn_start_layer", 0, max(0, n_layers - 1))

    # Ensure d_model divisible by n_heads
    if d_model % n_heads != 0:
        raise optuna.TrialPruned("d_model not divisible by n_heads")

    with mlflow.start_run(nested=True) as child_run:
        mlflow_client = MlflowClient()
        run_id = child_run.info.run_id

        param_dict = {
            "d_model": d_model, "n_heads": n_heads, "n_layers": n_layers,
            "d_ff": d_ff, "dropout": dropout, "learning_rate": lr,
            "weight_decay": wd, "cross_attn_start_layer": cross_start,
        }
        params_list = [Param(k, str(v)) for k, v in param_dict.items()]

        trial_model = ArbitrageLightningModule(
            n_input_features=N_INPUT_FEATURES,
            d_model=d_model,
            n_heads=n_heads,
            n_layers=n_layers,
            d_ff=d_ff,
            dropout=dropout,
            learning_rate=lr,
            weight_decay=wd,
            cross_attn_start_layer=cross_start,
        )

        trial_trainer = pl.Trainer(
            max_epochs=40,
            callbacks=[
                EarlyStopping(monitor="val_loss", patience=5, mode="min"),
                PruningCallback(trial, monitor="val_loss"),
            ],
            enable_progress_bar=False,
            log_every_n_steps=10,
        )

        trial_trainer.fit(trial_model, train_loader, val_loader)

        best_val_loss = trial_trainer.callback_metrics.get("val_loss").item()

        current_time = int(time.time() * 1000)
        mlflow_client.log_batch(
            run_id,
            metrics=[Metric("val_loss", best_val_loss, current_time, 0)],
            params=params_list,
        )

    trial.set_user_attr("model", trial_model)
    return best_val_loss

## 11. Export Checkpoint for Local Inference

Save the trained model as a `.ckpt` file that can be loaded by the FastAPI backend
without any Databricks or MLflow dependencies.

In [31]:
# Save the best checkpoint from the baseline trainer for local serving.
# After running this cell, copy the file to Backend/models/model.ckpt
#
#   cp model.ckpt  ../Backend/models/model.ckpt

trainer.save_checkpoint("model.ckpt")
print(f"Checkpoint saved to model.ckpt")
print("Copy it to Backend/models/model.ckpt for local FastAPI inference.")

`weights_only` was not set, defaulting to `False`.


Checkpoint saved to model.ckpt
Copy it to Backend/models/model.ckpt for local FastAPI inference.
