In [3]:
from __future__ import annotations

import dataclasses
import math
import time
from typing import Callable, List, Sequence, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm


from svcj.config import DATA_DIR, IS_START, IS_END, MONGO_URI



In [2]:
assert torch.cuda.is_available()
torch.cuda.get_device_name(torch.cuda.current_device())

'NVIDIA GeForce RTX 3090'

In [4]:

# ===================================================================================
# 1 ▸ Parameter container & realistic hyper-cube (crypto-calibrated)
# ===================================================================================
@dataclasses.dataclass
class SVCJParams:
    """All structural parameters of the Stochastic Volatility with Correlated Jumps (SVCJ) model."""

    mu: float
    kappa: float
    theta: float
    sigma_v: float
    rho: float
    lam: float  # jump intensity (λ)
    mu_s: float  # jump in log-returns
    sigma_s: float
    mu_v: float  # jump in variance
    sigma_vj: float
    rho_j: float

    # ──────────────────────────────────────────────────────────────────────────
    # Priors
    # ──────────────────────────────────────────────────────────────────────────
    @staticmethod
    def sample_crypto_prior(
        rng: np.random.Generator | None = None,
    ) -> "SVCJParams":
        """Draw one parameter set from a hand-crafted crypto prior hyper-cube."""
        rng = rng or np.random.default_rng()
        return SVCJParams(
            mu=rng.uniform(-0.0005, 0.0005),
            kappa=rng.uniform(0.01, 6.0),
            theta=rng.uniform(1e-6, 1e-2),
            sigma_v=rng.uniform(1e-4, 0.1),
            rho=rng.uniform(-0.4,0.4),
            lam=rng.uniform(0.0, 10.0),
            mu_s=rng.uniform(-0.1, 0.1),
            sigma_s=rng.uniform(0.01, 0.2),
            mu_v=rng.uniform(1e-6, 1e-2),
            sigma_vj=rng.uniform(1e-6, 1e-2),
            rho_j=rng.uniform(-0.4, 0.4),
        )

    # ──────────────────────────────────────────────────────────────────────────
    # Convenience helpers
    # ──────────────────────────────────────────────────────────────────────────
    def as_tensor(self, *, device: torch.device | str | None = None) -> torch.Tensor:
        """(22,) float32 tensor view of the parameters."""
        return torch.tensor(dataclasses.astuple(self), dtype=torch.float32, device=device)

    @staticmethod
    def fields() -> List[str]:
        return list(SVCJParams.__annotations__.keys())


# ===================================================================================
# 2 ▸ SVCJ path simulator (Euler–Maruyama + Poisson jumps)
# ===================================================================================
class SVCJSimulator:
    """Generates one path of log-returns under the SVCJ dynamics."""

    @staticmethod
    def simulate(
        params: SVCJParams,
        steps: int,
        *,
        dt: float = 1 / 24,  # default: hourly for 24/7 markets
        rng: np.random.Generator | None = None,
    ) -> np.ndarray:
        rng = rng or np.random.default_rng()

        (
            mu,
            kappa,
            theta,
            sigma_v,
            rho,
            lam,
            mu_s,
            sig_s,
            mu_v,
            sig_vj,
            _,
        ) = dataclasses.astuple(params)

        # Brownian shocks (correlated)
        dW_s = rng.normal(0.0, math.sqrt(dt), size=steps)
        dW_v = rng.normal(0.0, math.sqrt(dt), size=steps)
        dW_v = rho * dW_s + math.sqrt(max(1e-8, 1.0 - rho * rho)) * dW_v

        # Poisson jump counts & magnitudes
        dN = rng.poisson(lam * dt, size=steps)
        Z_s = rng.normal(mu_s, sig_s, size=steps) * dN
        Z_v = rng.normal(mu_v, sig_vj, size=steps) * dN

        # Path containers
        V: float = theta  # start at long-run level
        rets = np.empty(steps, dtype=np.float32)

        # Euler–Maruyama recursion
        for t in range(steps):
            V_pos = max(V, 0.0)
            rets[t] = (
                mu * dt
                - 0.5 * V_pos * dt
                + math.sqrt(V_pos) * dW_s[t]
                + Z_s[t]
            )
            V = max(
                V + kappa * (theta - V) * dt + sigma_v * math.sqrt(V_pos) * dW_v[t] + Z_v[t],
                1e-12,
            )
        return rets


# ===================================================================================
# 3 ▸ PyTorch Dataset that streams *fresh* simulations each epoch
# ===================================================================================
class SVCJDataset(Dataset):
    """An *in-finite* dataset: every __getitem__ produces brand-new synthetic data."""

    def __init__(
        self,
        n_samples: int,
        *,
        min_T: int = 240,
        max_T: int = 720,
        rng: np.random.Generator | None = None,
    ) -> None:
        self.n = n_samples
        self.min_T = min_T
        self.max_T = max_T
        self.rng = rng or np.random.default_rng()

    # Dataset protocol -------------------------------------------------------
    def __len__(self) -> int:  # type: ignore[override]
        return self.n

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:  # type: ignore[override]
        """Return one *variable-length* return series + ground-truth params."""
        T = int(self.rng.integers(self.min_T, self.max_T + 1))
        psi = SVCJParams.sample_crypto_prior(self.rng)
        r = SVCJSimulator.simulate(psi, T, rng=self.rng)
        x = torch.from_numpy(r).float().unsqueeze(0)  # (1, T)
        y = psi.as_tensor()
        return x, y


# ──────────────────────────────────────────────────────────────────────────────
# Collate & worker helpers
# ──────────────────────────────────────────────────────────────────────────────

def svcj_collate(batch: Sequence[Tuple[torch.Tensor, torch.Tensor]]):
    """Right-pad variable-length sequences so they can be stacked."""
    xs, ys = zip(*batch)
    lengths = [x.size(-1) for x in xs]
    max_len = max(lengths)
    padded = torch.zeros(len(xs), 1, max_len, dtype=xs[0].dtype)
    for i, x in enumerate(xs):
        padded[i, :, -x.size(-1) :] = x  # right-align keeps *most recent* info at the end
    return padded.contiguous(), torch.stack(ys)


def svcj_worker_init(worker_id: int):
    """Guarantee *independent* RNG streams across data-loader workers."""
    worker_info = torch.utils.data.get_worker_info()
    dataset: SVCJDataset = worker_info.dataset  # type: ignore[arg-type]
    seed = dataset.rng.integers(0, 2**32 - 1) + worker_id
    dataset.rng = np.random.default_rng(int(seed))


# ===================================================================================
# 4 ▸ Temporal Convolutional Network (variable-length input)
# ===================================================================================
class Chomp1d(nn.Module):
    def __init__(self, chomp: int):
        super().__init__()
        self.chomp = chomp

    def forward(self, x: torch.Tensor) -> torch.Tensor:  # type: ignore[override]
        return x[..., : -self.chomp] if self.chomp else x


class ResidualBlock(nn.Module):
    def __init__(self, in_c: int, out_c: int, *, k: int = 3, d: int = 1, p: float = 0.1):
        super().__init__()
        pad = (k - 1) * d
        self.net = nn.Sequential(
            nn.utils.weight_norm(nn.Conv1d(in_c, out_c, k, padding=pad, dilation=d)),
            Chomp1d(pad),
            nn.ReLU(),
            nn.Dropout(p),
            nn.utils.weight_norm(nn.Conv1d(out_c, out_c, k, padding=pad, dilation=d)),
            Chomp1d(pad),
            nn.ReLU(),
            nn.Dropout(p),
        )
        self.down = nn.Conv1d(in_c, out_c, 1) if in_c != out_c else nn.Identity()

    def forward(self, x: torch.Tensor) -> torch.Tensor:  # type: ignore[override]
        return torch.relu(self.net(x) + self.down(x))


class SVCJTCN(nn.Module):
    """Back-bone for amortised parameter inference."""

    def __init__(self, *, n_layers: int = 6, n_filters: int = 64, k: int = 3):
        super().__init__()
        layers: List[nn.Module] = []
        in_c = 1
        for i in range(n_layers):
            layers.append(ResidualBlock(in_c, n_filters, k=k, d=2**i))
            in_c = n_filters
        self.tcn = nn.Sequential(*layers)
        self.head = nn.Sequential(
            nn.AdaptiveAvgPool1d(1),
            nn.Flatten(),
            nn.Linear(n_filters, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 22),
        )
        self.register_buffer("var_floor", torch.tensor(1e-6))

    # ──────────────────────────────────────────────────────────────────────
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:  # type: ignore[override]
        h = self.tcn(x)
        stats = self.head(h)
        mean, log_var = stats.split(11, dim=-1)
        var = F.softplus(log_var) + self.var_floor  # strictly >0, better gradients
        return mean, var


# ===================================================================================
# 5 ▸ UKF Refinement via Unscented Kalman Filter for parameter smoothing
# ===================================================================================
class UKFRefiner:
    """Optional temporal smoothing of per-window estimates via a vanilla UKF."""

    def __init__(
        self,
        *,
        n_steps: int = 3,
        alpha: float = 1e-3,
        beta: float = 2.0,
        kappa: float = 0.0,
    ) -> None:
        self.n_steps = n_steps
        self.alpha = alpha
        self.beta = beta
        self.kappa = kappa

    # ------------------------------------------------------------------
    def _unscented_transform(
        self, f: Callable[[torch.Tensor], torch.Tensor], m: torch.Tensor, P: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        n = m.size(-1)
        lambda_ = self.alpha**2 * (n + self.kappa) - n
        c = n + lambda_
        Wm0 = lambda_ / c
        Wc0 = Wm0 + (1 - self.alpha**2 + self.beta)
        Wi = 0.5 / c
        sqrt_P = torch.linalg.cholesky(P * c)
        sigma = torch.cat([m.unsqueeze(0), m.unsqueeze(0) + sqrt_P, m.unsqueeze(0) - sqrt_P], dim=0)
        Y = f(sigma)  # shape: (2n+1, n)
        y_mean = Wm0 * Y[0] + Wi * Y[1:].sum(0)
        Yc = Y - y_mean
        Pyy = Wc0 * torch.outer(Yc[0], Yc[0]) + Wi * (Yc[1:].T @ Yc[1:])
        return y_mean, Pyy

    # ------------------------------------------------------------------
    def __call__(
        self,
        mean: torch.Tensor,
        var: torch.Tensor,
        x: torch.Tensor,
        m_prior: torch.Tensor | None = None,
        S_prior: torch.Tensor | None = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        if m_prior is None or S_prior is None:
            return mean, var
        m = m_prior.clone()
        P = S_prior.clone()
        R = torch.diag(var)

        # Identity measurement function — amortiser already outputs state vector
        identity: Callable[[torch.Tensor], torch.Tensor] = lambda z: z

        for _ in range(self.n_steps):
            y_mean, P_zz = self._unscented_transform(identity, m, P)
            P_zz = P_zz + R
            K = P @ torch.linalg.inv(P_zz)
            m = m + K @ (mean - y_mean)
            P = P - K @ P_zz @ K.mT  # keep symmetric
        return m, torch.diagonal(P)


# ===================================================================================
# 6 ▸ Trainer wrapper
# ===================================================================================
class SVCJTrainer:
    def __init__(self, model: nn.Module, *, device: str = "cuda") -> None:
        self.model = model.to(device)
        self.device = device
        self.opt = torch.optim.AdamW(self.model.parameters(), lr=2e-4, weight_decay=1e-2)
        self.register_buffer = lambda *args, **kwargs: None  # placeholder so that .to() works

    # ------------------------------------------------------------------
    def train_epoch(self, loader: DataLoader, *, beta: float = 3.0) -> None:
        self.model.train()
        tot_loss = 0.0
        for x, y in loader:
            x = x.to(self.device, non_blocking=True)
            y = y.to(self.device, non_blocking=True)

            mean, var = self.model(x)
            inv_var = var.reciprocal()
            log_det = torch.log(var + 1e-9)

            nll = 0.5 * ((y - mean).pow(2) * inv_var + log_det).sum(-1).mean()
            penalty = beta * F.relu(((y - mean).abs() - 2.0 * torch.sqrt(var)).sum(-1)).mean()
            loss = nll + penalty

            self.opt.zero_grad(set_to_none=True)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
            self.opt.step()
            tot_loss += loss.item()
        return tot_loss / len(loader)

    # ------------------------------------------------------------------
    def fit(
        self,
        *,
        n_epochs: int = 5,
        steps_per_epoch: int = 500,
        batch_size: int = 512,
        num_workers: int = 4,
    ) -> None:
        dataset = SVCJDataset(steps_per_epoch * batch_size)
        loader = DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=True,
            pin_memory=True,
            num_workers=num_workers,
            collate_fn=svcj_collate,
            worker_init_fn=svcj_worker_init,
        )
        for ep in range(1, n_epochs + 1):
            t0 = time.time()
            loss = self.train_epoch(loader)
            print(f"[epoch {ep:02d}/{n_epochs}] loss={loss:8.5f}  ({time.time() - t0:4.1f}s)")


# ===================================================================================
# 7 ▸ Inference utility — rolling parameter time-series
# ===================================================================================
class SVCJInference:
    def __init__(
        self,
        model: nn.Module,
        *,
        refiner: UKFRefiner | None = None,
        device: str = "cuda",
    ) -> None:
        self.model = model.eval().to(device)
        self.refiner = refiner or UKFRefiner()
        self.device = device

    def predict(
        self,
        returns: np.ndarray | torch.Tensor,
        *,
        refine: bool = False,
        m_prev: torch.Tensor | None = None,
        S_prev: torch.Tensor | None = None,
        lambda_f: float = 2.0,
    ) -> tuple[np.ndarray, np.ndarray]:
        """
        Predict model output (mean, var) for a given (batched or unbatched) array of returns.
        Arguments:
            returns: shape (..., window_length)
        Returns:
            mean, var as numpy arrays (squeezed to 1d for a single sample, or 2d for batch)
        """
        # Convert input to tensor
        if isinstance(returns, np.ndarray):
            x = torch.from_numpy(returns).float().to(self.device)
        else:
            x = returns.float().to(self.device)
        # Reshape to [batch, 1, window]
        if x.ndim == 1:
            x = x.unsqueeze(0).unsqueeze(0)
        elif x.ndim == 2:
            x = x.unsqueeze(1)
        # Forward pass
        with torch.no_grad():
            mean, var = self.model(x)
        mean, var = mean.squeeze(0), var.squeeze(0)
        # Refinement if requested
        if refine:
            mean, var = self.refiner(mean, var, x.squeeze(0), m_prev, S_prev)
        return mean.cpu().numpy(), var.cpu().numpy()

    def rolling_estimates(
        self,
        prices: np.ndarray,
        timestamps: np.ndarray | None = None,
        *,
        window: int = 720,
        refine: bool = False,
        lambda_f: float = 2.0,
    ) -> "pd.DataFrame":
        import pandas as pd

        if timestamps is None:
            timestamps = np.arange(len(prices))

        logp = np.log(prices.astype(np.float64))
        returns = np.diff(logp).astype(np.float32)

        means: list[np.ndarray] = []
        vars_: list[np.ndarray] = []
        m_prev: torch.Tensor | None = None
        S_prev: torch.Tensor | None = None

        for t in range(window, len(returns) + 1):
            window_returns = returns[t - window : t]
            mean, var = self.predict(
                window_returns,
                refine=refine,
                m_prev=m_prev,
                S_prev=S_prev,
                lambda_f=lambda_f,
            )
            means.append(mean)
            vars_.append(var)
            # Update state for refiner if needed
            m_prev = torch.from_numpy(mean).to(self.device)
            S_prev = torch.diag(torch.from_numpy(var * lambda_f).to(self.device))

        # Timestamps for result index
        index = (
            pd.to_datetime(timestamps[window:], unit="s")
            if np.issubdtype(timestamps.dtype, np.integer)
            else pd.to_datetime(timestamps[window:])
        )
        cols = SVCJParams.fields()
        df_mean = pd.DataFrame(means, index=index, columns=cols)
        df_var = pd.DataFrame(vars_, index=index, columns=[f"{c}_var" for c in cols])
        return pd.concat([df_mean, df_var], axis=1)

In [82]:
ds = SVCJDataset(n_samples=10)
loader = torch.utils.data.DataLoader(
    ds,
    batch_size=2,
    collate_fn=svcj_collate,
    worker_init_fn=svcj_worker_init,
    num_workers=2,
)

# Check one batch
for x, y in loader:
    print("x:", x.shape, "y:", y.shape)
    break

# Instantiate model & trainer
model   = SVCJTCN(n_layers=4, n_filters=32)
trainer = SVCJTrainer(model, device="cpu")
print("Forward pass ok:", model(x)[0].shape, model(x)[1].shape)

x: torch.Size([2, 1, 624]) y: torch.Size([2, 11])
Forward pass ok: torch.Size([2, 11]) torch.Size([2, 11])


In [41]:
# Load model from checkpoint
# Instantiate model first
model = SVCJTCN(n_layers=6, n_filters=64, k=3)  # Reduce n_layers to 6 to match checkpoint

# Then load checkpoint
ckpt_path = "svcj_tcn.pt"
state_dict = torch.load(ckpt_path)

# Print missing keys before loading
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
print("Missing keys:", missing_keys)

# Load state dict with strict=False to ignore missing keys
model.load_state_dict(state_dict, strict=False)
model = model.to("cuda")
model.eval()


test_len = 2  # Can be > or < T_TRAIN
test_dataset = SVCJDataset(n_samples=1, min_T=test_len, max_T=test_len)
test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=1,
    collate_fn=svcj_collate,
    worker_init_fn=svcj_worker_init,
    num_workers=1,
)
x, y = list(test_loader)[0]

# Or: infer on any raw series directly!
infer = SVCJInference(model=model, device='cuda')
mean, var = infer.predict(returns=x[0])  # or returns=x.squeeze(0)
print("Predicted params (mean):", mean)
print("True params:", y[0].numpy())

import pandas as pd

print(
    pd.concat(
        [pd.Series(mean, index=SVCJParams.fields()), pd.Series(y[0].cpu().numpy(), index=SVCJParams.fields())],
        axis=1, keys=["predicted", "true"]
    )
)

  WeightNorm.apply(module, name, dim)
  state_dict = torch.load(ckpt_path)


Missing keys: set()
Predicted params (mean): [-0.0805743   0.04756197  2.744004   -0.10452211  0.05089341 -0.04276178
 -0.15153016 -0.06141418 -0.0877592   0.054356   -0.06689388]
True params: [-3.5826949e-04  5.1593947e+00  1.4571112e-03  8.5954919e-02
 -1.2422550e-01  3.9173887e+00  3.8847264e-02  1.6806702e-01
  9.1629885e-03  7.2722770e-03 -7.0708744e-02]
          predicted      true
mu        -0.080574 -0.000358
kappa      0.047562  5.159395
theta      2.744004  0.001457
sigma_v   -0.104522  0.085955
rho        0.050893 -0.124225
lam       -0.042762  3.917389
mu_s      -0.151530  0.038847
sigma_s   -0.061414  0.168067
mu_v      -0.087759  0.009163
sigma_vj   0.054356  0.007272
rho_j     -0.066894 -0.070709


Unnamed: 0,0,1
0,-277.757141,-3818.169922
1,-16.126276,3.095801
2,23931.474609,16067.860352
3,-18.298306,0.005242
4,10.967661,-0.097543
5,-34.681225,3.239611
6,-36.031876,0.037132
7,-120.537987,0.033082
8,119.918839,5.7e-05
9,27.665014,0.004818


tensor(-3818.1699)