In [129]:
import time

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pickle

import diffrax
import equinox as eqx
import jax
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jr
import jax.lax as lax
import optax
from functools import partial

import sys, os
sys.path.append(os.path.join(os.getcwd(), "../"))
from src.data.data_reader import DataReader

key_number = 8

def key():
    global key_number
    key_number += 1
    return jax.random.PRNGKey(key_number)

In [130]:
def get_data(processed_file):
    with open(processed_file, 'rb') as f:
        sequences = pickle.load(f)

    sorted_train_data = []
    for idx, (ts, xs, ts_eval, ys_eval) in enumerate(sequences):
        t0 = ts[0]
        t_ms = (ts - t0) / 1e6  # Convert nanoseconds to milliseconds
        te_ms = (ts_eval - t0) / 1e6

        if len(t_ms) != len(xs):
            print(f"Warning: Inconsistent lengths in sequence {idx}: len(t_ms)={len(t_ms)}, len(xs)={len(xs)}")
            continue  

        if len(te_ms) != len(ys_eval):
            print(f"Warning: Inconsistent lengths in ts_eval and ys_eval in sequence {idx}")
            continue 

        # Keep data as arrays
        sorted_train_data.append((t_ms, xs, te_ms, ys_eval))

    if not sorted_train_data:
        raise ValueError("No valid sequences found in the processed data.")

    return sorted_train_data

def dataloader(sequences, batch_size, subset_size, *, key):
    dataset_size = len(sequences[0])
    assert all(len(seq) == dataset_size for seq in sequences)
    indices = np.arange(dataset_size)

    while True:
        subset_perm = np.random.choice(indices, size=subset_size, replace=False)

        start = 0
        end = batch_size

        while start < subset_size:
            batch_perm = subset_perm[start:end]
            # Ensure data remains as arrays
            batch = [ [seq[i] for i in batch_perm] for seq in sequences ]
            yield batch
            start = end
            end = start + batch_size

In [131]:
class Func(eqx.Module):
    mlp: eqx.nn.MLP  
    def __init__(self, mlp):
        self.mlp = mlp
    def __call__(self, t, y, args):
        return self.mlp(y) 

# NeuralCDE Model
class NeuralCDE(eqx.Module):
    latent_to_latent: eqx.nn.Linear
    func: Func
    control_to_latent: eqx.nn.Linear
    linear: eqx.nn.Linear

    def __init__(self, func, data_size, hidden_size, width_size, depth, *, key):
        lkey, ckey = jax.random.split(key, 2)
        self.latent_to_latent = eqx.nn.Linear(in_features=hidden_size, out_features=hidden_size, key=lkey)
        self.control_to_latent = eqx.nn.Linear(in_features=data_size, out_features=hidden_size, key=ckey)
        self.func = func
        self.linear = eqx.nn.Linear(hidden_size, 1, key=ckey)

    def predict(self, ts, xs, ts_eval):
        def single_sample_prediction(ts_single, xs_single, ts_eval_single):
            ts_interp, xs_interp = diffrax.rectilinear_interpolation(ts_single, xs_single)

            control_data = jnp.concatenate([ts_interp[:, None], xs_interp], axis=-1)

            control_latent = jax.vmap(self.control_to_latent)(control_data)
            
            ts_initial = ts_single[0]
            xs_initial = xs_single[0, :]
            initial_control_value = jnp.concatenate([jnp.array([ts_initial]), xs_initial], axis=-1)
            y0 = self.control_to_latent(initial_control_value)
            
            saveat = diffrax.SaveAt(ts=ts_eval_single)
            control = diffrax.LinearInterpolation(ts_interp, control_latent)
            term = diffrax.ControlTerm(self.func, control).to_ode()
            solution = diffrax.diffeqsolve(
                term,
                solver=diffrax.Tsit5(),
                t0=ts_single[0],
                t1=ts_single[-1],
                y0=y0,
                dt0=None,
                stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6),
                saveat=saveat
            )
            outputs = jnn.sigmoid(self.linear(solution.ys))
            return outputs

        predictions = jax.vmap(single_sample_prediction, in_axes=(0, 0, None))(ts, xs, ts_eval)
        return predictions.squeeze()

    def compute_loss(self, ts, xs, ts_eval, y_true):
        y_pred = self.predict(ts, xs, ts_eval)
        return jnp.mean((y_pred - y_true) ** 2)
           

In [132]:
def train(model, train_data, optimizer, steps, batch_size, seq_length, eval_point_ratio, key, subset_size, patience=500, plot_every=500):
    opt_state = optimizer.init(eqx.filter(model, eqx.is_inexact_array))
    losses = []
    best_loss = float('inf')
    best_model = model
    last_best_step = 0
    step_times = []

    @eqx.filter_value_and_grad
    def loss(model, ts, xs, ts_test, y_true):
        return model.compute_loss(ts, xs, ts_test, y_true)

    @eqx.filter_jit
    def make_step(model, opt_state, ts, xs, ts_test, y_true):
        value, grads = loss(model, ts, xs, ts_test, y_true)
        updates, opt_state = optimizer.update(grads, opt_state, model)
        model = eqx.apply_updates(model, updates)
        return value, model, opt_state
    
    num_params = sum(p.size for p in jax.tree_util.tree_leaves(eqx.filter(model, eqx.is_inexact_array)))
    num_train_points = len(train_data) * seq_length * eval_point_ratio

    print("-" * 60)
    print(f"Number of parameters: {num_params}")
    print(f"Number of data points: {int(num_train_points)}")
    print("-" * 60)

    ts, xs, ts_eval, y_test = zip(*train_data)

    arrays = (jnp.array(ts), jnp.array(xs), jnp.array(ts_eval), jnp.array(y_test))
    data_gen = dataloader(arrays, batch_size, subset_size=subset_size, key=key)

    for step in range(steps):
        start_time = time.time()

        # Get the next batch from the dataloader
        ts_batch, xs_batch, ts_test_batch, y_true_batch = next(data_gen)
        ts_batch, xs_batch, ts_test_batch, y_true_batch = jnp.array(ts_batch), jnp.array(xs_batch), jnp.array(ts_test_batch), jnp.array(y_true_batch)

        # Perform a training step
        loss_value, model, opt_state = make_step(model, opt_state, ts_batch, xs_batch, ts_test_batch, y_true_batch)
        losses.append(loss_value)

        step_time = time.time() - start_time
        step_times.append(step_time)
        if len(step_times) > 100:
            step_times.pop(0)

        if loss_value < best_loss:
            best_loss = loss_value
            best_model = model
            last_best_step = step

        if step % 100 == 0:
            avg_step_time = sum(step_times) / len(step_times) if step_times else 0
            estimated_time_remaining = avg_step_time * (steps - step - 1)
            if step == 0:
                print(f"Step {step}, Loss: {loss_value:.4f}, Best Loss: {best_loss:.4f}, Estimated Time Remaining: -- ")
            else:
                print(f"Step {step}, Loss: {loss_value:.4f}, Best Loss: {best_loss:.4f}, Estimated Time Remaining: {estimated_time_remaining / 60:.2f} minutes")

        if step % plot_every == 0:
            _, subkey = jr.split(key)
            random_index = jr.randint(subkey, (1,), 0, batch_size).item()
            plt.figure(figsize=(4, 3))
            plt.plot(ts_test_batch[random_index], y_true_batch[random_index], label='Actual', marker='o')
            y_pred = best_model.predict(jnp.expand_dims(ts_batch[random_index], axis=0), jnp.expand_dims(xs_batch[random_index], axis=0), jnp.expand_dims(ts_test_batch[random_index], axis=0))
            y_pred = y_pred.squeeze()
            plt.plot(ts_test_batch[random_index], y_pred, label='Predicted', marker='x')
            plt.xlabel('Time (ts_test)')
            plt.ylabel('Y values')
            plt.title(f'Training Sequence at Step {step}: Actual vs Predicted Y over Time')
            plt.legend()
            plt.grid(True)
            plt.show()

        if step - last_best_step >= patience:
            print(f"Stopping early at step {step}, no improvement for {patience} steps.")
            return best_model

    plt.figure(figsize=(10, 6))
    plt.plot(losses, label="Training Loss")
    plt.xlabel("Iteration")
    plt.ylabel("Loss")
    plt.title("Training Loss Over Time")
    plt.legend()
    plt.grid(True)
    plt.show()

    return best_model

In [133]:
batch_size = 8
subset_size = 128
train_steps = 2000
x_size = 6
eta = 5e-3
num_train_seq = 1000
train_seq_len = 32
train_test_ratio = 0.25
func_mlp_width = 4
func_mlp_depth = 2
latent_size = 10
num_test_seq = 5
test_seq_len = 200
test_test_ratio = 0.3
train_patience = 200
plot_every = 100

mlp = eqx.nn.MLP(
    in_size=latent_size,
    out_size=latent_size,
    width_size=func_mlp_width,
    depth=func_mlp_depth,
    key=key()
)

func = Func(mlp=mlp)

model = NeuralCDE(
    func=func,
    hidden_size=latent_size,
    width_size=func_mlp_width,
    depth=func_mlp_depth,
    data_size=x_size + 1,
    key=key()
)

optimizer = optax.adam(learning_rate=eta)

train_data = get_data(
    processed_file = "../data/processed/Visa_2024-09-06.pkl"
)

trained_model = train(
    model,
    train_data,
    optimizer,
    steps=train_steps,
    batch_size=batch_size,
    seq_length=train_seq_len,
    eval_point_ratio=train_test_ratio,
    key=key(),
    subset_size=subset_size,
    patience=train_patience,
    plot_every=plot_every
)

------------------------------------------------------------
Number of parameters: 315
Number of data points: 8192
------------------------------------------------------------


TypeError: dot_general requires contracting dimensions to have the same shape, got (8,) and (7,).