In [1]:
import time

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pickle

import diffrax
import equinox as eqx
import jax
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jr
import jax.lax as lax
import optax
from functools import partial

import sys, os
sys.path.append(os.path.join(os.getcwd(), "../"))
from src.data.data_reader import DataReader

key_number = 8

def key():
    global key_number
    key_number += 1
    return jax.random.PRNGKey(key_number)

In [2]:
def get_data(processed_file):
    with open(processed_file, 'rb') as f:
        sequences = pickle.load(f)

    sorted_train_data = []
    for idx, (ts, xs, ts_eval, ys_eval) in enumerate(sequences):
        t0 = ts[0]
        t_ms = (ts - t0) / 1e6  # Convert nanoseconds to milliseconds
        te_ms = (ts_eval - t0) / 1e6

        if len(t_ms) != len(xs):
            print(f"Warning: Inconsistent lengths in sequence {idx}: len(t_ms)={len(t_ms)}, len(xs)={len(xs)}")
            continue  

        if len(te_ms) != len(ys_eval):
            print(f"Warning: Inconsistent lengths in ts_eval and ys_eval in sequence {idx}")
            continue 

        # Keep data as arrays
        sorted_train_data.append((t_ms, xs, te_ms, ys_eval))

    if not sorted_train_data:
        raise ValueError("No valid sequences found in the processed data.")

    return sorted_train_data

def dataloader(sequences, batch_size, subset_size, *, key):
    dataset_size = len(sequences[0])
    assert all(len(seq) == dataset_size for seq in sequences)
    indices = np.arange(dataset_size)

    while True:
        subset_perm = np.random.choice(indices, size=subset_size, replace=False)

        start = 0
        end = batch_size

        while start < subset_size:
            batch_perm = subset_perm[start:end]
            # Ensure data remains as arrays
            batch = [ [seq[i] for i in batch_perm] for seq in sequences ]
            yield batch
            start = end
            end = start + batch_size

In [8]:
def train(model, train_data, optimizer, steps, batch_size, key, subset_size):
    opt_state = optimizer.init(eqx.filter(model, eqx.is_inexact_array))
    losses = []
    best_model = model

    @eqx.filter_value_and_grad
    def loss(model, ts, xs, ts_test, y_true):
        return model.compute_loss(ts, xs, ts_test, y_true)

    @eqx.filter_jit
    def make_step(model, opt_state, ts, xs, ts_test, y_true):
        value, grads = loss(model, ts, xs, ts_test, y_true)
        updates, opt_state = optimizer.update(grads, opt_state, model)
        model = eqx.apply_updates(model, updates)
        return value, model, opt_state

    ts, xs, ts_eval, y_test = zip(*train_data)

    arrays = (jnp.array(ts), jnp.array(xs), jnp.array(ts_eval), jnp.array(y_test))
    data_gen = dataloader(arrays, batch_size, subset_size=subset_size, key=key)

    for _ in range(steps):

        # Get the next batch from the dataloader
        ts_batch, xs_batch, ts_test_batch, y_true_batch = next(data_gen)

        # Perform a training step
        loss_value, model, opt_state = make_step(model, opt_state, ts_batch, xs_batch, ts_test_batch, y_true_batch)
        losses.append(loss_value)

    return best_model

In [9]:
import equinox.nn as nn

class SimpleRNN(eqx.Module):
    rnn: nn.GRUCell
    linear: nn.Linear
    key: jax.random.PRNGKey

    def __init__(self, input_size, hidden_size, output_size, *, key):
        self.key = key
        rnn_key, linear_key = jax.random.split(key)
        self.rnn = nn.GRUCell(input_size=input_size, hidden_size=hidden_size, key=rnn_key)
        self.linear = nn.Linear(in_features=hidden_size, out_features=output_size, key=linear_key)

    def __call__(self, xs, xs_mask):
        # xs shape: (batch_size, seq_length, input_size)
        batch_size, seq_length, input_size = xs.shape
        h = jnp.zeros((batch_size, self.rnn.hidden_size))
        outputs = []
        for t in range(seq_length):
            x_t = xs[:, t, :]
            mask_t = xs_mask[:, t][:, None]
            h = self.rnn(x_t, h)
            h = h * mask_t  # Apply mask
            outputs.append(h)
        outputs = jnp.stack(outputs, axis=1)  # (batch_size, seq_length, hidden_size)
        return outputs

    def compute_loss(self, ts_batch, xs_batch, ts_test_batch, y_true_batch,
                     ts_mask, xs_mask, ts_test_mask, y_true_mask):
        # xs_batch shape: (batch_size, seq_length)
        # Reshape xs_batch to include input_size dimension
        xs_batch = xs_batch[..., None] 
        outputs = self(xs_batch, xs_mask)
        last_hidden_states = outputs[:, -1, :]
        predictions = self.linear(last_hidden_states)
     
        loss = jnp.mean((predictions - y_true_batch[:, 0]) ** 2)
        return loss

In [11]:
eta = 1e-3
train_steps=1000
batch_size = 8


model = SimpleRNN(
    input_size = 6,
    hidden_size = 16,
    output_size = 1,
    key=key()
)

optimizer = optax.adam(learning_rate=eta)

train_data = get_data(
    processed_file = "../data/processed/Visa_2024-09-06.pkl"
)

trained_model = train(
    model,
    train_data,
    optimizer,
    steps=train_steps,
    batch_size=batch_size,
    key=key(),
    subset_size=128,
)

TypeError: compute_loss() missing 4 required positional arguments: 'ts_mask', 'xs_mask', 'ts_test_mask', and 'y_true_mask'