<https://github.com/PolymathicAI/xVal>


In [1]:
from flax import nnx
import flax

flax.__version__

'0.10.2'

In [2]:
import ast
import re
from datetime import datetime as dt
import os

import icecream
import jax
import jax.numpy as jnp
import orbax.checkpoint as ocp
import matplotlib.pyplot as plt
import numpy as np
import orbax
import orbax.checkpoint
import pandas as pd
import seaborn as sns
import optax
from flax.struct import dataclass
from flax.training import orbax_utils, train_state
from flax import nnx
from icecream import ic
from jax import random
from jax.tree_util import tree_flatten
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm.notebook import tqdm, trange
from transformers import BertTokenizerFast, FlaxBertModel

import hephaestus as hp
import hephaestus.models.time_series_decoder_training as tsd

icecream.install()
ic_disable = False  # Global variable to disable ic
if ic_disable:
    ic.disable()
ic.configureOutput(includeContext=True, contextAbsPath=True)
pd.options.mode.copy_on_write = True

In [3]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [4]:
# Load pre-trained BERT model and tokenizer
model_name = "bert-base-uncased"
model = FlaxBertModel.from_pretrained(model_name)
tokenizer = BertTokenizerFast.from_pretrained(model_name)

# Get the embeddings matrix
embeddings = model.params["embeddings"]["word_embeddings"]["embedding"]

# Now you can access specific embeddings like this:
# For example, to get embeddings for tokens 23, 293, and 993:
selected_embeddings = jnp.take(embeddings, jnp.array([23, 293, 993]), axis=0)

# If you want to get embeddings for specific words:
words = ["hello", "world", "example"]
tokens = tokenizer.convert_tokens_to_ids(words)
word_embeddings = jnp.take(embeddings, jnp.array(tokens), axis=0)
word_embeddings.shape

Some weights of FlaxBertModel were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: {('pooler', 'dense', 'kernel'), ('pooler', 'dense', 'bias')}
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


(3, 768)

In [5]:
def line2df(line, idx):
    data_rows = []
    line = ast.literal_eval(line)
    for i, time_step in enumerate(line["data"]):
        row = {"time_step": i}
        # Add position data for each planet
        for j, position in enumerate(time_step):
            row[f"planet{j}_x"] = position[0]
            row[f"planet{j}_y"] = position[1]
        data_rows.append(row)

    df = pd.DataFrame(data_rows)
    description = line.pop("description")
    step_size = description.pop("stepsize")
    for k, v in description.items():
        for k_prop, v_prop in v.items():
            df[f"{k}_{k_prop}"] = v_prop
    df["time_step"] = df["time_step"] * step_size
    df.insert(0, "idx", idx)

    return df

In [6]:
files = os.listdir("data")
if "planets.parquet" not in files:
    with open("data/planets.data") as f:
        data = f.read().splitlines()

        dfs = []
        for idx, line in enumerate(tqdm(data)):
            dfs.append(line2df(line, idx))
        df = pd.concat(dfs)
    df.to_parquet("data/planets.parquet")
else:
    df = pd.read_parquet("data/planets.parquet")


# Combine total mass of all planets into one column `planet<n>_m`
mass_regex = re.compile(r"planet(\d+)_m")
mass_cols = [col for col in df.columns if mass_regex.match(col)]
df["total_mass"] = df[mass_cols].sum(axis=1)
# df = df.reset_index(drop=True)
# Introduce categorical columns for the number of planets choose non null columns with mass
df["n_planets"] = df[mass_cols].notnull().sum(axis=1).astype("object")
df["n_planets"] = df["n_planets"].apply(lambda x: f"{x}_planets")
# Create category acceleration if the sum of plane/d_[x,y, z] is greater than 0
df["acceleration_x"] = df[
    [col for col in df.columns if "planet" in col and "_x" in col]
].sum(axis=1)
# Set acceleration_x to "increasing" if greater than 0 else "decreasing"
df["acceleration_x"] = (
    df["acceleration_x"]
    .apply(lambda x: "increasing" if x > 0 else "decreasing")
    .astype("object")
)
df["acceleration_y"] = df[
    [col for col in df.columns if "planet" in col and "_y" in col]
].sum(axis=1)
df["acceleration_y"] = df["acceleration_y"].apply(
    lambda x: "increasing" if x > 0 else "decreasing"
)


df.describe()

Unnamed: 0,idx,time_step,planet0_x,planet0_y,planet1_x,planet1_y,planet2_x,planet2_y,planet0_m,planet0_a,...,planet3_y,planet3_m,planet3_a,planet3_e,planet4_x,planet4_y,planet4_m,planet4_a,planet4_e,total_mass
count,5563957.0,5563957.0,5563957.0,5563957.0,5563957.0,5563957.0,4165044.0,4165044.0,5563957.0,5563957.0,...,2783627.0,2783627.0,2783627.0,2783627.0,1392864.0,1392864.0,1392864.0,1392864.0,1392864.0,5563957.0
mean,62486.35,9.748911,-0.1339198,0.07391138,-0.134014,0.07291389,-0.1305344,0.07065633,2.999306,1.624756,...,0.0655915,2.996303,1.623874,0.9980576,-0.1276881,0.06519469,3.002531,1.625815,1.001317,10.49149
std,36079.49,5.993534,1.228071,1.213232,1.22795,1.21265,1.217229,1.203678,1.157182,0.5876632,...,1.200148,1.15319,0.5270725,0.5764675,1.211648,1.199625,1.156856,0.5167198,0.5779763,3.99178
min,0.0,0.0,-3.294763,-2.997514,-3.284004,-2.998546,-3.28979,-2.99805,1.000003,1.0,...,-2.997621,1.000054,1.0,9.369537e-05,-3.273603,-2.998913,1.000103,1.0,6.720938e-05,2.014597
25%,31244.0,4.655172,-1.030131,-0.9020907,-1.030516,-0.9028009,-1.050662,-0.9211662,1.993948,1.0,...,-0.9321272,1.996853,1.191548,0.4980967,-1.071974,-0.9394428,2.00424,1.215927,0.5032645,7.282371
50%,62491.0,9.52381,-0.1542335,0.1117099,-0.1538916,0.1099474,-0.152552,0.1118031,2.994477,1.543047,...,0.1067355,3.001879,1.535683,0.995504,-0.1507184,0.1055276,3.000454,1.51904,1.003118,10.28205
75%,93728.0,14.4,0.8583344,0.9784987,0.8581045,0.9783998,0.8762358,0.9912902,4.005747,2.020672,...,0.9956064,3.986406,1.969644,1.497746,0.890678,0.999122,4.003332,1.950518,1.503457,13.45633
max,124999.0,24.0,2.99637,2.999014,2.993319,2.999536,2.990464,2.998478,4.999994,2.999984,...,3.000881,4.99999,2.999909,1.999957,2.98518,2.998936,4.999679,2.999497,1.999999,23.82455


In [7]:
df.head()

Unnamed: 0,idx,time_step,planet0_x,planet0_y,planet1_x,planet1_y,planet2_x,planet2_y,planet0_m,planet0_a,...,planet3_e,planet4_x,planet4_y,planet4_m,planet4_a,planet4_e,total_mass,n_planets,acceleration_x,acceleration_y
0,0,0.0,1.56006,-0.854437,0.720639,0.691729,0.944008,2.700632,1.312562,1.944263,...,,,,,,,6.974122,3_planets,increasing,increasing
1,0,0.465116,1.689858,-0.514359,0.333295,0.942289,0.681604,2.785811,1.312562,1.944263,...,,,,,,,6.974122,3_planets,increasing,increasing
2,0,0.930233,1.753589,-0.154209,-0.124995,0.992368,0.412951,2.845461,1.312562,1.944263,...,,,,,,,6.974122,3_planets,increasing,increasing
3,0,1.395349,1.748068,0.212022,-0.556775,0.831727,0.14054,2.879232,1.312562,1.944263,...,,,,,,,6.974122,3_planets,increasing,increasing
4,0,1.860465,1.673573,0.569904,-0.870579,0.494812,-0.133144,2.887018,1.312562,1.944263,...,,,,,,,6.974122,3_planets,increasing,increasing


In [8]:
df_categorical = df.select_dtypes(include=["object"]).astype(str)
unique_values_per_column = df_categorical.apply(
    pd.Series.unique
).values  # .flatten().tolist()
flattened_unique_values = np.concatenate(unique_values_per_column).tolist()
unique_values = list(set(flattened_unique_values))
unique_values

['5_planets',
 '2_planets',
 '3_planets',
 'decreasing',
 '4_planets',
 'increasing']

In [9]:
df.select_dtypes(include="object").groupby(
    df.select_dtypes(include="object").columns.tolist()
).size().reset_index(name="count")

Unnamed: 0,n_planets,acceleration_x,acceleration_y,count
0,2_planets,decreasing,decreasing,365720
1,2_planets,decreasing,increasing,410992
2,2_planets,increasing,decreasing,272240
3,2_planets,increasing,increasing,349961
4,3_planets,decreasing,decreasing,410092
5,3_planets,decreasing,increasing,404513
6,3_planets,increasing,decreasing,258536
7,3_planets,increasing,increasing,308276
8,4_planets,decreasing,decreasing,413451
9,4_planets,decreasing,increasing,416444


In [10]:
df = df.reset_index(drop=True)

In [11]:
# Get train test split at 80/20
time_series_config = hp.TimeSeriesConfig.generate(df=df)
train_idx = int(df.idx.max() * 0.8)
train_df = df.loc[df.idx < train_idx].copy()
test_df = df.loc[df.idx >= train_idx].copy()
# del df
train_ds = hp.TimeSeriesDS(train_df, time_series_config)
test_ds = hp.TimeSeriesDS(test_df, time_series_config)
len(train_ds), len(test_ds)

(99999, 25001)

In [12]:
len(time_series_config.numeric_col_tokens) + len(
    time_series_config.categorical_col_tokens
)

30

In [13]:
def make_batch(ds: hp.TimeSeriesDS, start: int, length: int):
    numeric = []
    categorical = []
    for i in range(start, length + start):
        numeric.append(ds[i][0])
        categorical.append(ds[i][1])
    # print index of None values
    return {"numeric": jnp.array(numeric), "categorical": jnp.array(categorical)}


batch = make_batch(train_ds, 0, 4)
# batch

In [14]:
multiplier = 4
time_series_regressor = hp.TimeSeriesDecoder(
    time_series_config, d_model=512, n_heads=8 * multiplier, rngs=nnx.Rngs(0)
)
# nnx.display(time_series_regressor)

In [15]:
res = time_series_regressor(
    numeric_inputs=batch["numeric"],
    categorical_inputs=batch["categorical"],
    deterministic=False,
)

ic| /Users/kailukowiak/Hephaestus/hephaestus/models/time_series_decoder.py:538 in process_numeric()
    numeric_embedding.shape: (512,)
ic| /Users/kailukowiak/Hephaestus/hephaestus/models/time_series_decoder.py:541 in process_numeric()
    numeric_embedding.shape: (4, 27, 59, 512)
ic| /Users/kailukowiak/Hephaestus/hephaestus/models/time_series_decoder.py:549 in process_numeric()
    numeric_embedding.shape: (4, 27, 59, 512)
ic| /Users/kailukowiak/Hephaestus/hephaestus/models/time_series_decoder.py:613 in combine_inputs()
    numeric.value_embeddings.shape: (4, 27, 59, 512)
    categorical.value_embeddings.shape: (4, 3, 59, 512)
ic| /Users/kailukowiak/Hephaestus/hephaestus/models/time_series_decoder.py:687 in __call__()
    combined_inputs.value_embeddings.shape: (4, 30, 59, 512)
    combined_inputs.column_embeddings.shape: (4, 30, 59, 512)
ic| /Users/kailukowiak/Hephaestus/hephaestus/models/time_series_decoder.py:355 in __call__()
    "Transformer Block": 'Transformer Block'
    q.shap

In [16]:
import orbax.checkpoint as ocp

ckpt_dir = ocp.test_utils.erase_and_create_empty("/tmp/my-checkpoints1/")
_, state = nnx.split(time_series_regressor)
state = state.to_pure_dict()
# nnx.display(state)

checkpointer = ocp.StandardCheckpointer()
checkpointer.save(ckpt_dir / "state", state)

In [20]:
nnx.display(state)

State({
  'categorical_dense1': {
    'bias': VariableState(
      type=Param,
      value=Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0.], dtype=float32)
    ),
    'kernel': VariableState(
      type=Param,
      value=Array([[ 0.03978157,  0.06490338,  0.04485879, ...,  0.01017239,
               0.05480217,  0.02496082],
             [ 0.05201926, -0.01393059,  0.01035063, ...,  0.0108574 ,
              -0.07468856,  0.0133664 ],
             [ 0.05493598, -0.02026741,  0.05329826, ..., -0.04810426,
              -0.01314572, -0.01126208],
             ...,
             [-0.06501927, -0.05778317, -0.0220789 , ...,  0.01792813,
               0.00671583, -0.0036991 ],
             [ 0.08548267, -0.03606193, -0.06776523, ...,  0.00459266,
               0.00335948, -0.01023912],
             [ 0.01790255, -0.0291177 , -0.01747387, ..., 

ERROR:absl:[process=0] Failed to run 1 Handler Commit operations or the Commit callback in background save thread, directory: /tmp/my-checkpoints1/state
Traceback (most recent call last):
  File "/Users/kailukowiak/Hephaestus/.venv/lib/python3.11/site-packages/orbax/checkpoint/_src/checkpointers/async_checkpointer.py", line 132, in _thread_func
    future.result()
  File "/Users/kailukowiak/Hephaestus/.venv/lib/python3.11/site-packages/orbax/checkpoint/future.py", line 78, in result
    f.result(timeout=time_remaining)
  File "/Users/kailukowiak/Hephaestus/.venv/lib/python3.11/site-packages/orbax/checkpoint/_src/serialization/type_handlers.py", line 250, in result
    return self._t.join(timeout=timeout)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/kailukowiak/Hephaestus/.venv/lib/python3.11/site-packages/orbax/checkpoint/future.py", line 62, in join
    raise self._exception
  File "/Users/kailukowiak/Hephaestus/.venv/lib/python3.11/site-packages/orbax/checkpoint/future.py"

In [18]:
def print_state_structure(state, path=[], depth=0):
    """Print the complete structure of the state dictionary with detailed type information."""
    if isinstance(state, dict):
        for key, value in state.items():
            current_path = path + [key]
            path_str = " -> ".join(str(p) for p in current_path)

            # Get more detailed type information
            if hasattr(value, "dtype"):
                type_info = f"{type(value)} with dtype: {value.dtype}"
            else:
                type_info = type(value)

            print("  " * depth + f"{path_str}: {type_info}")
            print_state_structure(value, current_path, depth + 1)
    elif isinstance(state, (list, tuple)):
        for i, value in enumerate(state):
            print_state_structure(value, path + [i], depth + 1)
    else:
        # For leaf nodes, print their type and value
        path_str = " -> ".join(str(p) for p in path)
        if hasattr(state, "dtype"):
            print("  " * depth + f"{path_str}: {type(state)} with dtype: {state.dtype}")
        else:
            print("  " * depth + f"{path_str}: {type(state)}")


# Use it
_, state = nnx.split(time_series_regressor)
state = state.to_pure_dict()
print("Full state structure:")
print_state_structure(state)

Full state structure:
categorical_dense1: <class 'dict'>
  categorical_dense1 -> bias: <class 'jaxlib.xla_extension.ArrayImpl'> with dtype: float32
    categorical_dense1 -> bias: <class 'jaxlib.xla_extension.ArrayImpl'> with dtype: float32
  categorical_dense1 -> kernel: <class 'jaxlib.xla_extension.ArrayImpl'> with dtype: float32
    categorical_dense1 -> kernel: <class 'jaxlib.xla_extension.ArrayImpl'> with dtype: float32
categorical_dense2: <class 'dict'>
  categorical_dense2 -> bias: <class 'jaxlib.xla_extension.ArrayImpl'> with dtype: float32
    categorical_dense2 -> bias: <class 'jaxlib.xla_extension.ArrayImpl'> with dtype: float32
  categorical_dense2 -> kernel: <class 'jaxlib.xla_extension.ArrayImpl'> with dtype: float32
    categorical_dense2 -> kernel: <class 'jaxlib.xla_extension.ArrayImpl'> with dtype: float32
config: <class 'dict'>
  config -> numeric_indices: <class 'jaxlib.xla_extension.ArrayImpl'> with dtype: int32
    config -> numeric_indices: <class 'jaxlib.xla_ext

In [17]:
def print_state_structure(state, path=[], depth=0):
    """Print the complete structure of the state dictionary."""
    if isinstance(state, dict):
        for key, value in state.items():
            current_path = path + [key]
            print("  " * depth + f"{key}: {type(value)}")
            print_state_structure(value, current_path, depth + 1)
    elif isinstance(state, (list, tuple)):
        for i, value in enumerate(state):
            print_state_structure(value, path + [i], depth + 1)


# Use it
_, state = nnx.split(time_series_regressor)
state = state.to_pure_dict()
print_state_structure(state)

categorical_dense1: <class 'dict'>
  bias: <class 'jaxlib.xla_extension.ArrayImpl'>
  kernel: <class 'jaxlib.xla_extension.ArrayImpl'>
categorical_dense2: <class 'dict'>
  bias: <class 'jaxlib.xla_extension.ArrayImpl'>
  kernel: <class 'jaxlib.xla_extension.ArrayImpl'>
config: <class 'dict'>
  numeric_indices: <class 'jaxlib.xla_extension.ArrayImpl'>
  categorical_indices: <class 'jaxlib.xla_extension.ArrayImpl'>
  reservoir_encoded: <class 'jaxlib.xla_extension.ArrayImpl'>
numeric_linear1: <class 'dict'>
  bias: <class 'jaxlib.xla_extension.ArrayImpl'>
  kernel: <class 'jaxlib.xla_extension.ArrayImpl'>
numeric_linear2: <class 'dict'>
  bias: <class 'jaxlib.xla_extension.ArrayImpl'>
  kernel: <class 'jaxlib.xla_extension.ArrayImpl'>
time_series_transformer: <class 'dict'>
  embedding: <class 'dict'>
    embedding: <class 'dict'>
      embedding: <class 'jaxlib.xla_extension.ArrayImpl'>
  transformer_block_0: <class 'dict'>
    feed_forward_network: <class 'dict'>
      dense1: <class '

In [18]:
res["numeric_out"].shape, res["categorical_out"].shape

((4, 27, 59), (4, 3, 59, 41))

In [19]:
ic.disable()

In [20]:
causal_mask = True
# time_series_regressor.train()

In [21]:
def add_input_offsets(
    inputs: jnp.array, outputs: jnp.array, inputs_offset: int = 1
) -> jnp.array:
    inputs = inputs[:, :, inputs_offset:]
    tmp_null = jnp.full((inputs.shape[0], inputs.shape[1], inputs_offset), jnp.nan)
    inputs = jnp.concatenate([inputs, tmp_null], axis=2)
    nan_mask = jnp.isnan(inputs)
    inputs = jnp.where(nan_mask, jnp.zeros_like(inputs), inputs)
    # Add ext
    if outputs.ndim == inputs.ndim + 1:
        nan_mask_expanded = jnp.expand_dims(nan_mask, axis=-1)
        nan_mask_expanded = jnp.broadcast_to(nan_mask_expanded, outputs.shape)
    else:
        nan_mask_expanded = nan_mask

    # Apply mask to outputs
    outputs = jnp.where(nan_mask_expanded, jnp.zeros_like(outputs), outputs)

    return inputs, outputs, nan_mask
    # outputs = jnp.where(nan_mask, jnp.zeros_like(outputs), outputs)

    # return inputs, outputs, nan_mask


def numeric_loss(inputs, outputs, input_offset: int = 1):
    # print("Doing Numeric Loss")
    inputs, outputs, nan_mask = add_input_offsets(
        inputs=inputs, outputs=outputs, inputs_offset=input_offset
    )
    # print(f"Numeric: {inputs.shape=}, {outputs.shape=}, {nan_mask.shape=}")
    # TODO make loss SSL for values greater than 0.5 and MSE for values less than 0.5
    raw_loss = jnp.abs(outputs - inputs)
    masked_loss = jnp.where(nan_mask, 0.0, raw_loss)
    # print(f"{masked_loss.shape=}, {nan_mask.shape=}")
    loss = masked_loss.sum() / (~nan_mask).sum()
    return loss


def categorical_loss(inputs, outputs, input_offset: int = 1):
    # print("Doing Categorical Loss")
    inputs, outputs, nan_mask = add_input_offsets(
        inputs=inputs, outputs=outputs, inputs_offset=input_offset
    )
    # print(f"Categorical: {inputs.shape=}, {outputs.shape=}, {nan_mask.shape=}")
    inputs = inputs.astype(jnp.int32)
    # print(f"Categorical Loss Shapes: {inputs.shape=}, {outputs.shape=}")
    raw_loss = optax.softmax_cross_entropy_with_integer_labels(outputs, inputs)
    masked_loss = jnp.where(nan_mask, 0.0, raw_loss).mean()
    return masked_loss


@nnx.jit
def train_step(
    model: hp.TimeSeriesDecoder,
    inputs: dict,
    optimizer: nnx.Optimizer,
    metrics: nnx.MultiMetric,
):
    def loss_fn(model):
        res = model(
            numeric_inputs=inputs["numeric"],
            categorical_inputs=inputs["categorical"],
            deterministic=False,
        )

        numeric_loss_value = numeric_loss(inputs["numeric"], res["numeric_out"])
        categorical_loss_value = categorical_loss(
            inputs["categorical"], res["categorical_out"]
        )
        loss = numeric_loss_value + categorical_loss_value
        return loss, (numeric_loss_value, categorical_loss_value)

    grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
    # grad_fn = nnx.value_and_grad(loss_fn, has_aux=False)
    # loss, grads = grad_fn(model)
    (loss, (numeric_loss_value, categorical_loss_value)), grads = grad_fn(model)
    metrics.update(
        loss=loss,
        numeric_loss=numeric_loss_value,
        categorical_loss=categorical_loss_value,
    )

    optimizer.update(grads)


metric_history = {
    "loss": [],
    "numeric_loss": [],
    "categorical_loss": [],
}

learning_rate = 1e-3
momentum = 0.9
optimizer = nnx.Optimizer(time_series_regressor, optax.adamw(learning_rate, momentum))
metrics = nnx.MultiMetric(
    loss=nnx.metrics.Average("loss"),
    categorical_loss=nnx.metrics.Average("categorical_loss"),
    numeric_loss=nnx.metrics.Average("numeric_loss"),
)
writer_name = "nnx_test"

writer_time = dt.now().strftime("%Y-%m-%dT%H:%M:%S")
model_name = writer_time + writer_name
summary_writer = SummaryWriter("runs/" + model_name)


metrics_history = {"train_loss": []}
train_data_loader = DataLoader(train_ds, batch_size=16, shuffle=True)

for step, batch in enumerate(tqdm(train_data_loader)):
    batch = {"numeric": jnp.array(batch[0]), "categorical": jnp.array(batch[1])}
    for metric, value in metrics.compute().items():
        # Only shows `loss`

        metric_history[metric].append(value)
        if not jnp.isnan(value).any():
            #     raise ValueError("Nan Values")
            summary_writer.add_scalar(f"train/{metric}", np.array(value), step)
    metrics.reset()
    # print(f"Step {step}")
    train_step(time_series_regressor, batch, optimizer, metrics)

  0%|          | 0/6250 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [31]:
import orbax.checkpoint as ocp

ckpt_dir = ocp.test_utils.erase_and_create_empty("/tmp/my-checkpoints1/")
_, state = nnx.split(time_series_regressor)
state = state.to_pure_dict()
# nnx.display(state)

checkpointer = ocp.StandardCheckpointer()
checkpointer.save(ckpt_dir / "state", state)

ERROR:absl:[process=0] Failed to run 1 Handler Commit operations or the Commit callback in background save thread, directory: /tmp/my-checkpoints1/state
Traceback (most recent call last):
  File "/Users/kailukowiak/Hephaestus/.venv/lib/python3.11/site-packages/orbax/checkpoint/_src/checkpointers/async_checkpointer.py", line 132, in _thread_func
    future.result()
  File "/Users/kailukowiak/Hephaestus/.venv/lib/python3.11/site-packages/orbax/checkpoint/future.py", line 78, in result
    f.result(timeout=time_remaining)
  File "/Users/kailukowiak/Hephaestus/.venv/lib/python3.11/site-packages/orbax/checkpoint/_src/serialization/type_handlers.py", line 250, in result
    return self._t.join(timeout=timeout)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/kailukowiak/Hephaestus/.venv/lib/python3.11/site-packages/orbax/checkpoint/future.py", line 62, in join
    raise self._exception
  File "/Users/kailukowiak/Hephaestus/.venv/lib/python3.11/site-packages/orbax/checkpoint/future.py"

In [32]:
# write out state to txt file
with open("State.txt", "w") as f:
    f.write(str(state))

In [None]:
import jax.numpy as jnp
import jax.tree_util as jtu


# Print the structure and dtypes
def print_dtype_tree(tree):
    def get_dtype(x):
        if hasattr(x, "dtype"):
            return x.dtype
        return type(x)

    return jtu.tree_map(get_dtype, tree)


print("State structure and dtypes:")
print(jtu.tree_map(lambda x: type(x), state))
print("\nDetailed dtype information:")
print(print_dtype_tree(state))

In [31]:
@dataclass
class Results:
    numeric_out: jnp.array
    categorical_out: jnp.array
    numeric_inputs: jnp.array
    categorical_inputs: jnp.array


def return_results(state, dataset, idx=0, mask_start: int = None):
    numeric_inputs, categorical_inputs = dataset[idx]
    if mask_start:
        numeric_inputs = numeric_inputs[:, :mask_start]
        categorical_inputs = categorical_inputs[:, :mask_start]
    numeric_inputs = jnp.array([numeric_inputs])
    categorical_inputs = jnp.array([categorical_inputs])
    out = time_series_regressor(
        numeric_inputs=numeric_inputs, categorical_inputs=categorical_inputs
    )
    numeric_out, categorical_out = out["numeric_out"], out["categorical_out"]
    return Results(numeric_out, categorical_out, numeric_inputs, categorical_inputs)


x = return_results(state, train_ds, 0)
x.categorical_out.shape

(1, 3, 59, 41)

In [32]:
causal_mask = False
causal_mask = True


def process_results(arr: jnp.array, col_names: list, config: hp.TimeSeriesConfig):
    arr = jnp.squeeze(arr)
    if arr.ndim == 3:
        # Check if there is a logit array for example if there are 3 dims then the
        # last dim is the logit array. We need to get the argmax of the last dim
        # to get the actual values of the array and replace the logit array with the
        # actual values
        arr = jnp.argmax(arr, axis=-1)
    df = pd.DataFrame(arr.T)
    df.columns = col_names
    return df


@dataclass
class DFComparison:
    input_df: pd.DataFrame
    output_df: pd.DataFrame


def show_results_df(
    state, time_series_config, dataset, idx: int = 0, mask_start: int = None
):
    results = return_results(state, dataset, idx=idx, mask_start=mask_start)

    input_categorical = process_results(
        results.categorical_inputs,
        time_series_config.categorical_col_tokens,
        time_series_config,
    )
    input_numeric = process_results(
        results.numeric_inputs,
        time_series_config.numeric_col_tokens,
        time_series_config,
    )
    output_categorical = process_results(
        results.categorical_out,
        time_series_config.categorical_col_tokens,
        time_series_config,
    )
    output_numeric = process_results(
        results.numeric_out, time_series_config.numeric_col_tokens, time_series_config
    )
    input_df = pd.concat([input_categorical, input_numeric], axis=1)
    output_df = pd.concat([output_categorical, output_numeric], axis=1)

    return DFComparison(input_df, output_df)


df_comp = show_results_df(
    state=state, time_series_config=time_series_config, dataset=train_ds, idx=0
)

In [33]:
df_comp.output_df.loc[:, time_series_config.categorical_col_tokens].tail()

Unnamed: 0,n_planets,acceleration_x,acceleration_y
54,35,32,32
55,36,37,37
56,36,32,32
57,36,32,32
58,36,37,37


In [34]:
df_comp.output_df.loc[:, time_series_config.categorical_col_tokens].tail()

Unnamed: 0,n_planets,acceleration_x,acceleration_y
54,35,32,32
55,36,37,37
56,36,32,32
57,36,32,32
58,36,37,37


In [35]:
def plot_planets(df_pred: pd.DataFrame, df_actual: pd.DataFrame, column: str, offset=0):
    plt.figure(figsize=(15, 10))
    plt.plot(df_pred[column], label="Autogregressive")
    plt.plot(df_actual[column], label="Actual")
    plt.title(f"{column} Predictions")
    plt.legend()
    # Show ticks and grid lines every 1 step
    plt.xticks(np.arange(0, len(df_pred), 1))
    plt.grid()
    # add black line at 0 on the y axis to show the difference
    plt.axhline(0, color="black")
    plt.show()

In [36]:
jnp.array([True, True, False, False, True]).shape

(5,)

In [38]:
x = jnp.ones((5, 20))
print(x.shape)
xx = x.at[jnp.array([True, True, False, False, True]), :].set(0)
xx

(5, 20)


Array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0.],
       [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0.]], dtype=float32)

In [41]:
@dataclass
class AutoRegressiveResults:
    numeric_inputs: jnp.array
    categorical_inputs: jnp.array

    @classmethod
    def from_ds(cls, ds: hp.TimeSeriesDS, idx: int, stop_idx: int = 10):
        inputs = ds[idx]
        numeric_inputs = inputs[0][:, :stop_idx]
        categorical_inputs = inputs[1][:, :stop_idx]
        return cls(numeric_inputs, categorical_inputs)


def auto_regressive_predictions(
    state: train_state.TrainState,
    inputs: AutoRegressiveResults,
) -> jnp.array:
    numeric_inputs = inputs.numeric_inputs
    categorical_inputs = inputs.categorical_inputs
    # get the first row that contains all nan vales
    # if nan_rows_start >= stop_idx:
    #     return inputs
    # numeric_inputs = inputs.numeric_inputs
    # categorical_inputs = inputs.categorical_inputs
    numeric_nan_columns = jnp.isnan(numeric_inputs).all(axis=1)
    categorical_nan_columns = jnp.isnan(categorical_inputs).all(axis=1)
    outputs = time_series_regressor(
        numeric_inputs=numeric_inputs, categorical_inputs=categorical_inputs
    )
    numeric_out = jnp.squeeze(outputs["numeric_out"])
    categorical_out = jnp.squeeze(outputs["categorical_out"])
    categorical_out = jnp.argmax(categorical_out, axis=-1)

    final_numeric_row = np.array(numeric_out[:, -1])
    final_numeric_row = final_numeric_row[:, None]  # New axis

    final_categorical_row = np.array(categorical_out[:, -1])
    final_categorical_row = final_categorical_row[:, None]  # New axis
    numeric_inputs = jnp.concatenate([numeric_inputs, final_numeric_row], axis=1)
    categorical_inputs = jnp.concatenate(
        [categorical_inputs, final_categorical_row], axis=1
    )
    numeric_inputs = numeric_inputs.at[jnp.array(numeric_nan_columns)].set(jnp.nan)
    categorical_inputs = categorical_inputs.at[jnp.array(categorical_nan_columns)].set(
        jnp.nan
    )
    inputs = (numeric_inputs, categorical_inputs)

    return inputs
    # return auto_regressive_predictions(state, inputs, stop_idx)

In [42]:
test_inputs = AutoRegressiveResults.from_ds(train_ds, 0, 10)
# inputs_test = train_ds[0]
# test_numeric = inputs_test[0]
# test_categorical = inputs_test[1]
# print(inputs_test.shape)
for i in trange(21):
    inputs_test = auto_regressive_predictions(state, test_inputs)

# x = auto_regressive_predictions(state, test_ds[0], 10)

  0%|          | 0/21 [00:00<?, ?it/s]

IndexError: tuple index out of range

In [None]:
# res = show_results_df(state, train_df, test_ds, idx=0, mask_start=30)

# plot_planets(res["pred"], res["actual_masked"], "planet2_x", offset=0)