<https://github.com/PolymathicAI/xVal>


In [1]:
import os

os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.95"

In [2]:
import ast
import re
from datetime import datetime as dt

import icecream
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import orbax
import orbax.checkpoint
import pandas as pd
import seaborn as sns
from flax.struct import dataclass
from flax.training import orbax_utils, train_state
from icecream import ic
from jax import random
from jax.tree_util import tree_flatten
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm.notebook import tqdm, trange
from transformers import BertTokenizerFast, FlaxBertModel

import hephaestus as hp
import hephaestus.models.time_series_decoder_training as tsd

icecream.install()
ic_disable = False  # Global variable to disable ic
if ic_disable:
    ic.disable()
ic.configureOutput(includeContext=True, contextAbsPath=True)
pd.options.mode.copy_on_write = True

In [3]:
# Load pre-trained BERT model and tokenizer
model_name = "bert-base-uncased"
model = FlaxBertModel.from_pretrained(model_name)
tokenizer = BertTokenizerFast.from_pretrained(model_name)

# Get the embeddings matrix
embeddings = model.params["embeddings"]["word_embeddings"]["embedding"]

# Now you can access specific embeddings like this:
# For example, to get embeddings for tokens 23, 293, and 993:
selected_embeddings = jnp.take(embeddings, jnp.array([23, 293, 993]), axis=0)

# If you want to get embeddings for specific words:
words = ["hello", "world", "example"]
tokens = tokenizer.convert_tokens_to_ids(words)
word_embeddings = jnp.take(embeddings, jnp.array(tokens), axis=0)
word_embeddings.shape

Some weights of FlaxBertModel were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: {('pooler', 'dense', 'bias'), ('pooler', 'dense', 'kernel')}
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


(3, 768)

In [4]:
def line2df(line, idx):
    data_rows = []
    line = ast.literal_eval(line)
    for i, time_step in enumerate(line["data"]):
        row = {"time_step": i}
        # Add position data for each planet
        for j, position in enumerate(time_step):
            row[f"planet{j}_x"] = position[0]
            row[f"planet{j}_y"] = position[1]
        data_rows.append(row)

    df = pd.DataFrame(data_rows)
    description = line.pop("description")
    step_size = description.pop("stepsize")
    for k, v in description.items():
        for k_prop, v_prop in v.items():
            df[f"{k}_{k_prop}"] = v_prop
    df["time_step"] = df["time_step"] * step_size
    df.insert(0, "idx", idx)

    return df

In [5]:
files = os.listdir("data")
if "planets.parquet" not in files:
    with open("data/planets.data") as f:
        data = f.read().splitlines()

        dfs = []
        for idx, line in enumerate(tqdm(data)):
            dfs.append(line2df(line, idx))
        df = pd.concat(dfs)
    df.to_parquet("data/planets.parquet")
else:
    df = pd.read_parquet("data/planets.parquet")


# Combine total mass of all planets into one column `planet<n>_m`
mass_regex = re.compile(r"planet(\d+)_m")
mass_cols = [col for col in df.columns if mass_regex.match(col)]
df["total_mass"] = df[mass_cols].sum(axis=1)

# Introduce categorical columns for the number of planets choose non null columns with mass
df["n_planets"] = df[mass_cols].notnull().sum(axis=1).astype("object")
df["n_planets"] = df["n_planets"].apply(lambda x: f"{x}_planets")
# Create category acceleration if the sum of plane/d_[x,y, z] is greater than 0
df["acceleration_x"] = df[
    [col for col in df.columns if "planet" in col and "_x" in col]
].sum(axis=1)
# Set acceleration_x to "increasing" if greater than 0 else "decreasing"
df["acceleration_x"] = (
    df["acceleration_x"]
    .apply(lambda x: "increasing" if x > 0 else "decreasing")
    .astype("object")
)
df["acceleration_y"] = df[
    [col for col in df.columns if "planet" in col and "_y" in col]
].sum(axis=1)
df["acceleration_y"] = df["acceleration_y"].apply(
    lambda x: "increasing" if x > 0 else "decreasing"
)


df.describe()

Unnamed: 0,idx,time_step,planet0_x,planet0_y,planet1_x,planet1_y,planet2_x,planet2_y,planet0_m,planet0_a,...,planet3_y,planet3_m,planet3_a,planet3_e,planet4_x,planet4_y,planet4_m,planet4_a,planet4_e,total_mass
count,5563957.0,5563957.0,5563957.0,5563957.0,5563957.0,5563957.0,4165044.0,4165044.0,5563957.0,5563957.0,...,2783627.0,2783627.0,2783627.0,2783627.0,1392864.0,1392864.0,1392864.0,1392864.0,1392864.0,5563957.0
mean,62486.35,9.748911,-0.1339198,0.07391138,-0.134014,0.07291389,-0.1305344,0.07065633,2.999306,1.624756,...,0.0655915,2.996303,1.623874,0.9980576,-0.1276881,0.06519469,3.002531,1.625815,1.001317,10.49149
std,36079.49,5.993534,1.228071,1.213232,1.22795,1.21265,1.217229,1.203678,1.157182,0.5876632,...,1.200148,1.15319,0.5270725,0.5764675,1.211648,1.199625,1.156856,0.5167198,0.5779763,3.99178
min,0.0,0.0,-3.294763,-2.997514,-3.284004,-2.998546,-3.28979,-2.99805,1.000003,1.0,...,-2.997621,1.000054,1.0,9.369537e-05,-3.273603,-2.998913,1.000103,1.0,6.720938e-05,2.014597
25%,31244.0,4.655172,-1.030131,-0.9020907,-1.030516,-0.9028009,-1.050662,-0.9211662,1.993948,1.0,...,-0.9321272,1.996853,1.191548,0.4980967,-1.071974,-0.9394428,2.00424,1.215927,0.5032645,7.282371
50%,62491.0,9.52381,-0.1542335,0.1117099,-0.1538916,0.1099474,-0.152552,0.1118031,2.994477,1.543047,...,0.1067355,3.001879,1.535683,0.995504,-0.1507184,0.1055276,3.000454,1.51904,1.003118,10.28205
75%,93728.0,14.4,0.8583344,0.9784987,0.8581045,0.9783998,0.8762358,0.9912902,4.005747,2.020672,...,0.9956064,3.986406,1.969644,1.497746,0.890678,0.999122,4.003332,1.950518,1.503457,13.45633
max,124999.0,24.0,2.99637,2.999014,2.993319,2.999536,2.990464,2.998478,4.999994,2.999984,...,3.000881,4.99999,2.999909,1.999957,2.98518,2.998936,4.999679,2.999497,1.999999,23.82455


In [6]:
df.select_dtypes(include="object").groupby(
    df.select_dtypes(include="object").columns.tolist()
).size().reset_index(name="count")

Unnamed: 0,n_planets,acceleration_x,acceleration_y,count
0,2_planets,decreasing,decreasing,365720
1,2_planets,decreasing,increasing,410992
2,2_planets,increasing,decreasing,272240
3,2_planets,increasing,increasing,349961
4,3_planets,decreasing,decreasing,410092
5,3_planets,decreasing,increasing,404513
6,3_planets,increasing,decreasing,258536
7,3_planets,increasing,increasing,308276
8,4_planets,decreasing,decreasing,413451
9,4_planets,decreasing,increasing,416444


In [7]:
# Get train test split at 80/20
time_series_config = hp.TimeSeriesConfig.generate(df=df)
train_idx = int(df.idx.max() * 0.8)
train_df = df.loc[df.idx < train_idx].copy()
test_df = df.loc[df.idx >= train_idx].copy()
# del df
train_ds = hp.TimeSeriesDS(train_df, time_series_config)
test_ds = hp.TimeSeriesDS(test_df, time_series_config)
len(train_ds), len(test_ds)

TypeError: unhashable type: 'numpy.ndarray'

In [8]:
def make_batch(ds: hp.TimeSeriesDS, start: int, length: int):
    numeric = []
    categorical = []
    for i in range(start, length + start):
        numeric.append(ds[i][0])
        categorical.append(ds[i][1])
    # print index of None values
    return {"numeric": jnp.array(numeric), "categorical": jnp.array(categorical)}


batch = make_batch(train_ds, 0, 4)
# batch

In [9]:
multiplier = 4
time_series_regressor = hp.TimeSeriesDecoder(
    time_series_config, d_model=512, n_heads=8 * multiplier
)

In [10]:
key = random.PRNGKey(0)
init_key, dropout_key = random.split(key)
vars = time_series_regressor.init(
    {"params": init_key, "dropout": dropout_key},
    batch["numeric"],
    categorical_inputs=batch["categorical"].astype(jnp.int32),
    deterministic=False,
)
dropout_key, original_dropout_key = random.split(dropout_key)

In [11]:
x = time_series_regressor.apply(
    vars,
    batch["numeric"],
    batch["categorical"].astype(jnp.int32),
    deterministic=False,
    rngs={"dropout": dropout_key},
)
print(x.get("numeric_out").shape)
# Check if categorical input is None and print None or it's shape
print(x.get("categorical_out").shape if x.get("categorical_out") is not None else None)

(4, 27, 59)
(4, 2, 59, 38)


In [12]:
def calculate_memory_footprint(params):
    """Calculate total memory footprint of JAX model parameters and total
    number of parameters."""
    total_bytes = 0
    # Flatten the parameter tree structure into a list of arrays
    flat_params, _ = tree_flatten(params)
    for param in flat_params:
        # Calculate bytes: number of elements * size of each element
        bytes_per_param = param.size * param.dtype.itemsize
        total_bytes += bytes_per_param
    return total_bytes


def count_parameters(params):
    return sum(jnp.prod(jnp.array(p.shape)) for p in jax.tree_util.tree_leaves(params))


mem = calculate_memory_footprint(vars)
total_params = count_parameters(vars)


print(f"Memory of custom: {mem / 1e6:.2f} MB with {total_params:,} parameters")

Memory of custom: 141.42 MB with 35,355,005 parameters


In [13]:
ic.disable()

In [14]:
mts_root_key = random.PRNGKey(44)
mts_main_key, ts_params_key, ts_data_key = random.split(mts_root_key, 3)

mask_data = False


batch_size = 2

state = tsd.create_train_state(time_series_regressor, mts_main_key, batch, 0.0001)

In [15]:
MASK_DATA = True

In [16]:
writer_name = "NewCategorical"

writer_time = dt.now().strftime("%Y-%m-%dT%H:%M:%S")
model_name = writer_time + writer_name
train_summary_writer = SummaryWriter("runs/" + model_name)


test_set_key = random.PRNGKey(4454)

batch_size = 32
train_data_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
test_data_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=True)

# train_data_loader = DataLoader(train_ds, batch_size=256 // 2, shuffle=True)
# test_data_loader = DataLoader(test_ds, batch_size=256 // 2, shuffle=True)

batch_count = 0
base_key = random.PRNGKey(42)

# Disable IC for training
max_iters = None
ic.disable()
for j in trange(1, desc=f"epochs for {train_summary_writer.log_dir}"):
    # arrs = train_data_loader()
    for i in tqdm(train_data_loader, leave=False, desc="batches"):
        # for i in trange(len(pre_train) // batch_size, leave=False):
        # for i in trange(len(pre_train) // batch_size //10, leave=False):
        # batch = make_batch(train_ds, i[0], 4)
        state, loss, base_key = tsd.train_step(
            state,
            jnp.array(i[0]),
            jnp.array(i[1]),
            base_key,
            # mask_data=MASK_DATA,
        )
        if jnp.isnan(loss):
            raise ValueError("Nan Value in loss, stopping")
        batch_count += 1

        if batch_count % 1 == 0:
            train_summary_writer.add_scalar(
                "loss/loss", np.array(loss.item()), batch_count
            )
        if batch_count % 10 == 0:
            numeric_eval, categorical_eval = next(iter(test_data_loader))
            test_loss, base_key = tsd.eval_step(
                state,
                jnp.array(numeric_eval),
                jnp.array(categorical_eval),
                base_key,
                # mask_data=MASK_DATA,
            )
            train_summary_writer.add_scalar(
                "loss/test_loss", np.array(test_loss.item()), batch_count
            )
            train_summary_writer.flush()
        # if batch_count > 200:
        #     break
        if not max_iters:
            continue
        else:
            if batch_count > max_iters:
                break

train_summary_writer.close()

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


epochs for runs/2024-09-03T22:35:20NewCategorical:   0%|          | 0/1 [00:00<?, ?it/s]

batches:   0%|          | 0/3125 [00:00<?, ?it/s]

In [None]:
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()

In [None]:
ckpt = {"model": state, "step": batch_count}


checkpoint_dir = f"checkpoints/{model_name}"
checkpoint_dir = os.path.abspath(checkpoint_dir)

# os.makedirs(checkpoint_dir, exist_ok=True)

orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
save_args = orbax_utils.save_args_from_target(ckpt)
orbax_checkpointer.save(checkpoint_dir, ckpt, save_args=save_args)

In [None]:
# if model_name is None:
if True:
    all_checkpoints = os.listdir("checkpoints/")
    model_name = all_checkpoints
    all_checkpoints = sorted(all_checkpoints)
    model_name = all_checkpoints[-1]
    checkpoint_dir = f"checkpoints/{model_name}"
    checkpoint_dir = os.path.abspath(checkpoint_dir)

new_checkpoint = orbax_checkpointer.restore(checkpoint_dir)
new_state = tsd.create_train_state(time_series_regressor, mts_main_key, batch, 0.0001)
new_state = new_state.replace(params=new_checkpoint["model"]["params"])



In [None]:
numeric_inputs, categorical_inputs = train_ds[0]
numeric_inputs = jnp.array([numeric_inputs])
categorical_inputs = jnp.array([categorical_inputs])
test_results = state.apply_fn(
    {"params": state.params},
    # jnp.array(i[0]),
    # jnp.array(i[1]),
    numeric_inputs,
    categorical_inputs,
    deterministic=True,
    mask_data=False,
)
test_results["numeric_out"].shape

(1, 27, 59)

In [None]:
@dataclass
class Results:
    numeric_out: jnp.array
    categorical_out: jnp.array
    numeric_inputs: jnp.array
    categorical_inputs: jnp.array


def return_results(state, dataset, idx=0, mask_start: int = None):
    numeric_inputs, categorical_inputs = dataset[idx]
    if mask_start:
        numeric_inputs = numeric_inputs[:, :mask_start]
        categorical_inputs = categorical_inputs[:, :mask_start]
    numeric_inputs = jnp.array([numeric_inputs])
    categorical_inputs = jnp.array([categorical_inputs])
    out = state.apply_fn(
        {"params": state.params},
        # hp.mask_tensor(jnp.array([train_ds[0]]), dataset, prng_key=key),
        numeric_inputs=numeric_inputs,
        categorical_inputs=categorical_inputs,
        deterministic=True,
        mask_data=MASK_DATA,
    )
    numeric_out, categorical_out = out["numeric_out"], out["categorical_out"]
    return Results(numeric_out, categorical_out, numeric_inputs, categorical_inputs)


x = return_results(state, train_ds, 0)
x.categorical_out.shape

(1, 2, 59, 38)

In [None]:
mask_data = False
MASK_DATA = True


def process_results(arr: jnp.array, col_names: list, config: hp.TimeSeriesConfig):
    arr = jnp.squeeze(arr)
    if arr.ndim == 3:
        # Check if there is a logit array for example if there are 3 dims then the
        # last dim is the logit array. We need to get the argmax of the last dim
        # to get the actual values of the array and replace the logit array with the
        # actual values
        arr = jnp.argmax(arr, axis=-1)
    df = pd.DataFrame(arr.T)
    df.columns = col_names
    return df


@dataclass
class DFComparison:
    input_df: pd.DataFrame
    output_df: pd.DataFrame


def show_results_df(
    state, time_series_config, dataset, idx: int = 0, mask_start: int = None
):
    results = return_results(state, dataset, idx=idx, mask_start=mask_start)

    input_categorical = process_results(
        results.categorical_inputs,
        time_series_config.categorical_col_tokens,
        time_series_config,
    )
    input_numeric = process_results(
        results.numeric_inputs,
        time_series_config.numeric_col_tokens,
        time_series_config,
    )
    output_categorical = process_results(
        results.categorical_out,
        time_series_config.categorical_col_tokens,
        time_series_config,
    )
    output_numeric = process_results(
        results.numeric_out, time_series_config.numeric_col_tokens, time_series_config
    )
    input_df = pd.concat([input_categorical, input_numeric], axis=1)
    output_df = pd.concat([output_categorical, output_numeric], axis=1)

    return DFComparison(input_df, output_df)


df_comp = show_results_df(
    state=state, time_series_config=time_series_config, dataset=train_ds, idx=0
)

In [None]:
df_comp.output_df.loc[:, time_series_config.categorical_col_tokens].tail()

Unnamed: 0,acceleration_x,acceleration_y
54,28,13
55,28,13
56,28,13
57,28,13
58,28,13


In [None]:
df_comp.output_df.loc[:, time_series_config.categorical_col_tokens].tail()

Unnamed: 0,acceleration_x,acceleration_y
54,28,13
55,28,13
56,28,13
57,28,13
58,28,13


In [None]:
def plot_planets(df_pred: pd.DataFrame, df_actual: pd.DataFrame, column: str, offset=0):
    plt.figure(figsize=(15, 10))
    plt.plot(df_pred[column], label="Autogregressive")
    plt.plot(df_actual[column], label="Actual")
    plt.title(f"{column} Predictions")
    plt.legend()
    # Show ticks and grid lines every 1 step
    plt.xticks(np.arange(0, len(df_pred), 1))
    plt.grid()
    # add black line at 0 on the y axis to show the difference
    plt.axhline(0, color="black")
    plt.show()

In [None]:
jnp.array([True, True, False, False, True]).shape

(5,)

In [None]:
x = jnp.ones((5, 20))
print(x.shape)
xx = x.at[jnp.array([True, True, False, False, True]), :].set(0)
xx

(5, 20)


Array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0.],
       [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0.]], dtype=float32)

In [None]:
@dataclass
class AutoRegressiveResults:
    numeric_inputs: jnp.array
    categorical_inputs: jnp.array

    @classmethod
    def from_ds(cls, ds: hp.TimeSeriesDS, idx: int, stop_idx: int = 10):
        inputs = ds[idx]
        numeric_inputs = inputs[0][:, :stop_idx]
        categorical_inputs = inputs[1][:, :stop_idx]
        return cls(numeric_inputs, categorical_inputs)


def auto_regressive_predictions(
    state: train_state.TrainState,
    inputs: AutoRegressiveResults,
) -> jnp.array:
    numeric_inputs = inputs.numeric_inputs
    categorical_inputs = inputs.categorical_inputs
    # get the first row that contains all nan vales
    # if nan_rows_start >= stop_idx:
    #     return inputs
    # numeric_inputs = inputs.numeric_inputs
    # categorical_inputs = inputs.categorical_inputs
    numeric_nan_columns = jnp.isnan(numeric_inputs).all(axis=1)
    categorical_nan_columns = jnp.isnan(categorical_inputs).all(axis=1)
    outputs = state.apply_fn(
        {"params": state.params},
        numeric_inputs=jnp.array([numeric_inputs]),
        categorical_inputs=jnp.array([categorical_inputs]),
        deterministic=True,
        mask_data=False,
    )
    numeric_out = jnp.squeeze(outputs["numeric_out"])
    categorical_out = jnp.squeeze(outputs["categorical_out"])
    categorical_out = jnp.argmax(categorical_out, axis=-1)

    final_numeric_row = np.array(numeric_out[:, -1])
    final_numeric_row = final_numeric_row[:, None]  # New axis

    final_categorical_row = np.array(categorical_out[:, -1])
    final_categorical_row = final_categorical_row[:, None]  # New axis
    numeric_inputs = jnp.concatenate([numeric_inputs, final_numeric_row], axis=1)
    categorical_inputs = jnp.concatenate(
        [categorical_inputs, final_categorical_row], axis=1
    )
    numeric_inputs = numeric_inputs.at[jnp.array(numeric_nan_columns)].set(jnp.nan)
    categorical_inputs = categorical_inputs.at[jnp.array(categorical_nan_columns)].set(
        jnp.nan
    )
    inputs = (numeric_inputs, categorical_inputs)

    return inputs
    # return auto_regressive_predictions(state, inputs, stop_idx)

In [None]:
test_inputs = AutoRegressiveResults.from_ds(train_ds, 0, 10)
# inputs_test = train_ds[0]
# test_numeric = inputs_test[0]
# test_categorical = inputs_test[1]
# print(inputs_test.shape)
for i in trange(21):
    inputs_test = auto_regressive_predictions(state, test_inputs)

# x = auto_regressive_predictions(state, test_ds[0], 10)

  0%|          | 0/21 [00:00<?, ?it/s]

In [None]:
# res = show_results_df(state, train_df, test_ds, idx=0, mask_start=30)

# plot_planets(res["pred"], res["actual_masked"], "planet2_x", offset=0)

: 