<https://github.com/PolymathicAI/xVal>


In [1]:
import os

os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.95"

In [2]:
import icecream
from icecream import ic

icecream.install()
ic_disable = False  # Global variable to disable ic
if ic_disable:
    ic.disable()
ic.configureOutput(includeContext=True, contextAbsPath=True)

In [3]:
import ast
import re

from datetime import datetime as dt
from torch.utils.data import DataLoader
import torch
import seaborn as sns
import matplotlib.pyplot as plt
from hephaestus.models import simple_time_series_pt_claude as hp
import numpy as np
import pandas as pd
from icecream import ic

from torch.utils.tensorboard import SummaryWriter
from tqdm.notebook import tqdm, trange
from hephaestus.models.simple_time_series_pt_claude import SimpleDS

pd.options.mode.copy_on_write = True

In [4]:
def line2df(line, idx):
    data_rows = []
    line = ast.literal_eval(line)
    for i, time_step in enumerate(line["data"]):
        row = {"time_step": i}
        # Add position data for each planet
        for j, position in enumerate(time_step):
            row[f"planet{j}_x"] = position[0]
            row[f"planet{j}_y"] = position[1]
        data_rows.append(row)

    df = pd.DataFrame(data_rows)
    description = line.pop("description")
    step_size = description.pop("stepsize")
    for k, v in description.items():
        for k_prop, v_prop in v.items():
            df[f"{k}_{k_prop}"] = v_prop
    df["time_step"] = df["time_step"] * step_size
    df.insert(0, "idx", idx)

    return df

In [5]:
files = os.listdir("data")
if "planets.parquet" not in files:
    with open("data/planets.data") as f:
        data = f.read().splitlines()

        dfs = []
        for idx, line in enumerate(tqdm(data)):
            dfs.append(line2df(line, idx))
        print("Concatenating dfs...")
        df = pd.concat(dfs)
    df.to_parquet("data/planets.parquet")
else:
    df = pd.read_parquet("data/planets.parquet")


# Combine total mass of all planets into one column `planet<n>_m`
mass_regex = re.compile(r"planet(\d+)_m")
mass_cols = [col for col in df.columns if mass_regex.match(col)]
df["total_mass"] = df[mass_cols].sum(axis=1)

# Introduce categorical columns for the number of planets choose non null columns with mass
df["n_planets"] = df[mass_cols].notnull().sum(axis=1).astype("category")
# Create category acceleration if the sum of plane/d_[x,y, z] is greater than 0
df["acceleration_x"] = df[
    [col for col in df.columns if "planet" in col and "_x" in col]
].sum(axis=1)
# Set acceleration_x to "increasing" if greater than 0 else "decreasing"
df["acceleration_x"] = df["acceleration_x"].apply(
    lambda x: "increasing" if x > 0 else "decreasing"
)
df["acceleration_y"] = df[
    [col for col in df.columns if "planet" in col and "_y" in col]
].sum(axis=1)
df["acceleration_y"] = df["acceleration_y"].apply(
    lambda x: "increasing" if x > 0 else "decreasing"
)


df.describe()

Unnamed: 0,idx,time_step,planet0_x,planet0_y,planet1_x,planet1_y,planet2_x,planet2_y,planet0_m,planet0_a,...,planet3_y,planet3_m,planet3_a,planet3_e,planet4_x,planet4_y,planet4_m,planet4_a,planet4_e,total_mass
count,5563957.0,5563957.0,5563957.0,5563957.0,5563957.0,5563957.0,4165044.0,4165044.0,5563957.0,5563957.0,...,2783627.0,2783627.0,2783627.0,2783627.0,1392864.0,1392864.0,1392864.0,1392864.0,1392864.0,5563957.0
mean,62486.35,9.748911,-0.1339198,0.07391138,-0.134014,0.07291389,-0.1305344,0.07065633,2.999306,1.624756,...,0.0655915,2.996303,1.623874,0.9980576,-0.1276881,0.06519469,3.002531,1.625815,1.001317,10.49149
std,36079.49,5.993534,1.228071,1.213232,1.22795,1.21265,1.217229,1.203678,1.157182,0.5876632,...,1.200148,1.15319,0.5270725,0.5764675,1.211648,1.199625,1.156856,0.5167198,0.5779763,3.99178
min,0.0,0.0,-3.294763,-2.997514,-3.284004,-2.998546,-3.28979,-2.99805,1.000003,1.0,...,-2.997621,1.000054,1.0,9.369537e-05,-3.273603,-2.998913,1.000103,1.0,6.720938e-05,2.014597
25%,31244.0,4.655172,-1.030131,-0.9020907,-1.030516,-0.9028009,-1.050662,-0.9211662,1.993948,1.0,...,-0.9321272,1.996853,1.191548,0.4980967,-1.071974,-0.9394428,2.00424,1.215927,0.5032645,7.282371
50%,62491.0,9.52381,-0.1542335,0.1117099,-0.1538916,0.1099474,-0.152552,0.1118031,2.994477,1.543047,...,0.1067355,3.001879,1.535683,0.995504,-0.1507184,0.1055276,3.000454,1.51904,1.003118,10.28205
75%,93728.0,14.4,0.8583344,0.9784987,0.8581045,0.9783998,0.8762358,0.9912902,4.005747,2.020672,...,0.9956064,3.986406,1.969644,1.497746,0.890678,0.999122,4.003332,1.950518,1.503457,13.45633
max,124999.0,24.0,2.99637,2.999014,2.993319,2.999536,2.990464,2.998478,4.999994,2.999984,...,3.000881,4.99999,2.999909,1.999957,2.98518,2.998936,4.999679,2.999497,1.999999,23.82455


In [6]:
has_cuda = torch.cuda.is_available()
has_mps = torch.backends.mps.is_available()

if has_cuda:
    device = torch.device("cuda")
elif has_mps:
    device = torch.device("mps")
else:
    device = torch.device("cpu")

In [7]:
# Get train test split at 80/20
train_idx = int(df.idx.max() * 0.8)
train_df = df.loc[df.idx < train_idx].copy()
test_df = df.loc[df.idx >= train_idx].copy()
# del df

train_ds = SimpleDS(train_df, device)
test_ds = SimpleDS(test_df, device)
len(train_ds), len(test_ds)

(99999, 25001)

In [8]:
type(train_ds.numeric_indices)

torch.Tensor

In [9]:
test_df

Unnamed: 0,idx,time_step,planet0_x,planet0_y,planet1_x,planet1_y,planet2_x,planet2_y,planet0_m,planet0_a,...,planet3_e,planet4_x,planet4_y,planet4_m,planet4_a,planet4_e,total_mass,n_planets,acceleration_x,acceleration_y
0,0,0.000000,-0.274094,1.658928,-1.598680,1.237278,-0.072378,1.334127,3.092371,1.67039,...,0.265969,,,,,,10.247159,4,decreasing,increasing
1,0,0.733333,-0.810119,1.516448,-1.860540,0.797326,-0.675005,1.164327,3.092371,1.67039,...,0.265969,,,,,,10.247159,4,decreasing,increasing
2,0,1.466667,-1.261577,1.214381,-2.002381,0.305935,-1.131812,0.742120,3.092371,1.67039,...,0.265969,,,,,,10.247159,4,decreasing,increasing
3,0,2.200000,-1.587840,0.791168,-2.015313,-0.205141,-1.347517,0.161522,3.092371,1.67039,...,0.265969,,,,,,10.247159,4,decreasing,increasing
4,0,2.933333,-1.762252,0.291976,-1.898518,-0.702988,-1.278262,-0.453284,3.092371,1.67039,...,0.265969,,,,,,10.247159,4,decreasing,decreasing
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
39,25000,18.139535,0.859903,-0.357431,-1.780663,-0.841087,,,2.983244,1.00000,...,,,,,,,7.554701,2,decreasing,decreasing
40,25000,18.604651,0.917215,0.133881,-1.609198,-1.111904,,,2.983244,1.00000,...,,,,,,,7.554701,2,decreasing,decreasing
41,25000,19.069767,0.732278,0.590216,-1.391311,-1.350590,,,2.983244,1.00000,...,,,,,,,7.554701,2,decreasing,decreasing
42,25000,19.534884,0.360838,0.898121,-1.132250,-1.549225,,,2.983244,1.00000,...,,,,,,,7.554701,2,decreasing,decreasing


In [10]:
test_df.shape, train_df.shape

((1110975, 31), (4452982, 31))

In [11]:
df.idx.max()

124999

In [12]:
train_ds[0]

(tensor([[ 0.0000,  0.4651,  0.9302,  ...,     nan,     nan,     nan],
         [ 1.5601,  1.6899,  1.7536,  ...,     nan,     nan,     nan],
         [-0.8544, -0.5144, -0.1542,  ...,     nan,     nan,     nan],
         ...,
         [    nan,     nan,     nan,  ...,     nan,     nan,     nan],
         [    nan,     nan,     nan,  ...,     nan,     nan,     nan],
         [ 6.9741,  6.9741,  6.9741,  ...,     nan,     nan,     nan]],
        device='mps:0'),
 tensor([[33., 33., 33., 33., 33., 33., 35., 35., 35., 35., 35., 35., 35., 35.,
          35., 35., 35., 35., 35., 35., 35., 35., 35., 35., 35., 35., 35., 35.,
          35., 35., 35., 35., 35., 35., 35., 33., 33., 33., 33., 33., 33., 33.,
          33., 33., nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan],
         [33., 33., 33., 33., 33., 33., 33., 33., 33., 33., 33., 33., 33., 33.,
          33., 33., 33., 33., 33., 33., 35., 35., 35., 35., 35., 35., 35., 35.,
          35., 35., 35., 35.

In [13]:
def make_batch(ds: SimpleDS, start: int, length: int):
    # Get the shapes of numeric and categorical data
    numeric_shape = ds[start][0].shape
    categorical_shape = ds[start][1].shape if ds[start][1] is not None else None

    # Initialize tensors with the correct shape
    numeric = torch.empty((length,) + numeric_shape, dtype=torch.float32)
    categorical = (
        torch.empty((length,) + categorical_shape, dtype=torch.long)
        if categorical_shape is not None
        else None
    )

    for i in range(length):
        idx = start + i
        numeric[i] = ds[idx][0]
        if categorical is not None:
            categorical[i] = ds[idx][1]

    return {"numeric": numeric, "categorical": categorical}


# Usage
batch = make_batch(train_ds, 0, 4)

In [14]:
batch["categorical"]

tensor([[[33, 33, 33, 33, 33, 33, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35,
          35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35,
          35, 33, 33, 33, 33, 33, 33, 33, 33, 33,  0,  0,  0,  0,  0,  0,  0,
           0,  0,  0,  0,  0,  0,  0,  0],
         [33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33,
          33, 33, 33, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35,
          35, 35, 35, 35, 35, 35, 35, 35, 35, 35,  0,  0,  0,  0,  0,  0,  0,
           0,  0,  0,  0,  0,  0,  0,  0]],

        [[35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35,
          35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35,
          35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 33, 33, 33, 33, 33,
          33,  0,  0,  0,  0,  0,  0,  0],
         [33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33,
          33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 35, 35, 35, 35,
          3

In [15]:
# time_series_regressor = hp.simple_time_series.SimplePred(
#     train_ds, d_model=2048, n_heads=16 # large
# )
multiplier = 4
time_series_regressor = hp.SimplePred(train_ds, d_model=512, n_heads=8 * multiplier)

In [16]:
# x = torch.rand(4, 10)
# x[2, 2] = torch.nan
# y = torch.rand(15)

# # Reshape x and y to allow broadcasting
# x_reshaped = x.unsqueeze(-1)  # Shape becomes (4, 10, 1)
# y_reshaped = y.unsqueeze(0).unsqueeze(0)  # Shape becomes (1, 1, 15)

# # Multiply x and y
# result = x_reshaped * y_reshaped
# result
# nan_maks = torch.isnan(result)
# result[nan_maks] = torch.ones(15) * 10_1000
# result[nan_maks]

In [17]:
# # replace places a nan occurs with a tensor of 10 of shape 15
# result = torch.where(torch.isnan(result), torch.zeros_like(result), result)
# result.shape

# replacement = torch.rand(15) * 10_000  # This is your specific tensor of shape (15,)

# # Find the indices where the second dimension (dim=1) is all NaN
# nan_indices = torch.isnan(result).all(dim=2)

# # Replace the NaN slice with the replacement tensor
# result[nan_indices] = replacement

# # Verify the replacement
# print(result[2, 2, :])

In [18]:
ic("HI")

ic| 3461734802.py:1 in <module>- 'HI'


'HI'

In [19]:
time_series_regressor(batch["numeric"], batch["categorical"].to(torch.int))

ic| simple_time_series_pt_claude.py:299 in forward()
    x.shape: torch.Size([4, 27, 59, 512])
    self.encoding.shape: torch.Size([1, 10000, 512])


AssertionError: query should be unbatched 2D or batched 3D tensor but received 4-D query tensor

In [None]:
import torch.nn as nn

embedding = nn.Embedding(5, 10)
embedding(torch.tensor([0, 2, 3]))

tensor([[-0.3160, -0.8257, -0.2434,  1.6396,  0.1669, -0.7829, -1.1622,  1.1063,
          0.9746, -0.5865],
        [ 0.2052, -0.2699,  0.0855,  0.0163,  0.0386, -0.2648, -1.1139, -1.6070,
         -0.7035,  0.7924],
        [-1.5251,  0.7666,  0.9173,  0.8353, -0.3284,  0.0325, -0.4460, -2.0415,
         -0.0039,  0.5959]], grad_fn=<EmbeddingBackward0>)

In [None]:
type(my_tensor)

jaxlib.xla_extension.ArrayImpl

In [None]:
my_tensor.shape
my_tensor.unsqueeze(0).repeat(4, 0)

AttributeError: 'jaxlib.xla_extension.ArrayImpl' object has no attribute 'unsqueeze'

In [None]:
batch["categorical"]

Array([[[33., 33., 33., 33., 33., 33., 35., 35., 35., 35., 35., 35.,
         35., 35., 35., 35., 35., 35., 35., 35., 35., 35., 35., 35.,
         35., 35., 35., 35., 35., 35., 35., 35., 35., 35., 35., 33.,
         33., 33., 33., 33., 33., 33., 33., 33., nan, nan, nan, nan,
         nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [33., 33., 33., 33., 33., 33., 33., 33., 33., 33., 33., 33.,
         33., 33., 33., 33., 33., 33., 33., 33., 35., 35., 35., 35.,
         35., 35., 35., 35., 35., 35., 35., 35., 35., 35., 35., 35.,
         35., 35., 35., 35., 35., 35., 35., 35., nan, nan, nan, nan,
         nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan]],

       [[35., 35., 35., 35., 35., 35., 35., 35., 35., 35., 35., 35.,
         35., 35., 35., 35., 35., 35., 35., 35., 35., 35., 35., 35.,
         35., 35., 35., 35., 35., 35., 35., 35., 35., 35., 35., 35.,
         35., 35., 35., 35., 35., 35., 35., 35., 35., 35., 33., 33.,
         33., 33., 33., 33., nan, nan, n

In [None]:
test_arr = jnp.array([1.0, 2.0, 3.0, 4.0])
# Convert to int
test_arr = test_arr.astype(jnp.int32)
test_arr

Array([1, 2, 3, 4], dtype=int32)

In [None]:
key = random.PRNGKey(0)
init_key, dropout_key = random.split(key)
vars = time_series_regressor.init(
    {"params": init_key, "dropout": dropout_key},
    batch["numeric"],
    categorical_inputs=batch["categorical"].astype(jnp.int32),
    deterministic=False,
)
dropout_key, original_dropout_key = random.split(dropout_key)

ic| simple_time_series.py:292 in __call__()
    numeric_inputs.shape: (4, 27, 59)
ic| simple_time_series.py:294 in __call__()- 'Here Again???'
ic| simple_time_series.py:547 in __call__()
    "pe before tiling": 'pe before tiling'
    pe.shape: (1, 59, 512, 1)
ic| simple_time_series.py:549 in __call__()
    "pe after tiling": 'pe after tiling'
    pe.shape: (4, 59, 512, 27)
ic| simple_time_series.py:551 in __call__()
    "pe after transpose": 'pe after transpose'
    pe.shape: (4, 27, 59, 512)
ic| simple_time_series.py:555 in __call__()
    "PE Result shape": 'PE Result shape'
    result.shape: (4, 27, 59, 512)
ic| simple_time_series.py:352 in __call__()
    numeric_broadcast.shape: (4, 27, 59, 512)
    numeric_col_embeddings.shape: (4, 27, 59, 512)
ic| simple_time_series.py:364 in __call__()
    "Masking for categorical data": 'Masking for categorical data'
ic| simple_time_series.py:370 in __call__()
    mask_input.shape: (4, 29, 59)
ic| simple_time_series.py:377 in __call__()
    mask

In [None]:
df.shape

(5563957, 31)

In [None]:
# ic.disable()

In [None]:
x = time_series_regressor.apply(
    vars,
    batch["numeric"],
    batch["categorical"].astype(jnp.int32),
    deterministic=False,
    rngs={"dropout": dropout_key},
)
print(x.get("numeric_out").shape)
# Check if categorical input is None and print None or it's shape
print(x.get("categorical_out").shape if x.get("categorical_out") is not None else None)

ic| simple_time_series.py:292 in __call__()
    numeric_inputs.shape: (4, 27, 59)
ic| simple_time_series.py:294 in __call__()- 'Here Again???'
ic| simple_time_series.py:547 in __call__()
    "pe before tiling": 'pe before tiling'
    pe.shape: (1, 59, 512, 1)
ic| simple_time_series.py:549 in __call__()
    "pe after tiling": 'pe after tiling'
    pe.shape: (4, 59, 512, 27)
ic| simple_time_series.py:551 in __call__()
    "pe after transpose": 'pe after transpose'
    pe.shape: (4, 27, 59, 512)
ic| simple_time_series.py:555 in __call__()
    "PE Result shape": 'PE Result shape'
    result.shape: (4, 27, 59, 512)
ic| simple_time_series.py:352 in __call__()
    numeric_broadcast.shape: (4, 27, 59, 512)
    numeric_col_embeddings.shape: (4, 27, 59, 512)
ic| simple_time_series.py:364 in __call__()
    "Masking for categorical data": 'Masking for categorical data'
ic| simple_time_series.py:370 in __call__()
    mask_input.shape: (4, 29, 59)
ic| simple_time_series.py:377 in __call__()
    mask

(4, 27, 59)
(4, 2, 59)


In [None]:
len(train_ds.categorical_indices)

2

In [None]:
# time_series_regressor.tabulate(
#     {"params": init_key, "dropout": dropout_key},
#     batch["numeric"],
#     console_kwargs={"force_jupyter": True, "width": 120},
# )

In [None]:
def calculate_memory_footprint(params):
    """Calculate total memory footprint of JAX model parameters and total
    number of parameters."""
    total_bytes = 0
    # Flatten the parameter tree structure into a list of arrays
    flat_params, _ = tree_flatten(params)
    for param in flat_params:
        # Calculate bytes: number of elements * size of each element
        bytes_per_param = param.size * param.dtype.itemsize
        total_bytes += bytes_per_param
    return total_bytes


def count_parameters(params):
    return sum(jnp.prod(jnp.array(p.shape)) for p in jax.tree_util.tree_leaves(params))


mem = calculate_memory_footprint(vars)
total_params = count_parameters(vars)


print(f"Memory of custom: {mem / 1e6:.2f} MB with {total_params:,} parameters")

Memory of custom: 202.17 MB with 50,542,877 parameters


In [None]:
mts_root_key = random.PRNGKey(44)
mts_main_key, ts_params_key, ts_data_key = random.split(mts_root_key, 3)

mask_data = False


def clip_gradients(gradients, max_norm):
    total_norm = jnp.sqrt(sum(jnp.sum(jnp.square(grad)) for grad in gradients.values()))
    scale = max_norm / (total_norm + 1e-6)
    clipped_gradients = jax.tree_map(
        lambda grad: jnp.where(total_norm > max_norm, grad * scale, grad), gradients
    )
    return clipped_gradients


def add_time_shifts(inputs: jnp.array, outputs: jnp.array) -> jnp.array:
    inputs_offset = 1
    inputs = inputs[:, :, inputs_offset:]
    tmp_null = jnp.full((inputs.shape[0], inputs.shape[1], inputs_offset), jnp.nan)
    inputs = jnp.concatenate([inputs, tmp_null], axis=2)
    nan_mask = jnp.isnan(inputs)
    inputs = jnp.where(nan_mask, jnp.zeros_like(inputs), inputs)
    print(f"{outputs.shape=}, {inputs.shape=}")
    outputs = jnp.where(nan_mask, jnp.zeros_like(outputs), outputs)

    return inputs, outputs, nan_mask


def numeric_loss(inputs, outputs):
    inputs, outputs, nan_mask = add_time_shifts(inputs, outputs)
    # TODO make loss SSL for values greater than 0.5 and MSE for values less than 0.5
    raw_loss = jnp.abs(outputs - inputs)
    masked_loss = jnp.where(nan_mask, 0.0, raw_loss)
    loss = masked_loss.sum() / (~nan_mask).sum()
    return loss


def categorical_loss(inputs, outputs):
    inputs, outputs, nan_mask = add_time_shifts(inputs, outputs)

    raw_loss = optax.squared_error(outputs, inputs)
    masked_loss = jnp.where(nan_mask, 0.0, raw_loss).mean()
    return masked_loss


def base_loss(
    numeric_inputs,
    categorical_inputs,
    outputs,
):
    numeric_out = outputs["numeric_out"]
    categorical_out = outputs["categorical_out"]
    print("Base Loss", numeric_inputs.shape, numeric_out.shape)
    numeric = numeric_loss(numeric_inputs, numeric_out)
    categorical = categorical_loss(categorical_inputs, categorical_out)
    return numeric + categorical


def base_loss_old(inputs, outputs):
    """TODO HERE IS THE SHIT"""
    # Remove the first value and add a jnp.nan to the end
    # inputs = inputs * 3
    inputs_offset = 1
    inputs = inputs[:, :, inputs_offset:]
    print(f"Inputs shape: {inputs.shape=}")
    # Add a jnp.nan to the end
    temp_null = jnp.full((inputs.shape[0], inputs.shape[1], inputs_offset), jnp.nan)
    inputs = jnp.concatenate([inputs, temp_null], axis=2)
    print(f"Inputs shape after addition: {inputs.shape=}")
    nan_mask = jnp.isnan(inputs)
    inputs = jnp.where(nan_mask, jnp.zeros_like(inputs), inputs)

    # outputs = outputs[:, :, :-inputs_offset]
    outputs = jnp.where(nan_mask, jnp.zeros_like(outputs), outputs)

    # raw_loss = optax.squared_error(outputs, inputs)
    # compute manually
    # raw_loss = jnp.square(outputs - inputs)
    # Abs loss
    raw_loss = jnp.abs(outputs - inputs)
    masked_loss = jnp.where(nan_mask, 0.0, raw_loss)
    loss = masked_loss.sum() / (~nan_mask).sum()

    return loss


def calculate_loss(
    params,
    state,
    numeric_inputs,
    categorical_inputs,
    dropout_key,
    mask_data: bool = True,
):
    outputs = state.apply_fn(
        {"params": params},
        # hp.mask_tensor(inputs, dataset, prng_key=mask_key),
        numeric_inputs=numeric_inputs,
        categorical_inputs=categorical_inputs.astype(jnp.int32),
        rngs={"dropout": dropout_key},
        deterministic=False,
        mask_data=mask_data,
    )
    loss = base_loss(
        numeric_inputs=numeric_inputs,
        categorical_inputs=categorical_inputs,
        outputs=outputs,
    )
    # Create mask for nan inputs

    return loss


@jax.jit
def train_step(
    state: train_state.TrainState,
    numeric_inputs,
    categorical_inputs,
    base_key,
    # mask_data=True,
):
    # print("In train step")
    dropout_key, mask_key, new_key = jax.random.split(
        base_key, 3
    )  # TODO Figure out mask key
    # print("Making masks")
    # numeric_inputs = batch["numeric"]
    # categorical_inputs = batch["categorical"]
    # print("Made masks")

    def calculate_loss_with_mask(params):
        return calculate_loss(
            params,
            state,
            numeric_inputs=numeric_inputs,
            categorical_inputs=categorical_inputs,
            dropout_key=dropout_key,
            mask_data=True,
        )

    def loss_fn(params):
        return calculate_loss_with_mask(params)

    # def calculate_loss_without_mask(params):
    #     return calculate_loss(params, state, batch, dropout_key, mask_data=False)

    # def loss_fn(params):
    #     return jax.lax.cond(
    #         mask_data,
    #         lambda _: calculate_loss_with_mask(params),
    #         lambda _: calculate_loss_without_mask(params),
    #         operand=None,
    #     )

    # def loss_fn(params):
    #     return calculate_loss(params, state, batch, dropout_key, mask_data=mask_data)

    grad_fn = jax.value_and_grad(loss_fn)

    # (loss, individual_losses), grad = grad_fn(state.params)
    loss, grad = grad_fn(state.params)
    # grad = replace_nans(grad)
    # grad = clip_gradients(grad, 1.0)
    state = state.apply_gradients(grads=grad)

    return state, loss, new_key


def evaluate(params, state, inputs, mask_data: bool = True):
    outputs = state.apply_fn(
        {"params": params},
        # hp.mask_tensor(inputs, dataset, prng_key=mask_key),
        inputs,
        deterministic=True,
        mask_data=mask_data,
    )
    loss = base_loss(inputs, outputs)
    return loss


@jax.jit
def eval_step(
    state: train_state.TrainState, numeric_inputs, categorical_inputs, base_key
):
    # mask_data=True
    mask_key, dropout_key, new_key = jax.random.split(base_key, 3)

    def calculate_loss_with_mask(params):
        return calculate_loss(
            params,
            state,
            numeric_inputs=numeric_inputs,
            categorical_inputs=categorical_inputs,
            dropout_key=dropout_key,
            mask_data=True,
        )

    def calculate_loss_without_mask(params):
        return calculate_loss(
            params, state=state, batch=batch, dropout_key=dropout_key, mask_data=False
        )

    def loss_fn(params):
        return calculate_loss_with_mask(params)

    # TODO Reimplement this...
    # def loss_fn(params):
    #     return jax.lax.cond(
    #         mask_data,
    #         lambda _: calculate_loss_with_mask(params),
    #         lambda _: calculate_loss_without_mask(params),
    #         operand=None,
    #     )

    # def loss_fn(params):
    #     return evaluate(params, state, batch, mask_data=mask_data)

    # (loss, individual_losses), grad = grad_fn(state.params)
    loss = loss_fn(state.params)
    return loss, new_key


def create_train_state(model, prng, batch, lr):
    init_key, dropout_key = random.split(prng)
    params = model.init(
        {"params": init_key, "dropout": dropout_key},
        batch["numeric"],
        batch["categorical"],
        deterministic=False,
    )
    # optimizer = optax.chain(optax.adam(lr))
    optimizer = optax.chain(optax.clip_by_global_norm(0.4), optax.adam(lr))
    # optimizer_state = optimizer.init(params)
    return train_state.TrainState.create(
        apply_fn=model.apply,
        params=params["params"],
        tx=optimizer,
        # tx_state=optimizer_state,
    )


batch_size = 2
# batch = train_ds[0]
# state = create_train_state(time_series_regressor, mts_main_key, batch, 0.0001)
state = create_train_state(time_series_regressor, mts_main_key, batch, 0.0001)

ic| simple_time_series.py:292 in __call__()
    numeric_inputs.shape: (4, 27, 59)
ic| simple_time_series.py:294 in __call__()- 'Here Again???'
ic| simple_time_series.py:547 in __call__()
    "pe before tiling": 'pe before tiling'
    pe.shape: (1, 59, 512, 1)
ic| simple_time_series.py:549 in __call__()
    "pe after tiling": 'pe after tiling'
    pe.shape: (4, 59, 512, 27)
ic| simple_time_series.py:551 in __call__()
    "pe after transpose": 'pe after transpose'
    pe.shape: (4, 27, 59, 512)
ic| simple_time_series.py:555 in __call__()
    "PE Result shape": 'PE Result shape'
    result.shape: (4, 27, 59, 512)
ic| simple_time_series.py:352 in __call__()
    numeric_broadcast.shape: (4, 27, 59, 512)
    numeric_col_embeddings.shape: (4, 27, 59, 512)
ic| simple_time_series.py:364 in __call__()
    "Masking for categorical data": 'Masking for categorical data'
ic| simple_time_series.py:370 in __call__()
    mask_input.shape: (4, 29, 59)
ic| simple_time_series.py:377 in __call__()
    mask

In [None]:
writer_name = "BERT_Embeddings"

writer_time = dt.now().strftime("%Y-%m-%dT%H:%M:%S")
model_name = writer_time + writer_name
train_summary_writer = SummaryWriter("runs/" + model_name)

MASK_DATA = True

test_set_key = random.PRNGKey(4454)

batch_size = 16
train_data_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
test_data_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=True)

# train_data_loader = DataLoader(train_ds, batch_size=256 // 2, shuffle=True)
# test_data_loader = DataLoader(test_ds, batch_size=256 // 2, shuffle=True)

batch_count = 0
base_key = random.PRNGKey(42)

# Disable IC for training
max_iters = 200
ic.disable()
for j in trange(1, desc=f"epochs for {train_summary_writer.log_dir}"):
    # arrs = train_data_loader()
    for i in tqdm(train_data_loader, leave=False, desc="batches"):
        # for i in trange(len(pre_train) // batch_size, leave=False):
        # for i in trange(len(pre_train) // batch_size //10, leave=False):
        # batch = make_batch(train_ds, i[0], 4)
        state, loss, base_key = train_step(
            state,
            jnp.array(i[0]),
            jnp.array(i[1]),
            base_key,
            # mask_data=MASK_DATA,
        )
        if jnp.isnan(loss):
            raise ValueError("Nan Value in loss, stopping")
        batch_count += 1

        if batch_count % 1 == 0:
            train_summary_writer.add_scalar(
                "loss/loss", np.array(loss.item()), batch_count
            )
        if batch_count % 10 == 0:
            numeric_eval, categorical_eval = next(iter(test_data_loader))
            test_loss, base_key = eval_step(
                state,
                jnp.array(numeric_eval),
                jnp.array(categorical_eval),
                base_key,
                # mask_data=MASK_DATA,
            )
            train_summary_writer.add_scalar(
                "loss/test_loss", np.array(test_loss.item()), batch_count
            )
            train_summary_writer.flush()
        # if batch_count > 200:
        #     break
        if batch_count > max_iters:
            break

train_summary_writer.close()

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


epochs for runs/2024-08-27T20:23:29BERT_Embeddings:   0%|          | 0/1 [00:00<?, ?it/s]

batches:   0%|          | 0/6250 [00:00<?, ?it/s]

Base Loss (16, 27, 59) (16, 27, 59)
outputs.shape=(16, 27, 59), inputs.shape=(16, 27, 59)
outputs.shape=(16, 2, 59), inputs.shape=(16, 2, 59)
Base Loss (16, 27, 59) (16, 27, 59)
outputs.shape=(16, 27, 59), inputs.shape=(16, 27, 59)
outputs.shape=(16, 2, 59), inputs.shape=(16, 2, 59)


KeyboardInterrupt: 

In [None]:
len(i)

2

In [None]:
from flax.training import checkpoints
from flax.training import checkpoints


def save_model(model, optimizer, step, checkpoint_dir):
    checkpoint = {
        "model": model,
        "optimizer": optimizer,
        "step": step,
    }
    checkpoints.save_checkpoint(checkpoint_dir, checkpoint, step, keep=3)

In [None]:
checkpoint_dir = f"ckpt/{model_name}"
ckpt_dir = os.path.abspath(checkpoint_dir)

save_model(
    model=time_series_regressor,
    optimizer=state,
    step=batch_count,
    checkpoint_dir=ckpt_dir,
)

ValueError: TypeHandler lookup failed for: type=<class 'hephaestus.models.simple_time_series.SimplePred'>, keypath=(DictKey(key='model'),), ParamInfo=ParamInfo(name='model', path=PosixGPath('/Users/kailukowiak/Hephaestus/ckpt/2024-08-27T20:23:29BERT_Embeddings/checkpoint_13.orbax-checkpoint-tmp-0/model'), parent_dir=PosixGPath('/Users/kailukowiak/Hephaestus/ckpt/2024-08-27T20:23:29BERT_Embeddings/checkpoint_13.orbax-checkpoint-tmp-0'), skip_deserialize=False, byte_limiter=<orbax.checkpoint.serialization.LimitInFlightBytes object at 0x178054eb0>, is_ocdbt_checkpoint=True, use_zarr3=False, ocdbt_target_data_file_size=None, ts_context=<tensorstore.Context object at 0x403dccd30>, value_typestr='None'), RestoreArgs=SaveArgs(aggregate=False, dtype=None, write_chunk_shape=None, read_chunk_shape=None, chunk_byte_size=None), value=SimplePred(
    # attributes
    dataset = <hephaestus.models.simple_time_series.SimpleDS object at 0x33cd90040>
    d_model = 512
    n_heads = 32
)

In [None]:
# Save the model
ckpt_dir = f"ckpts/{model_name}/"
# make absolute path
ckpt_dir = os.path.abspath(ckpt_dir)

# if os.path.exists(ckpt_dir):
#     shutil.rmtree(ckpt_dir)

ckpt = {"state": state, "step": batch_count}
# Save to disk
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
save_args = orbax_utils.save_args_from_target(ckpt)
orbax_checkpointer.save(ckpt_dir, ckpt, save_args=save_args)

new_state = create_train_state(time_series_regressor, mts_main_key, batch, 0.0001)
# Load from disk
ckpt1 = orbax_checkpointer.restore(ckpt_dir)
new_state = new_state.replace(params=ckpt1["state"]["params"])



In [None]:
model_name = "2024-06-13T03:10:09MAE_Loss_Large"
ckpt_dir = f"ckpts/{model_name}/"
# make absolute path
ckpt_dir = os.path.abspath(ckpt_dir)
new_state = create_train_state(time_series_regressor, mts_main_key, batch, 0.0001)
# Load from disk
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()

ckpt1 = orbax_checkpointer.restore(ckpt_dir)
new_state = new_state.replace(params=ckpt1["state"]["params"])
# state = new_state

FileNotFoundError: Checkpoint at /Users/kailukowiak/Hephaestus/ckpts/2024-06-13T03:10:09MAE_Loss_Large not found.

In [None]:
def return_results(state, dataset, idx=0, mask_start: int = None):
    inputs = dataset[idx]
    if mask_start:
        inputs = inputs[:, :mask_start]
    inputs = jnp.array([inputs])
    outputs = state.apply_fn(
        {"params": state.params},
        # hp.mask_tensor(jnp.array([train_ds[0]]), dataset, prng_key=key),
        inputs,
        deterministic=True,
        mask_data=MASK_DATA,
    )
    return outputs, inputs

In [None]:
train_ds[0][0].shape, train_ds[0][1].shape

((27, 59), (2, 59))

In [None]:
mask_data = False
MASK_DATA = True


def show_results_df(state, base_df, dataset, idx: int = 0, mask_start: int = None):
    outputs, inputs = return_results(state, dataset, idx=idx, mask_start=mask_start)

    outputs = jnp.squeeze(outputs)
    df_pred = pd.DataFrame(outputs.T)
    df_pred.columns = base_df.columns[1:]

    inputs = jnp.squeeze(inputs)
    df_actual_masked = pd.DataFrame(inputs.T)
    df_actual_masked.columns = base_df.columns[1:]
    # remove the first row to match the prediction
    # df_actual_masked = df_actual_masked.iloc[1:].reset_index()
    diff_df = df_pred - df_actual_masked

    inputs_no_mask = jnp.array([dataset[idx]])
    df_no_mask = pd.DataFrame(jnp.squeeze(inputs_no_mask).T)
    df_no_mask.columns = base_df.columns[1:]
    # df_no_mask = df_no_mask.iloc[1:].reset_index()  # rm first row
    diff_df_no_mask = df_pred - df_no_mask
    return {
        "pred": df_pred,
        "actual_masked": df_actual_masked,
        "actual_no_mask": df_no_mask,
        "diff_masked": diff_df,
        "diff_no_mask": diff_df_no_mask,
    }


res = show_results_df(state, train_df, train_ds, idx=0, mask_start=10)

TypeError: tuple indices must be integers or slices, not tuple

In [None]:
def show_heatmap(df, title):
    """Shows heatmap for a dataframe
    excludes all columns that are only nan and all rows that are only nan"""

    df = df.dropna(axis=1, how="all")
    df = df.dropna(axis=0, how="all")
    plt.figure(figsize=(15, 10))
    cmap = sns.diverging_palette(220, 20, as_cmap=True)
    sns.heatmap(df, cmap=cmap, center=0, annot=True, fmt=".2f")
    plt.title(title)
    plt.show()


show_heatmap(res["diff_masked"], "Diff Masked")

NameError: name 'res' is not defined

In [None]:
show_heatmap(res["diff_no_mask"].head(14), "Diff Masked")

In [None]:
test_key = random.PRNGKey(4454)
x = jax.random.normal(test_key, (4, 26, 59, 256))

In [None]:
no_mask_out = state.apply_fn(
    {"params": state.params},
    # jnp.array([test_ds[0][:10, :]]),
    jnp.array([test_ds[0][:, :10]]),
    deterministic=True,
    mask_data=True,
)
mask_out = state.apply_fn(
    {"params": state.params},
    jnp.array([test_ds[0][:, :20]]),
    deterministic=True,
    mask_data=True,
)
mask_out_df = pd.DataFrame(jnp.squeeze(mask_out).T)
mask_out_df.columns = test_df.columns[1:]
no_mask_out_df = pd.DataFrame(jnp.squeeze(no_mask_out).T)
no_mask_out_df.columns = test_df.columns[1:]

test_diff = mask_out_df - no_mask_out_df
test_diff

In [None]:
def plot_planets(df_pred: pd.DataFrame, df_actual: pd.DataFrame, column: str, offset=0):
    plt.figure(figsize=(15, 10))
    plt.plot(df_pred[column], label="Autogregressive")
    plt.plot(df_actual[column], label="Actual")
    plt.title(f"{column} Predictions")
    plt.legend()
    # Show ticks and grid lines every 1 step
    plt.xticks(np.arange(0, len(df_pred), 1))
    plt.grid()
    # add black line at 0 on the y axis to show the difference
    plt.axhline(0, color="black")
    plt.show()

In [None]:
def auto_regressive_predictions(
    state: train_state.TrainState, inputs: jnp.ndarray
) -> np.ndarray:
    # get the first row that contains all nan vales
    # if nan_rows_start >= stop_idx:
    #     return inputs
    nan_columns = jnp.isnan(inputs).all(axis=1)
    outputs = state.apply_fn(
        {"params": state.params},
        jnp.array([inputs]),
        # jnp.array([inputs]),
        deterministic=True,
        mask_data=MASK_DATA,
    )
    outputs = jnp.squeeze(outputs)
    final_row = np.array(outputs[:, -1])
    final_row = final_row[:, None]
    inputs = jnp.concatenate([inputs, final_row], axis=1)
    inputs = np.array(inputs)
    inputs[nan_columns] = np.nan
    return inputs
    # return auto_regressive_predictions(state, inputs, stop_idx)

In [None]:
base_inputs = test_ds[300]
inputs_test = base_inputs[:, :10]
print(inputs_test.shape)
for i in trange(21):
    inputs_test = auto_regressive_predictions(state, inputs_test)

# x = auto_regressive_predictions(state, test_ds[0], 10)

In [None]:
df_auto = pd.DataFrame(inputs_test.T)
df_actual = pd.DataFrame(base_inputs.T)
df_auto.columns = train_df.columns[1:]
df_actual.columns = train_df.columns[1:]
df_diff = df_auto - df_actual

# Drop rows that are all nan
df_diff = df_diff.dropna(axis=0, how="all")

In [None]:
plot_planets(df_auto, df_actual, "time_step")

In [None]:
plot_planets(df_auto, df_actual, "planet3_x")

In [None]:
# inputs = jnp.array(next(iter(test_data_loader)))
inputs_test = make_batch(test_ds, 0, 1)

outputs = new_state.apply_fn(
    {"params": state.params},
    # hp.mask_tensor(inputs, dataset, prng_key=mask_key),
    inputs_test,
    deterministic=True,
    mask_data=True,
)
df_actual = pd.DataFrame(jnp.squeeze(inputs_test).T)
df_actual.columns = test_df.columns[1:]

df_pred = pd.DataFrame(jnp.squeeze(outputs).T)
df_pred.columns = test_df.columns[1:]
plot_planets(df_pred, df_actual, "time_step")
plot_planets(df_pred, df_actual, "planet2_y")

In [None]:
inputs_test.shape

In [None]:
show_heatmap(df_diff, "Auto Regressive Predictions")

In [None]:
# plot planet0_x from df_auto and df_actual


res = show_results_df(state, train_df, train_ds, idx=299, mask_start=20)
plot_planets(res["pred"], res["actual_masked"], "planet2_y", offset=0)
# plot_planets(df_auto, df_actual, "planet2_y")

In [None]:
# plot planet0_x from df_auto and df_actual
res = show_results_df(state, train_df, train_ds, idx=20, mask_start=20)
plot_planets(res["pred"], res["actual_no_mask"], "planet1_y", offset=0)

In [None]:
res = show_results_df(state, train_df, test_ds, idx=0, mask_start=30)

plot_planets(res["pred"], res["actual_masked"], "planet2_x", offset=0)

In [None]:
loss, key = eval_step(state, jnp.array(next(iter(test_data_loader))), base_key)
loss, key