In [9]:
import time

import diffrax
import equinox as eqx
import jax
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jr
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import optax


import sys, os
sys.path.append(os.path.join(os.getcwd(), "../"))
from src.data.data_reader import DataReader

In [69]:
import jax.numpy as jnp
import jax.random as jr
import pandas as pd

# Function to generate a latent process using a simple ODE
def generate_latent_process(num_steps=1000, dt=0.01, key=None):
    key, subkey = jr.split(key)
    
    # Initialize the latent state
    latent_state = jnp.zeros((num_steps,))
    for i in range(1, num_steps):
        # Simple ODE: dx/dt = -x + noise
        noise = jr.normal(subkey, (1,)) * 0.1  # Adjust noise scale as needed
        latent_state = latent_state.at[i].set(latent_state[i-1] - latent_state[i-1] * dt + noise)
        subkey, noise_key = jr.split(subkey)

    return latent_state

def get_data(num_samples=1000, test_ratio=0.2, key=None):
    # Generate a latent process
    latent_process = generate_latent_process(num_steps=num_samples, key=key)

    # Randomly sample timestamps from the latent process
    ts = jr.choice(key, jnp.arange(num_samples), shape=(num_samples,))
    ts_test = jr.choice(key, jnp.arange(num_samples), shape=(num_samples // test_ratio,))

    # Generate noisy observations (xs) from the latent process
    xs = latent_process[ts] + jr.normal(key, shape=(len(ts),)) * 0.05  # Add noise

    # Generate true outputs (ys) as a function of the latent state
    # Here, we can define a simple function like y = latent_state^2 + noise
    ys = latent_process[ts] ** 2 + jr.normal(key, shape=(len(ts),)) * 0.1
    
    # Get the y_true values at the sampled test times
    y_true = latent_process[ts_test]

    return jnp.expand_dims(ts, axis=0), jnp.expand_dims(xs, axis=0), jnp.expand_dims(ys, axis=0), jnp.expand_dims(ts_test, axis=0), jnp.expand_dims(y_true, axis=0)

def dataloader(arrays, batch_size, *, key):
    dataset_size = arrays[0].shape[1]  # Adjusted for shape (1, N)
    assert all(array.shape[1] == dataset_size for array in arrays)
    
    indices = jnp.arange(dataset_size)
    while True:
        perm = jr.permutation(key, indices)
        (key,) = jr.split(key, 1)
        start = 0
        end = batch_size
        while start < dataset_size:
            batch_perm = perm[start:end]
            yield tuple(array[:, batch_perm] for array in arrays)  # Adjust for shape (1, N)
            start = end
            end = start + batch_size


In [11]:
def dataloader(arrays, batch_size, *, key):
    dataset_size = arrays[0].shape[0]
    assert all(array.shape[0] == dataset_size for array in arrays)
    indices = jnp.arange(dataset_size)
    while True:
        perm = jr.permutation(key, indices)
        (key,) = jr.split(key, 1)
        start = 0
        end = batch_size
        while start < dataset_size:
            batch_perm = perm[start:end]
            yield tuple(array[batch_perm] for array in arrays)
            start = end
            end = start + batch_size

In [25]:
class NeuralODEModel(eqx.Module):
    func: eqx.Module
    rnn_cell: eqx.nn.GRUCell
    latent_to_y: eqx.nn.MLP
    x_to_latent: eqx.nn.Linear
    latent_size: int

    def __init__(self, func, rnn_cell, latent_to_y, latent_size, x_size):
        self.func = func
        self.rnn_cell = rnn_cell
        self.latent_to_y = latent_to_y
        self.x_to_latent = eqx.nn.Linear(x_size, latent_size, key=jax.random.PRNGKey(6))
        self.latent_size = latent_size

    def _propagate_ode(self, latent, t0, t1):
        sol = diffrax.diffeqsolve(
            diffrax.ODETerm(self.func),
            diffrax.Tsit5(),
            t0=t0,
            t1=t1,
            dt0=0.1,
            y0=latent,
        )
        return sol.ys[-1]

    def predict(self, ts, xs, ts_test):
        ts_combined = jnp.sort(jnp.concatenate([ts, ts_test]))
        is_test_time = jnp.isin(ts_combined, ts_test)
        latent = jnp.zeros((self.latent_size,))
        ys_pred = []
        last_t = ts_combined[0]

        for i, t in enumerate(ts_combined):
            if i > 0:
                latent = self._propagate_ode(latent, last_t, t)
            if t in ts:
                idx = jnp.where(ts == t)[0][0]
                x = xs[idx]
                x_latent = self.x_to_latent(x)
                latent = self.rnn_cell(latent, x_latent)
            if is_test_time[i]:
                y_pred = self.latent_to_y(latent)
                ys_pred.append(y_pred)
            last_t = t

        return jnp.stack(ys_pred)

    
    def compute_loss(self, ts, xs, ts_test, y_true):
        y_pred = self.predict(ts, xs, ts_test)
        return jnp.mean((y_pred - y_true) ** 2)


In [44]:
import jax
import jax.numpy as jnp
import equinox as eqx
import optax
import pandas as pd
import jax.random as jr

# Assume NeuralODEModel, get_data, and dataloader are defined as discussed

def get_data(train_split=0.8):
    d_path = "../data/processed/msft.csv"
    df = pd.read_csv(d_path)
    df['timestamp'] = pd.to_datetime(df['timestamp'])
    df = df.drop_duplicates(subset='timestamp', keep='last')

    # Convert timestamp to seconds and create xs (features)
    ts = pd.to_datetime(df["timestamp"]).astype(int) / 1_000_000_000  # Convert to seconds
    ts = jnp.array(ts)
    xs = jnp.column_stack((df["bid_price"].to_numpy(), df["offer_price"].to_numpy()))  # Features
    ys = df["next_price"].to_numpy()  # Next price as the target variable

    # Normalize ts, xs, and ys
    ts = (ts - ts.min()) / (ts.max() - ts.min())  # Normalize ts between 0 and 1
    xs = (xs - xs.min(axis=0)) / (xs.max(axis=0) - xs.min(axis=0))  # Normalize each feature in xs
    ys = (ys - ys.min()) / (ys.max() - ys.min())  # Normalize ys between 0 and 1

    # Split data into training and test sets
    num_points = len(ts)
    split_idx = int(train_split * num_points)
    train_ts, test_ts = ts[:split_idx], ts[split_idx:]
    train_xs, test_xs = xs[:split_idx], xs[split_idx:]
    train_ys, test_ys = ys[:split_idx], ys[split_idx:]

    return (train_ts, train_xs, train_ys), (test_ts, test_xs, test_ys)


def dataloader(ts, xs, ys, batch_size, seq_length, key):
    """
    Dataloader for sampling batches of time series.

    ts: Timestamps for all series
    xs: Corresponding feature values for all series
    ys: Corresponding target values for all series
    batch_size: Number of time series in each batch
    seq_length: Length of each time series in the batch
    key: PRNGKey for random number generation
    """
    dataset_size = len(ts)
    indices = jnp.arange(dataset_size - seq_length)  # Indices where sequences of seq_length can start

    while True:
        perm = jr.choice(key, indices, shape=(batch_size,))  # Randomly sample start indices
        (key,) = jr.split(key, 1)
        ts_batch = []
        xs_batch = []
        ys_batch = []

        for idx in perm:
            ts_batch.append(ts[idx:idx + seq_length])
            xs_batch.append(xs[idx:idx + seq_length])
            ys_batch.append(ys[idx:idx + seq_length])

        yield jnp.stack(ts_batch), jnp.stack(xs_batch), jnp.stack(ys_batch)



In [65]:
def train(model, train_data, optimizer, steps, batch_size, seq_length, key):
    # Initialize optimizer state
    opt_state = optimizer.init(eqx.filter(model, eqx.is_inexact_array))
    
    (train_ts, train_xs, train_ys) = train_data  # Unpack train and test data
    train_key = key

    # Create data loader for training data
    loader = dataloader(train_ts, train_xs, train_ys, batch_size, seq_length, key=train_key)

    @eqx.filter_value_and_grad
    def loss(model, ts_i, xs_i, ys_i, key_i):
        batch_size, _ = ts_i.shape
        key_i = jr.split(key_i, batch_size)
        loss = jax.vmap(model.compute_loss)(ts_i, xs_i, ys_i, key=key_i)  # Updated to include xs
        return jnp.mean(loss)

    @eqx.filter_jit
    def make_step(model, opt_state, ts_i, xs_i, ys_i, key_i):
        value, grads = loss(model, ts_i, xs_i, ys_i, key_i)
        key_i = jr.split(key_i, 1)[0]
        updates, opt_state = optimizer.update(grads, opt_state)
        model = eqx.apply_updates(model, updates)
        return value, model, opt_state, key_i

    for step, (ts_i, xs_i, ys_i) in zip(
        range(steps), loader  # Use the data loader directly
    ):
        opt_key, subkey = jr.split(train_key)
        loss_value, model, opt_state, key_i = make_step(model, opt_state, ts_i, xs_i, ys_i, subkey)

        if step % 100 == 0:
            print(f"Step {step}, Loss: {loss_value:.4f}")

    return model


In [66]:

# Hyperparameters
batch_size = 32
seq_length = 100  # Length of each time series
steps = 1000
latent_size = 5
x_size = 2  # Number of features (bid_price and offer_price)

# Define model
func = eqx.nn.MLP(in_size=latent_size, out_size=latent_size, width_size=32, depth=2, key=jax.random.PRNGKey(0))
rnn_cell = eqx.nn.GRUCell(x_size, latent_size, key=jax.random.PRNGKey(1))
latent_to_y = eqx.nn.MLP(in_size=latent_size, out_size=1, width_size=32, depth=2, key=jax.random.PRNGKey(2))  # Update output size to 1

model = NeuralODEModel(func, rnn_cell, latent_to_y, latent_size, x_size)

# Adam optimizer
optimizer = optax.adam(learning_rate=1e-3)

# Get data, split into training and test
data = get_data(train_split=0.8)





In [67]:
train_data, test_data = data


In [68]:
# Train the model
trained_model = train(model, train_data, optimizer, steps=steps, batch_size=batch_size, seq_length=seq_length, key=jax.random.PRNGKey(0))

TypeError: Cannot interpret value of type <class 'jax._src.custom_derivatives.custom_jvp'> as an abstract array; it does not have a dtype attribute