# Neural MJD Demonstration Notebook

This notebook demonstrates the Neural MJD model with randomly generated data, following the same logic as the main training script.


## Python imports and initializations

In [1]:
# Import required libraries
import os
import sys
import numpy as np
import torch
from torch.utils.data import Dataset
from torch_geometric.data import Data
import torch_geometric.loader
from torch_geometric.utils import to_dense_batch, to_dense_adj

# Add project root to path
project_root = os.getcwd()
sys.path.append(project_root)

# Project modules
from model.transformer import MJDTransformer
from model.mjd.neural_mjd import NeuralMJD

# Device and seeds
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(42)
np.random_seed = 42
np.random.seed(np.random_seed)


In [2]:
# Synthetic data parameters
SEQLEN = 10
PREDLEN = 3
PAST_LEN = SEQLEN - PREDLEN

NUM_NODES = 4
NUM_SAMPLES = 512

# Training parameters
BATCH_SIZE = 512
MAX_EPOCH = 100
LR = 1e-4
HUBER_DELTA = 10.0

# Model parameters
NUM_LAYERS = 4
FEATURE_DIMS = 256
NUM_HEADS = 8


## Create synthetic data

In [3]:
class SyntheticGraphDataset(Dataset):
    def __init__(self, num_samples=10, num_nodes=NUM_NODES, seqlen=SEQLEN, predlen=PREDLEN):
        self.num_samples = num_samples
        self.num_nodes = num_nodes
        self.seqlen = seqlen
        self.predlen = predlen
        self.past_len = seqlen - predlen
        
        # Pre-generate all samples ONCE
        self.data_list = [self._generate_sample() for _ in range(self.num_samples)]

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return self.data_list[idx]
    
    def _generate_sample(self):
        # Sinusoidal prices with same trend for past and future
        t_past = torch.arange(self.past_len).float().unsqueeze(0)
        t_future = torch.arange(self.past_len, self.past_len + self.predlen).float().unsqueeze(0)

        base = 20.0 + 15.0 * torch.rand(self.num_nodes, 1)   # [20, 35]
        amp = 5.0 + 10.0 * torch.rand(self.num_nodes, 1)     # [5, 15]
        freq = 0.05 + 0.15 * torch.rand(self.num_nodes, 1)   # [0.05, 0.20]
        phase = torch.rand(self.num_nodes, 1) * 2 * torch.pi

        past_prices = base + amp * torch.sin(freq * t_past + phase) + torch.randn(self.num_nodes, self.past_len)
        future_prices_raw = base + amp * torch.sin(freq * t_future + phase) + torch.randn(self.num_nodes, self.predlen)

        past_prices = past_prices.clamp(1.0, 50.0)
        future_prices_raw = future_prices_raw.clamp(1.0, 50.0)

        # Fully connected graph (all nodes connected to all nodes)
        if self.num_nodes > 1:
            all_pairs = torch.combinations(torch.arange(self.num_nodes), r=2, with_replacement=False)
            edge_index = torch.cat([all_pairs.t(), all_pairs.flip(1).t()], dim=1)
            # Add self-loops
            self_loops = torch.arange(self.num_nodes).unsqueeze(0).repeat(2, 1)
            edge_index = torch.cat([edge_index, self_loops], dim=1)
        else:
            # Self-loop for single node
            edge_index = torch.zeros(2, 1, dtype=torch.long)

        # Empty edge attributes for simplicity
        edge_attr = torch.zeros(edge_index.shape[1], 1)

        # Normalization
        last_prices = past_prices[:, -1:].clamp(1.0, 50.0)
        price_norm_coef = last_prices.squeeze()
        future_prices_norm = (future_prices_raw / last_prices).clamp(1.0, 50.0)

        return Data(
            dyn_feat=torch.zeros(self.num_nodes, 1, self.past_len),
            static_feat=torch.zeros(self.num_nodes, 1),
            edge_index=edge_index,
            edge_attr=edge_attr,
            price_past_raw=past_prices,
            price_norm_coef=price_norm_coef,
            price_future_norm=future_prices_norm,
            price_future_raw=future_prices_raw,
            num_nodes=self.num_nodes,
        )
    
full_dataset = SyntheticGraphDataset(num_samples=NUM_SAMPLES, num_nodes=NUM_NODES)
# split into train and test
train_dataset = full_dataset[:int(0.8 * len(full_dataset))]
test_dataset = full_dataset[int(0.8 * len(full_dataset)):]

print(f"Train dataset size: {len(train_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")

Train dataset size: 409
Test dataset size: 103


In [4]:
# Create data loader
train_dl = torch_geometric.loader.DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4,
)

test_dl = torch_geometric.loader.DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=4,
)

print(f"Train Batch size: {train_dl.batch_size}, Num batches: {len(train_dl)}")
print(f"Test Batch size: {test_dl.batch_size}, Num batches: {len(test_dl)}")


Train Batch size: 512, Num batches: 1
Test Batch size: 512, Num batches: 1


## Create model and start training

In [5]:
# Initialize model
in_seq_dim = 1
num_static_features = 1
in_seq_length = PAST_LEN
out_seq_length = PREDLEN
output_dim = 5

network = MJDTransformer(
    in_seq_length=in_seq_length,
    in_seq_dim=in_seq_dim,
    out_seq_length=out_seq_length,
    out_seq_dim=output_dim,
    num_static_features=num_static_features,
    num_encoder_layers=NUM_LAYERS,
    embedding_dim=FEATURE_DIMS,
    ffn_embedding_dim=FEATURE_DIMS * 4,
    num_attention_heads=NUM_HEADS,
    pre_layernorm=False,
    activation_fn='relu',
    dropout=0.0,
    light_mode=False,
)

model = NeuralMJD(
    model=network,
    w_cond_mean_loss=1.0,
    steps_per_unit_time=5,
    jump_diffusion=True,
    s_0_from_avg=True,
    cond_mean_raw_scale=True,
).to(device)

print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")


Parameters: 6,693,999


In [6]:
# Optimizer and scheduler
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=LR,
    betas=(0.9, 0.999),
    eps=1e-8,
    weight_decay=0.0,
)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=MAX_EPOCH * len(train_dl),
    eta_min=1e-6
)

print("Optimizer ready.")


Optimizer ready.


In [7]:
def process_batch_data(batch_data, device):
    batch_data = batch_data.to(device)
    node_past_dyn_data, node_mask = to_dense_batch(batch_data.dyn_feat, batch=batch_data.batch)
    node_past_static_data, _ = to_dense_batch(batch_data.static_feat, batch=batch_data.batch)
    adj_matrix = to_dense_adj(batch_data.edge_index, batch=batch_data.batch)
    edge_attr_dist = to_dense_adj(batch_data.edge_index, batch=batch_data.batch, edge_attr=batch_data.edge_attr)
    node_target, _ = to_dense_batch(batch_data.price_future_raw, batch=batch_data.batch)
    node_norm_coef, _ = to_dense_batch(batch_data.price_norm_coef, batch=batch_data.batch)
    node_norm_coef = node_norm_coef.unsqueeze(-1)

    batch_size, num_nodes = node_mask.shape
    spatial_pos = torch.zeros((batch_size, num_nodes, num_nodes), dtype=torch.long, device=device)
    spatial_pos[adj_matrix > 0] = 1

    node_type = torch.ones((batch_size, num_nodes, 1), device=device)
    node_type[:, 0, 0] = 0

    batched_data = {
        'node_type': node_type.long(),
        'in_degree': adj_matrix.sum(dim=-1).long(),
        'out_degree': adj_matrix.sum(dim=-2).long(),
        'spatial_pos': spatial_pos.long(),
        'edge_attr': edge_attr_dist.float(),
        'adj_matrix': adj_matrix.long(),
        'node_mask': node_mask,
        'node_past_dyn_data': node_past_dyn_data.float(),
        'node_past_static_data': node_past_static_data.float(),
        'node_norm_coef': node_norm_coef,
        'data_norm': 'max',
        'huber_delta': HUBER_DELTA,
    }
    return batched_data, node_target, node_mask


In [8]:
def train_one_epoch(model, optimizer, scheduler, dataloader, device):
    model.train()
    total_loss = 0.0
    num_batches = 0
    for batch_data in dataloader:
        batched_data, node_target, node_mask = process_batch_data(batch_data, device)
        cond_mean_loss, likelihood_loss, _ = model(batched_data, target=node_target, flag_sample=False)
        loss = (cond_mean_loss + likelihood_loss)[node_mask].mean()
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0)
        optimizer.step()
        scheduler.step()
        total_loss += loss.item()
        num_batches += 1
    return total_loss / max(num_batches, 1)

In [9]:
print("Starting training...")
for epoch in range(MAX_EPOCH):
    avg_loss = train_one_epoch(model, optimizer, scheduler, train_dl, device)
    print(f"Epoch {epoch+1}/{MAX_EPOCH} - Loss: {avg_loss:.4f}")
print("Done.")

Starting training...
Epoch 1/100 - Loss: 648.4822
Epoch 2/100 - Loss: 384.7818
Epoch 3/100 - Loss: 250.1821
Epoch 4/100 - Loss: 244.7582
Epoch 5/100 - Loss: 242.8763
Epoch 6/100 - Loss: 241.7639
Epoch 7/100 - Loss: 240.6798
Epoch 8/100 - Loss: 240.0987
Epoch 9/100 - Loss: 239.0287
Epoch 10/100 - Loss: 238.3629
Epoch 11/100 - Loss: 237.7062
Epoch 12/100 - Loss: 237.0687
Epoch 13/100 - Loss: 236.5359
Epoch 14/100 - Loss: 236.0414
Epoch 15/100 - Loss: 235.5961
Epoch 16/100 - Loss: 235.1423
Epoch 17/100 - Loss: 234.6921
Epoch 18/100 - Loss: 234.2200
Epoch 19/100 - Loss: 233.8360
Epoch 20/100 - Loss: 233.3167
Epoch 21/100 - Loss: 232.8546
Epoch 22/100 - Loss: 232.3557
Epoch 23/100 - Loss: 231.7903
Epoch 24/100 - Loss: 231.1392
Epoch 25/100 - Loss: 230.3639
Epoch 26/100 - Loss: 229.4078
Epoch 27/100 - Loss: 228.1947
Epoch 28/100 - Loss: 226.6243
Epoch 29/100 - Loss: 224.6176
Epoch 30/100 - Loss: 221.9038
Epoch 31/100 - Loss: 218.1745
Epoch 32/100 - Loss: 213.0657
Epoch 33/100 - Loss: 206.091

## Model inference

In [10]:
from torch_geometric.utils import to_dense_batch

@torch.no_grad()
def evaluate_inference_modes(model, dataloader, device):
    metrics = {
        'mean': {'mae': 0.0},
        'winner': {'mae': 0.0},
        'prob': {'mae': 0.0},
    }
    total_count = 0

    model.eval()
    for batch in dataloader:
        batched_data, node_target_norm, node_mask = process_batch_data(batch, device)
        node_target_raw, _ = to_dense_batch(batch.price_future_raw, batch=batch.batch)

        _, _, samples = model(batched_data, target=node_target_raw, flag_sample=True)
        s_out_demean, s_out_winner_demean, s_out_prob_demean = samples

        mean_pred_raw = s_out_demean.mean(dim=2)

        total_count += node_mask.sum().item()

        for name, pred in {
            'mean': mean_pred_raw,
            'winner': s_out_winner_demean,
            'prob': s_out_prob_demean,
        }.items():
            diff = (pred - node_target_raw)
            metrics[name]['mae'] += diff.abs().sum().item()

    # Aggregate
    results = {}
    for name in metrics:
        mae = metrics[name]['mae'] / max(total_count, 1)
        results[name] = {'MAE': mae}
    return results

# Run evaluation
results = evaluate_inference_modes(model, train_dl, device)
print("\nInference-mode comparison (on train data):")
for mode, vals in results.items():
    print(f"- {mode:>6}: MAE={vals['MAE']:.4f}")

results = evaluate_inference_modes(model, test_dl, device)
print("\nInference-mode comparison (on test data):")
for mode, vals in results.items():
    print(f"- {mode:>6}: MAE={vals['MAE']:.4f}")



Inference-mode comparison (on train data):
-   mean: MAE=82.2180
- winner: MAE=53.4728
-   prob: MAE=77.0730

Inference-mode comparison (on test data):
-   mean: MAE=76.2225
- winner: MAE=53.9657
-   prob: MAE=77.8604
