<https://github.com/PolymathicAI/xVal>


In [1]:
import os

os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.95"

In [2]:
# conda install nltk

In [3]:
from datetime import datetime as dt
from typing import Union

import icecream
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import orbax
import orbax.checkpoint
import pandas as pd
from flax.struct import dataclass
from flax.training import orbax_utils, train_state
from icecream import ic
from jax import random
from jax.tree_util import tree_flatten
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm.notebook import tqdm, trange
from transformers import BertTokenizerFast, FlaxBertModel

import hephaestus as hp
import hephaestus.models.time_series_decoder_training as tsd
import hashlib
import random as py_random
import nltk
from nltk.corpus import words

# Predefined list of words


# Download the words corpus if not already downloaded
nltk.download("words")

icecream.install()
ic_disable = False  # Global variable to disable ic
if ic_disable:
    ic.disable()
ic.configureOutput(includeContext=True, contextAbsPath=True)
pd.options.mode.copy_on_write = True

[nltk_data] Downloading package words to /home/ubuntu/nltk_data...
[nltk_data]   Package words is already up-to-date!


In [4]:
word_list = words.words()

In [5]:
# Load pre-trained BERT model and tokenizer
model_name = "bert-base-uncased"
model = FlaxBertModel.from_pretrained(model_name)
tokenizer = BertTokenizerFast.from_pretrained(model_name)

# Get the embeddings matrix
embeddings = model.params["embeddings"]["word_embeddings"]["embedding"]

# Now you can access specific embeddings like this:
# For example, to get embeddings for tokens 23, 293, and 993:
selected_embeddings = jnp.take(embeddings, jnp.array([23, 293, 993]), axis=0)

# If you want to get embeddings for specific words:
words = ["hello", "world", "example"]
tokens = tokenizer.convert_tokens_to_ids(words)
word_embeddings = jnp.take(embeddings, jnp.array(tokens), axis=0)
word_embeddings.shape

Some weights of FlaxBertModel were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: {('pooler', 'dense', 'bias'), ('pooler', 'dense', 'kernel')}
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


(3, 768)

In [6]:
data_path = "data/Untouched Data - 1 Oct 24"
csvs = os.listdir(data_path)
dfs = []
for csv in csvs:
    df = pd.read_csv(os.path.join(data_path, csv))
    dfs.append(df)
df = pd.concat(dfs)
df.head()

Unnamed: 0,id,created_at,updated_at,deleted_at,date,in_amount,in_token,in_from,out_amount,out_token,...,exchange,type_id,transaction_hash,wallet_id,chain,usd_equivalent,user_id,additional_transaction_data,deleted,status
0,007aa1c9-d3a9-46bc-b953-5ea114774609,2024-10-01 20:03:44.465135+00,2024-10-01 20:03:44.465135+00,,2023-11-09T13:29:11-07:00,994.015602,USDC,0xa9d1e08c7793af67e9d92fe308d5697fb81d3e43,994.015602,USDC,...,,transfer,0xb9e455885d862e122f5e0ee69e347dd3da71b2ca3359...,6ee74231-8984-48e4-ad70-3bb8c8a8c1db,ethereum,,169885e1-baef-4163-aefe-47b664a1d375,,False,active
1,008be80c-9f5b-4660-aa57-507f783c2f80,2024-10-01 20:03:49.195505+00,2024-10-01 20:03:49.195505+00,,2021-06-07T00:23:50-06:00,63.510001,LQTY,0xffe5219a68f9e1a63b71a4c75de500ed864ae651,63.510001,LQTY,...,,transfer,0x6ba3cdbbc355a0d13bb551fd082e5cd17541695f1289...,6ee74231-8984-48e4-ad70-3bb8c8a8c1db,ethereum,,169885e1-baef-4163-aefe-47b664a1d375,,False,active
2,014cb49f-4d4c-47aa-8116-b54f005ef27f,2024-10-01 20:03:46.091891+00,2024-10-01 20:03:46.091891+00,,2021-07-26T08:38:34-06:00,10927.17396,USDC,0x503828976d22510aad0201ac7ec88293211d23da,10927.17396,USDC,...,,transfer,0x644190f200a1d6b9d6ac0001e5167ed261e4867eb96d...,6ee74231-8984-48e4-ad70-3bb8c8a8c1db,ethereum,,169885e1-baef-4163-aefe-47b664a1d375,,False,active
3,0195f0c8-026d-4b7a-abbe-8d1cd5e40589,2024-10-01 20:03:45.901012+00,2024-10-01 20:03:45.901012+00,,2021-08-06T11:10:20-06:00,18299.44618,USDC,0xffe5219a68f9e1a63b71a4c75de500ed864ae651,18299.44618,USDC,...,,transfer,0xba4d14e6b05051384750760fae78fbe0d4dadb46615d...,6ee74231-8984-48e4-ad70-3bb8c8a8c1db,ethereum,,169885e1-baef-4163-aefe-47b664a1d375,,False,active
4,022f5287-4346-4afa-8a16-75680253cab9,2024-10-01 20:03:44.465135+00,2024-10-01 20:03:44.465135+00,,2022-09-24T14:27:35-06:00,1.0,DFSFB22,0xac760d3ca949b3136d3e9ab645e9ec94ba3649f5,1.0,DFSFB22,...,,transfer,0xa2e6ac48f7a5b5a0357cff9d7f56da687102ebfc5f9a...,6ee74231-8984-48e4-ad70-3bb8c8a8c1db,ethereum,,169885e1-baef-4163-aefe-47b664a1d375,,False,active


In [7]:
datetime_cols = ["created_at", "updated_at", "date"]


def enrich_datetimes(df: pd.DataFrame, col: str):
    df[col] = pd.to_datetime(df[col], utc=True)

    df[f"{col}_year"] = df[col].dt.year
    df[f"{col}_month_sin"] = np.sin(2 * np.pi * df[col].dt.month / 12)
    df[f"{col}_month_cos"] = np.cos(2 * np.pi * df[col].dt.month / 12)
    df[f"{col}_day_sin"] = np.sin(2 * np.pi * df[col].dt.day / 31)
    df[f"{col}_day_cos"] = np.cos(2 * np.pi * df[col].dt.day / 31)
    df[f"{col}_hour_sin"] = np.sin(2 * np.pi * df[col].dt.hour / 24)
    df[f"{col}_hour_cos"] = np.cos(2 * np.pi * df[col].dt.hour / 24)
    df[f"{col}_minute_sin"] = np.sin(2 * np.pi * df[col].dt.minute / 60)
    df[f"{col}_minute_cos"] = np.cos(2 * np.pi * df[col].dt.minute / 60)
    df[f"{col}_second_sin"] = np.sin(2 * np.pi * df[col].dt.second / 60)
    df[f"{col}_second_cos"] = np.cos(2 * np.pi * df[col].dt.second / 60)

    return df.drop(columns=[col])


# for col in datetime_cols:
#     df = enrich_datetimes(df, col)

In [8]:
# Use the words corpus from nltk
# words_list = words.words()


def hash_to_words(crypto_hash: Union[str, np.nan]) -> Union[str, np.nan]:
    if crypto_hash is np.nan:
        return np.nan
    # Create a hash of the input
    hash_object = hashlib.sha256(crypto_hash.encode())
    hash_digest = hash_object.hexdigest()

    # Use the hash to select three words from the list
    py_random.seed(hash_digest)
    selected_words = py_random.sample(word_list, 3)

    return "hash " + " ".join(selected_words)


# Example usage
crypto_hash = "269CiLPaFK55QqiVbsJupN6BSPUHQ3x7kN6iHPhSMV2NDwHM2EhnwQ6hE6FEvbix6AVN2PLUMQyyhrKr2y514dRB"
print(hash_to_words(crypto_hash))

hash rabies placophoran isagogically


In [9]:
hash_columns = [
    "id",
    "in_from",
    "out_to",
    "fee_paid_by",
    "transaction_hash",
    "wallet_id",
    "user_id",
]
for col in hash_columns:
    df[col] = df[col].apply(hash_to_words)
df.head()

Unnamed: 0,id,deleted_at,in_amount,in_token,in_from,out_amount,out_token,out_to,fee_amount,fee_paid_by,...,date_month_sin,date_month_cos,date_day_sin,date_day_cos,date_hour_sin,date_hour_cos,date_minute_sin,date_minute_cos,date_second_sin,date_second_cos
0,hash auditor Ceyx ametria,,994.015602,USDC,hash synorchidism cercopid Chiasmodon,994.015602,USDC,hash gonimous premycotic weakling,,,...,-0.5,0.8660254,0.968077,-0.250653,-0.866025,0.5,0.104528,-0.994522,0.913545,0.406737
1,hash Nandina killeekillee whinner,,63.510001,LQTY,hash gonimous premycotic weakling,63.510001,LQTY,hash Goyetian intermission enunciatory,0.001218,hash gonimous premycotic weakling,...,1.224647e-16,-1.0,0.988468,0.151428,1.0,6.123234000000001e-17,0.669131,-0.743145,-0.866025,0.5
2,hash chrysohermidin accent bearlet,,10927.17396,USDC,hash Idalian Galician nymph,10927.17396,USDC,hash gonimous premycotic weakling,,,...,-0.5,-0.8660254,-0.848644,0.528964,-0.5,-0.8660254,-0.743145,-0.669131,-0.406737,-0.913545
3,hash Karroo teratism oboist,,18299.44618,USDC,hash gonimous premycotic weakling,18299.44618,USDC,hash coppering biconcave ramellose,0.004304,hash gonimous premycotic weakling,...,-0.8660254,-0.5,0.937752,0.347305,-0.965926,-0.258819,0.866025,0.5,0.866025,-0.5
4,hash Statehouse Stanly candlelighter,,1.0,DFSFB22,hash chipling tetraphyllous lemur,1.0,DFSFB22,hash gonimous premycotic weakling,,,...,-1.0,-1.83697e-16,-0.988468,0.151428,-0.866025,0.5,0.309017,-0.951057,-0.5,-0.866025


In [10]:
df.dtypes

id                              object
deleted_at                     float64
in_amount                      float64
in_token                        object
in_from                         object
out_amount                     float64
out_token                       object
out_to                          object
fee_amount                     float64
fee_paid_by                     object
exchange                       float64
type_id                         object
transaction_hash                object
wallet_id                       object
chain                           object
usd_equivalent                 float64
user_id                         object
additional_transaction_data    float64
deleted                           bool
status                          object
created_at_year                  int32
created_at_month_sin           float64
created_at_month_cos           float64
created_at_day_sin             float64
created_at_day_cos             float64
created_at_hour_sin      

In [11]:
obj_cols = df.select_dtypes(include=["object"]).columns
num_cols = df.select_dtypes(include=["number"]).columns
df[obj_cols] = df[obj_cols].fillna("missing")
df[num_cols] = df[num_cols].fillna(0)
df.head()

Unnamed: 0,id,deleted_at,in_amount,in_token,in_from,out_amount,out_token,out_to,fee_amount,fee_paid_by,...,date_month_sin,date_month_cos,date_day_sin,date_day_cos,date_hour_sin,date_hour_cos,date_minute_sin,date_minute_cos,date_second_sin,date_second_cos
0,hash auditor Ceyx ametria,0.0,994.015602,USDC,hash synorchidism cercopid Chiasmodon,994.015602,USDC,hash gonimous premycotic weakling,0.0,missing,...,-0.5,0.8660254,0.968077,-0.250653,-0.866025,0.5,0.104528,-0.994522,0.913545,0.406737
1,hash Nandina killeekillee whinner,0.0,63.510001,LQTY,hash gonimous premycotic weakling,63.510001,LQTY,hash Goyetian intermission enunciatory,0.001218,hash gonimous premycotic weakling,...,1.224647e-16,-1.0,0.988468,0.151428,1.0,6.123234000000001e-17,0.669131,-0.743145,-0.866025,0.5
2,hash chrysohermidin accent bearlet,0.0,10927.17396,USDC,hash Idalian Galician nymph,10927.17396,USDC,hash gonimous premycotic weakling,0.0,missing,...,-0.5,-0.8660254,-0.848644,0.528964,-0.5,-0.8660254,-0.743145,-0.669131,-0.406737,-0.913545
3,hash Karroo teratism oboist,0.0,18299.44618,USDC,hash gonimous premycotic weakling,18299.44618,USDC,hash coppering biconcave ramellose,0.004304,hash gonimous premycotic weakling,...,-0.8660254,-0.5,0.937752,0.347305,-0.965926,-0.258819,0.866025,0.5,0.866025,-0.5
4,hash Statehouse Stanly candlelighter,0.0,1.0,DFSFB22,hash chipling tetraphyllous lemur,1.0,DFSFB22,hash gonimous premycotic weakling,0.0,missing,...,-1.0,-1.83697e-16,-0.988468,0.151428,-0.866025,0.5,0.309017,-0.951057,-0.5,-0.866025


In [12]:
df.dtypes

id                              object
deleted_at                     float64
in_amount                      float64
in_token                        object
in_from                         object
out_amount                     float64
out_token                       object
out_to                          object
fee_amount                     float64
fee_paid_by                     object
exchange                       float64
type_id                         object
transaction_hash                object
wallet_id                       object
chain                           object
usd_equivalent                 float64
user_id                         object
additional_transaction_data    float64
deleted                           bool
status                          object
created_at_year                  int32
created_at_month_sin           float64
created_at_month_cos           float64
created_at_day_sin             float64
created_at_day_cos             float64
created_at_hour_sin      

In [13]:
df["idx"] = df.index // 10
df.head(20)

Unnamed: 0,id,deleted_at,in_amount,in_token,in_from,out_amount,out_token,out_to,fee_amount,fee_paid_by,...,date_month_cos,date_day_sin,date_day_cos,date_hour_sin,date_hour_cos,date_minute_sin,date_minute_cos,date_second_sin,date_second_cos,idx
0,hash auditor Ceyx ametria,0.0,994.015602,USDC,hash synorchidism cercopid Chiasmodon,994.015602,USDC,hash gonimous premycotic weakling,0.0,missing,...,0.8660254,0.9680771,-0.250653,-0.866025,0.5,0.104528,-0.994522,0.913545,0.4067366,0
1,hash Nandina killeekillee whinner,0.0,63.510001,LQTY,hash gonimous premycotic weakling,63.510001,LQTY,hash Goyetian intermission enunciatory,0.001218,hash gonimous premycotic weakling,...,-1.0,0.9884683,0.151428,1.0,6.123234000000001e-17,0.669131,-0.743145,-0.866025,0.5,0
2,hash chrysohermidin accent bearlet,0.0,10927.17396,USDC,hash Idalian Galician nymph,10927.17396,USDC,hash gonimous premycotic weakling,0.0,missing,...,-0.8660254,-0.8486443,0.528964,-0.5,-0.8660254,-0.743145,-0.669131,-0.406737,-0.9135455,0
3,hash Karroo teratism oboist,0.0,18299.44618,USDC,hash gonimous premycotic weakling,18299.44618,USDC,hash coppering biconcave ramellose,0.004304,hash gonimous premycotic weakling,...,-0.5,0.9377521,0.347305,-0.965926,-0.258819,0.866025,0.5,0.866025,-0.5,0
4,hash Statehouse Stanly candlelighter,0.0,1.0,DFSFB22,hash chipling tetraphyllous lemur,1.0,DFSFB22,hash gonimous premycotic weakling,0.0,missing,...,-1.83697e-16,-0.9884683,0.151428,-0.866025,0.5,0.309017,-0.951057,-0.5,-0.8660254,0
5,hash dihybridism crannock grig,0.0,12.372807,RLC,hash siphonogam Nicolaitanism sulfohalite,12.372807,RLC,hash gonimous premycotic weakling,0.0,missing,...,-0.8660254,-0.8486443,0.528964,-0.965926,0.258819,0.669131,-0.743145,0.5,0.8660254,0
6,hash stroth cheeseburger patternmaker,0.0,0.058025,WETH,hash gonimous premycotic weakling,0.058025,WETH,hash Sienese maharao cornetcy,0.002689,hash gonimous premycotic weakling,...,-0.5,0.5712682,0.820763,0.5,-0.8660254,0.743145,0.669131,0.743145,-0.6691306,0
7,hash resurrect phraseologically hydropneumatosis,0.0,0.0627,ETH,hash overdiversity overregularity lollopy,0.0627,ETH,hash gonimous premycotic weakling,0.0,missing,...,-0.5,0.9680771,-0.250653,-0.866025,0.5,-0.669131,-0.743145,-0.207912,0.9781476,0
8,hash exhaustive dephlegmator nonvesture,0.0,1412.919511,MATIC,hash quincuncial semipalmate unempaneled,1412.919511,MATIC,hash gonimous premycotic weakling,0.0,missing,...,-0.8660254,0.6513725,-0.758758,-0.965926,-0.258819,-0.5,-0.866025,-0.809017,0.5877853,0
9,hash mesoskelic underadmiral epileptoid,0.0,0.0,Friends,hash shapometer home micropipette,0.0,Friends,hash gonimous premycotic weakling,0.0,missing,...,0.5,0.1011683,-0.994869,-0.5,0.8660254,0.5,-0.866025,-0.669131,0.7431448,0


In [14]:
# Get train test split at 80/20
time_series_config = hp.TimeSeriesConfig.generate(df=df)
train_idx = int(df.idx.max() * 0.8)
train_df = df.loc[df.idx < train_idx].copy()
test_df = df.loc[df.idx >= train_idx].copy()
# del df
train_ds = hp.TimeSeriesDS(train_df, time_series_config)
test_ds = hp.TimeSeriesDS(test_df, time_series_config)
len(train_ds), len(test_ds)

(601, 152)

In [15]:
def make_batch(ds: hp.TimeSeriesDS, start: int, length: int):
    numeric = []
    categorical = []
    for i in range(start, length + start):
        numeric.append(ds[i][0])
        categorical.append(ds[i][1])
    # print index of None values
    return {"numeric": jnp.array(numeric), "categorical": jnp.array(categorical)}


batch = make_batch(train_ds, 0, 4)
# batch

In [16]:
multiplier = 4
time_series_regressor = hp.TimeSeriesDecoder(
    time_series_config, d_model=512, n_heads=8 * multiplier
)

In [17]:
key = random.PRNGKey(0)
init_key, dropout_key = random.split(key)
vars = time_series_regressor.init(
    {"params": init_key, "dropout": dropout_key},
    batch["numeric"],
    categorical_inputs=batch["categorical"].astype(jnp.int32),
    deterministic=False,
    causal_mask=False,
)
dropout_key, original_dropout_key = random.split(dropout_key)

ic| Error: Failed to access the underlying source code for analysis. Was ic() invoked in a REPL (e.g. from the command line), a frozen application (e.g. packaged with PyInstaller), or did the underlying source code change during execution?
ic| Error: Failed to access the underlying source code for analysis. Was ic() invoked in a REPL (e.g. from the command line), a frozen application (e.g. packaged with PyInstaller), or did the underlying source code change during execution?
ic| Error: Failed to access the underlying source code for analysis. Was ic() invoked in a REPL (e.g. from the command line), a frozen application (e.g. packaged with PyInstaller), or did the underlying source code change during execution?
ic| Error: Failed to access the underlying source code for analysis. Was ic() invoked in a REPL (e.g. from the command line), a frozen application (e.g. packaged with PyInstaller), or did the underlying source code change during execution?
ic| Error: Failed to access the underlyi

In [18]:
x = time_series_regressor.apply(
    vars,
    batch["numeric"],
    batch["categorical"].astype(jnp.int32),
    deterministic=False,
    rngs={"dropout": dropout_key},
)
print(x.get("numeric_out").shape)
# Check if categorical input is None and print None or it's shape
print(x.get("categorical_out").shape if x.get("categorical_out") is not None else None)

ic| Error: Failed to access the underlying source code for analysis. Was ic() invoked in a REPL (e.g. from the command line), a frozen application (e.g. packaged with PyInstaller), or did the underlying source code change during execution?
ic| Error: Failed to access the underlying source code for analysis. Was ic() invoked in a REPL (e.g. from the command line), a frozen application (e.g. packaged with PyInstaller), or did the underlying source code change during execution?
ic| Error: Failed to access the underlying source code for analysis. Was ic() invoked in a REPL (e.g. from the command line), a frozen application (e.g. packaged with PyInstaller), or did the underlying source code change during execution?
ic| Error: Failed to access the underlying source code for analysis. Was ic() invoked in a REPL (e.g. from the command line), a frozen application (e.g. packaged with PyInstaller), or did the underlying source code change during execution?
ic| Error: Failed to access the underlyi

(4, 40, 100)
(4, 12, 100, 45906)


In [19]:
x["numeric_out"]

Array([[[ 0.5391996 ,  0.25555545,  0.15995353, ..., -0.17448428,
          0.36823082, -0.5377096 ],
        [ 0.5936038 ,  0.6740494 ,  0.58564276, ..., -0.18124664,
          0.32163638, -0.26839712],
        [ 0.43207607, -0.23547438, -0.32862288, ..., -0.15446618,
          0.39610773, -0.38567042],
        ...,
        [ 0.6713651 ,  0.16688654,  0.7819479 , ...,  0.0260087 ,
         -0.24633735, -0.06804048],
        [-0.5502661 , -0.28047183,  0.13585824, ..., -0.6994779 ,
         -0.1427004 , -0.2607084 ],
        [ 0.65356493, -0.36569646, -0.2793846 , ..., -0.2888239 ,
          0.03901091, -0.37889773]],

       [[-0.61214757,  0.28611636, -0.09607197, ...,  0.04753201,
         -0.00873746, -0.22971836],
        [ 0.9134278 ,  0.09889889,  0.15836424, ...,  0.17750327,
          0.09653816, -0.4970063 ],
        [ 0.10563877, -0.05873691, -0.04184473, ...,  0.31792346,
          0.1956397 , -0.0652024 ],
        ...,
        [ 0.80708456,  0.73046404,  0.81317294, ...,  

In [20]:
def calculate_memory_footprint(params):
    """Calculate total memory footprint of JAX model parameters and total
    number of parameters."""
    total_bytes = 0
    # Flatten the parameter tree structure into a list of arrays
    flat_params, _ = tree_flatten(params)
    for param in flat_params:
        # Calculate bytes: number of elements * size of each element
        bytes_per_param = param.size * param.dtype.itemsize
        total_bytes += bytes_per_param
    return total_bytes


def count_parameters(params):
    return sum(jnp.prod(jnp.array(p.shape)) for p in jax.tree_util.tree_leaves(params))


mem = calculate_memory_footprint(vars)
total_params = count_parameters(vars)


print(f"Memory of custom: {mem / 1e6:.2f} MB with {total_params:,} parameters")

Memory of custom: 315.34 MB with 78,834,422 parameters


In [21]:
ic.disable()

In [22]:
mts_root_key = random.PRNGKey(44)
mts_main_key, ts_params_key, ts_data_key = random.split(mts_root_key, 3)

causal_mask = False


batch_size = 2

state = tsd.create_train_state(time_series_regressor, mts_main_key, batch, 0.0001)

In [23]:
@jax.jit
def my_train_step(
    state: train_state.TrainState,
    numerical_inputs,
    categorical_inputs,
    base_key,
):
    return tsd.train_step(
        state,
        numerical_inputs,
        categorical_inputs,
        base_key,
        input_offset=0,
    )

In [24]:
# causal_mask = False

: 

In [None]:
writer_name = "NansWTF"

writer_time = dt.now().strftime("%Y-%m-%dT%H:%M:%S")
model_name = writer_time + writer_name
train_summary_writer = SummaryWriter("runs/" + model_name)


test_set_key = random.PRNGKey(444)

batch_size = 2
train_data_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
test_data_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=True)

# train_data_loader = DataLoader(train_ds, batch_size=256 // 2, shuffle=True)
# test_data_loader = DataLoader(test_ds, batch_size=256 // 2, shuffle=True)

batch_count = 0
base_key = random.PRNGKey(42)

# Disable IC for training
max_iters = None
ic.disable()
for j in trange(1, desc=f"epochs for {train_summary_writer.log_dir}"):
    # arrs = train_data_loader()
    for i in tqdm(train_data_loader, leave=False, desc="batches"):
        # for i in trange(len(pre_train) // batch_size, leave=False):
        # for i in trange(len(pre_train) // batch_size //10, leave=False):
        # batch = make_batch(train_ds, i[0], 4)
        state, loss, base_key = my_train_step(
            state,
            jnp.array(i[0]),
            jnp.array(i[1]),
            base_key,
            # input_offset=0,
            # causal_mask=causal_mask,
        )
        if jnp.isnan(loss):
            raise ValueError("Nan Value in loss, stopping")
        batch_count += 1

        if batch_count % 1 == 0:
            train_summary_writer.add_scalar(
                "loss/loss", np.array(loss.item()), batch_count
            )
        if batch_count % 10 == 0:
            numeric_eval, categorical_eval = next(iter(test_data_loader))
            test_loss, base_key = tsd.eval_step(
                state,
                jnp.array(numeric_eval),
                jnp.array(categorical_eval),
                base_key,
                # causal_mask=causal_mask,
            )
            train_summary_writer.add_scalar(
                "loss/test_loss", np.array(test_loss.item()), batch_count
            )
            train_summary_writer.flush()
        # if batch_count > 200:
        #     break
        if not max_iters:
            continue
        else:
            if batch_count > max_iters:
                break

train_summary_writer.close()

epochs for runs/2024-10-10T03:40:16NansWTF:   0%|          | 0/1 [00:00<?, ?it/s]

batches:   0%|          | 0/301 [00:00<?, ?it/s]

F1010 03:40:51.596403    4144 reduction.cc:147] Check failed: LayoutUtil::IsMonotonicWithDim0Major( hero.instruction().operand(i)->shape().layout()) reduction-layout-normalizer must run before code generation


In [None]:
print("Done")

In [None]:
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()

In [None]:
ckpt = {"model": state, "step": batch_count}


checkpoint_dir = f"checkpoints/{model_name}"
checkpoint_dir = os.path.abspath(checkpoint_dir)

# os.makedirs(checkpoint_dir, exist_ok=True)

orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
save_args = orbax_utils.save_args_from_target(ckpt)
orbax_checkpointer.save(checkpoint_dir, ckpt, save_args=save_args)

In [None]:
# if model_name is None:
if True:
    all_checkpoints = os.listdir("checkpoints/")
    model_name = all_checkpoints
    all_checkpoints = sorted(all_checkpoints)
    model_name = all_checkpoints[-1]
    checkpoint_dir = f"checkpoints/{model_name}"
    checkpoint_dir = os.path.abspath(checkpoint_dir)

new_checkpoint = orbax_checkpointer.restore(checkpoint_dir)
new_state = tsd.create_train_state(time_series_regressor, mts_main_key, batch, 0.0001)
new_state = new_state.replace(params=new_checkpoint["model"]["params"])

In [None]:
numeric_inputs, categorical_inputs = train_ds[0]
numeric_inputs = jnp.array([numeric_inputs])
categorical_inputs = jnp.array([categorical_inputs])
test_results = state.apply_fn(
    {"params": state.params},
    # jnp.array(i[0]),
    # jnp.array(i[1]),
    numeric_inputs,
    categorical_inputs,
    deterministic=True,
    causal_mask=False,
)
test_results["numeric_out"].shape

In [None]:
@dataclass
class Results:
    numeric_out: jnp.array
    categorical_out: jnp.array
    numeric_inputs: jnp.array
    categorical_inputs: jnp.array


def return_results(state, dataset, idx=0, mask_start: int = None):
    numeric_inputs, categorical_inputs = dataset[idx]
    if mask_start:
        numeric_inputs = numeric_inputs[:, :mask_start]
        categorical_inputs = categorical_inputs[:, :mask_start]
    numeric_inputs = jnp.array([numeric_inputs])
    categorical_inputs = jnp.array([categorical_inputs])
    out = state.apply_fn(
        {"params": state.params},
        # hp.mask_tensor(jnp.array([train_ds[0]]), dataset, prng_key=key),
        numeric_inputs=numeric_inputs,
        categorical_inputs=categorical_inputs,
        deterministic=True,
        causal_mask=causal_mask,
    )
    numeric_out, categorical_out = out["numeric_out"], out["categorical_out"]
    return Results(numeric_out, categorical_out, numeric_inputs, categorical_inputs)


x = return_results(state, train_ds, 0)
x.categorical_out.shape

In [None]:
causal_mask = False
causal_mask = True


def process_results(arr: jnp.array, col_names: list, config: hp.TimeSeriesConfig):
    arr = jnp.squeeze(arr)
    if arr.ndim == 3:
        # Check if there is a logit array for example if there are 3 dims then the
        # last dim is the logit array. We need to get the argmax of the last dim
        # to get the actual values of the array and replace the logit array with the
        # actual values
        arr = jnp.argmax(arr, axis=-1)
    df = pd.DataFrame(arr.T)
    df.columns = col_names
    return df


@dataclass
class DFComparison:
    input_df: pd.DataFrame
    output_df: pd.DataFrame


def show_results_df(
    state, time_series_config, dataset, idx: int = 0, mask_start: int = None
):
    results = return_results(state, dataset, idx=idx, mask_start=mask_start)

    input_categorical = process_results(
        results.categorical_inputs,
        time_series_config.categorical_col_tokens,
        time_series_config,
    )
    input_numeric = process_results(
        results.numeric_inputs,
        time_series_config.numeric_col_tokens,
        time_series_config,
    )
    output_categorical = process_results(
        results.categorical_out,
        time_series_config.categorical_col_tokens,
        time_series_config,
    )
    output_numeric = process_results(
        results.numeric_out, time_series_config.numeric_col_tokens, time_series_config
    )
    input_df = pd.concat([input_categorical, input_numeric], axis=1)
    output_df = pd.concat([output_categorical, output_numeric], axis=1)

    return DFComparison(input_df, output_df)


df_comp = show_results_df(
    state=state, time_series_config=time_series_config, dataset=train_ds, idx=0
)

In [None]:
df_comp.output_df.loc[:, time_series_config.categorical_col_tokens].tail()

In [None]:
df_comp.output_df.loc[:, time_series_config.categorical_col_tokens].tail()

In [None]:
def plot_planets(df_pred: pd.DataFrame, df_actual: pd.DataFrame, column: str, offset=0):
    plt.figure(figsize=(15, 10))
    plt.plot(df_pred[column], label="Autogregressive")
    plt.plot(df_actual[column], label="Actual")
    plt.title(f"{column} Predictions")
    plt.legend()
    # Show ticks and grid lines every 1 step
    plt.xticks(np.arange(0, len(df_pred), 1))
    plt.grid()
    # add black line at 0 on the y axis to show the difference
    plt.axhline(0, color="black")
    plt.show()

In [None]:
jnp.array([True, True, False, False, True]).shape

In [None]:
x = jnp.ones((5, 20))
print(x.shape)
xx = x.at[jnp.array([True, True, False, False, True]), :].set(0)
xx

In [None]:
@dataclass
class AutoRegressiveResults:
    numeric_inputs: jnp.array
    categorical_inputs: jnp.array

    @classmethod
    def from_ds(cls, ds: hp.TimeSeriesDS, idx: int, stop_idx: int = 10):
        inputs = ds[idx]
        numeric_inputs = inputs[0][:, :stop_idx]
        categorical_inputs = inputs[1][:, :stop_idx]
        return cls(numeric_inputs, categorical_inputs)


def auto_regressive_predictions(
    state: train_state.TrainState,
    inputs: AutoRegressiveResults,
) -> jnp.array:
    numeric_inputs = inputs.numeric_inputs
    categorical_inputs = inputs.categorical_inputs
    # get the first row that contains all nan vales
    # if nan_rows_start >= stop_idx:
    #     return inputs
    # numeric_inputs = inputs.numeric_inputs
    # categorical_inputs = inputs.categorical_inputs
    numeric_nan_columns = jnp.isnan(numeric_inputs).all(axis=1)
    categorical_nan_columns = jnp.isnan(categorical_inputs).all(axis=1)
    outputs = state.apply_fn(
        {"params": state.params},
        numeric_inputs=jnp.array([numeric_inputs]),
        categorical_inputs=jnp.array([categorical_inputs]),
        deterministic=True,
        causal_mask=False,
    )
    numeric_out = jnp.squeeze(outputs["numeric_out"])
    categorical_out = jnp.squeeze(outputs["categorical_out"])
    categorical_out = jnp.argmax(categorical_out, axis=-1)

    final_numeric_row = np.array(numeric_out[:, -1])
    final_numeric_row = final_numeric_row[:, None]  # New axis

    final_categorical_row = np.array(categorical_out[:, -1])
    final_categorical_row = final_categorical_row[:, None]  # New axis
    numeric_inputs = jnp.concatenate([numeric_inputs, final_numeric_row], axis=1)
    categorical_inputs = jnp.concatenate(
        [categorical_inputs, final_categorical_row], axis=1
    )
    numeric_inputs = numeric_inputs.at[jnp.array(numeric_nan_columns)].set(jnp.nan)
    categorical_inputs = categorical_inputs.at[jnp.array(categorical_nan_columns)].set(
        jnp.nan
    )
    inputs = (numeric_inputs, categorical_inputs)

    return inputs
    # return auto_regressive_predictions(state, inputs, stop_idx)

In [None]:
test_inputs = AutoRegressiveResults.from_ds(train_ds, 0, 10)
# inputs_test = train_ds[0]
# test_numeric = inputs_test[0]
# test_categorical = inputs_test[1]
# print(inputs_test.shape)
for i in trange(21):
    inputs_test = auto_regressive_predictions(state, test_inputs)

# x = auto_regressive_predictions(state, test_ds[0], 10)

In [None]:
# res = show_results_df(state, train_df, test_ds, idx=0, mask_start=30)

# plot_planets(res["pred"], res["actual_masked"], "planet2_x", offset=0)