<https://github.com/PolymathicAI/xVal>


In [None]:
import os

os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.95"

In [None]:
import jax.numpy as jnp  # Oddly works in colab to set gpu

arr = jnp.array([1, 2, 3])
arr.devices()

In [None]:
import icecream
from icecream import ic

icecream.install()
ic_disable = False
if ic_disable:
    ic.disable()
ic.configureOutput(includeContext=True, contextAbsPath=True)

In [None]:
import os
import ast
import re

from datetime import datetime as dt
from torch.utils.data import DataLoader
import seaborn as sns
import matplotlib.pyplot as plt
import hephaestus as hp
import jax
import jax.numpy as jnp
import numpy as np
import optax
import pandas as pd
from flax.training import train_state
from icecream import ic
from jax import random
from jax.tree_util import tree_flatten
from torch.utils.tensorboard import SummaryWriter
from tqdm.notebook import tqdm, trange
from hephaestus.models.simple_time_series import SimpleDS

pd.options.mode.copy_on_write = True

In [None]:
from transformers import FlaxBertModel, BertTokenizerFast
import jax.numpy as jnp

# Load pre-trained BERT model and tokenizer
model_name = "bert-base-uncased"
model = FlaxBertModel.from_pretrained(model_name)
tokenizer = BertTokenizerFast.from_pretrained(model_name)

# Get the embeddings matrix
embeddings = model.params["embeddings"]["word_embeddings"]["embedding"]

# Now you can access specific embeddings like this:
# For example, to get embeddings for tokens 23, 293, and 993:
selected_embeddings = jnp.take(embeddings, jnp.array([23, 293, 993]), axis=0)

# If you want to get embeddings for specific words:
words = ["hello", "world", "example"]
tokens = tokenizer.convert_tokens_to_ids(words)
word_embeddings = jnp.take(embeddings, jnp.array(tokens), axis=0)
word_embeddings.shape

In [None]:
embeddings[jnp.array([23, 293, 993])].shape

In [None]:
def line2df(line, idx):
    data_rows = []
    line = ast.literal_eval(line)
    for i, time_step in enumerate(line["data"]):
        row = {"time_step": i}
        # Add position data for each planet
        for j, position in enumerate(time_step):
            row[f"planet{j}_x"] = position[0]
            row[f"planet{j}_y"] = position[1]
        data_rows.append(row)

    df = pd.DataFrame(data_rows)
    description = line.pop("description")
    step_size = description.pop("stepsize")
    for k, v in description.items():
        for k_prop, v_prop in v.items():
            df[f"{k}_{k_prop}"] = v_prop
    df["time_step"] = df["time_step"] * step_size
    df.insert(0, "idx", idx)

    return df

In [None]:
files = os.listdir("data")
if "planets.parquet" not in files:
    with open("data/planets.data") as f:
        data = f.read().splitlines()

        dfs = []
        for idx, line in enumerate(tqdm(data)):
            dfs.append(line2df(line, idx))
        print("Concatenating dfs...")
        df = pd.concat(dfs)
    df.to_parquet("data/planets.parquet")
else:
    df = pd.read_parquet("data/planets.parquet")


# Combine total mass of all planets into one column `planet<n>_m`
mass_regex = re.compile(r"planet(\d+)_m")
mass_cols = [col for col in df.columns if mass_regex.match(col)]
df["total_mass"] = df[mass_cols].sum(axis=1)

# Introduce categorical columns for the number of planets choose non null columns with mass
df["n_planets"] = df[mass_cols].notnull().sum(axis=1).astype("category")
# Create category acceleration if the sum of plane/d_[x,y, z] is greater than 0
df["acceleration_x"] = df[
    [col for col in df.columns if "planet" in col and "_x" in col]
].sum(axis=1)
# Set acceleration_x to "increasing" if greater than 0 else "decreasing"
df["acceleration_x"] = df["acceleration_x"].apply(
    lambda x: "increasing" if x > 0 else "decreasing"
)
df["acceleration_y"] = df[
    [col for col in df.columns if "planet" in col and "_y" in col]
].sum(axis=1)
df["acceleration_y"] = df["acceleration_y"].apply(
    lambda x: "increasing" if x > 0 else "decreasing"
)


df.describe()

In [None]:
# Get train test split at 80/20
train_idx = int(df.idx.max() * 0.8)
train_df = df.loc[df.idx < train_idx].copy()
test_df = df.loc[df.idx >= train_idx].copy()
# del df
train_ds = SimpleDS(train_df)
test_ds = SimpleDS(test_df)
len(train_ds), len(test_ds)

In [None]:
test_df

In [None]:
test_df.shape, train_df.shape

In [None]:
df.idx.max()

In [None]:
jnp.array([train_ds[0][0]])

In [None]:
def make_batch(ds: SimpleDS, start: int, length: int):
    numeric = []
    categorical = []
    for i in range(start, length + start):
        numeric.append(ds[i][0])
        categorical.append(ds[i][1])
    # print index of None values
    return {"numeric": jnp.array(numeric), "categorical": jnp.array(categorical)}


batch = make_batch(train_ds, 0, 4)
# batch

In [None]:
# time_series_regressor = hp.simple_time_series.SimplePred(
#     train_ds, d_model=2048, n_heads=16 # large
# )
multiplier = 4
time_series_regressor = hp.simple_time_series.SimplePred(
    train_ds, d_model=512, n_heads=8 * multiplier
)

In [None]:
# train_ds.reservoir_encoded

In [None]:
# train_ds.reservoir_encoded
train_ds.reservoir_encoded[jnp.array([0, 1, 2, 23])]

In [None]:
batch["categorical"]

In [None]:
test_arr = jnp.array([1.0, 2.0, 3.0, 4.0])
# Convert to int
test_arr = test_arr.astype(jnp.int32)
test_arr

In [None]:
key = random.PRNGKey(0)
init_key, dropout_key = random.split(key)
vars = time_series_regressor.init(
    {"params": init_key, "dropout": dropout_key},
    batch["numeric"],
    categorical_inputs=batch["categorical"].astype(jnp.int32),
    deterministic=False,
)
dropout_key, original_dropout_key = random.split(dropout_key)

```
mask.shape=(4, 26, 1, 59, 59), attn_weights.shape=(4, 26, 16, 59, 59)
Mask example
[[0\. 0\. 0\. ... 0\. 0\. 0.]
 [0\. 1\. 0\. ... 0\. 0\. 0.]
 [0\. 1\. 1\. ... 0\. 0\. 0.]
 ...
 [1\. 1\. 1\. ... 1\. 0\. 0.]
 [1\. 1\. 1\. ... 1\. 1\. 0.]
 [1\. 1\. 1\. ... 1\. 1\. 1.]]
```


In [None]:
df.shape

In [None]:
# ic.disable()

In [None]:
x = time_series_regressor.apply(
    vars,
    batch["numeric"],
    batch["categorical"].astype(jnp.int32),
    deterministic=False,
    rngs={"dropout": dropout_key},
)
print(x.get("numeric_out").shape)
# Check if categorical input is None and print None or it's shape
print(x.get("categorical_out").shape if x.get("categorical_out") is not None else None)

In [None]:
# time_series_regressor.tabulate(
#     {"params": init_key, "dropout": dropout_key},
#     batch["numeric"],
#     console_kwargs={"force_jupyter": True, "width": 120},
# )

In [None]:
def calculate_memory_footprint(params):
    """Calculate total memory footprint of JAX model parameters and total
    number of parameters."""
    total_bytes = 0
    # Flatten the parameter tree structure into a list of arrays
    flat_params, _ = tree_flatten(params)
    for param in flat_params:
        # Calculate bytes: number of elements * size of each element
        bytes_per_param = param.size * param.dtype.itemsize
        total_bytes += bytes_per_param
    return total_bytes


def count_parameters(params):
    return sum(jnp.prod(jnp.array(p.shape)) for p in jax.tree_util.tree_leaves(params))


mem = calculate_memory_footprint(vars)
total_params = count_parameters(vars)


print(f"Memory of custom: {mem / 1e6:.2f} MB with {total_params:,} parameters")

In [None]:
mts_root_key = random.PRNGKey(44)
mts_main_key, ts_params_key, ts_data_key = random.split(mts_root_key, 3)

mask_data = False


def clip_gradients(gradients, max_norm):
    total_norm = jnp.sqrt(sum(jnp.sum(jnp.square(grad)) for grad in gradients.values()))
    scale = max_norm / (total_norm + 1e-6)
    clipped_gradients = jax.tree_map(
        lambda grad: jnp.where(total_norm > max_norm, grad * scale, grad), gradients
    )
    return clipped_gradients


def base_loss(inputs, outputs):
    # Remove the first value and add a jnp.nan to the end
    # inputs = inputs * 3
    inputs_offset = 1
    inputs = inputs[:, :, inputs_offset:]
    print(f"Inputs shape: {inputs.shape=}")
    # Add a jnp.nan to the end
    temp_null = jnp.full((inputs.shape[0], inputs.shape[1], inputs_offset), jnp.nan)
    inputs = jnp.concatenate(
        [inputs, temp_null],
        axis=2,
    )
    print(f"Inputs shape after addition: {inputs.shape=}")
    nan_mask = jnp.isnan(inputs)
    inputs = jnp.where(nan_mask, jnp.zeros_like(inputs), inputs)

    # outputs = outputs[:, :, :-inputs_offset]
    outputs = jnp.where(nan_mask, jnp.zeros_like(outputs), outputs)

    # raw_loss = optax.squared_error(outputs, inputs)
    # compute manually
    # raw_loss = jnp.square(outputs - inputs)
    # Abs loss
    raw_loss = jnp.abs(outputs - inputs)
    masked_loss = jnp.where(nan_mask, 0.0, raw_loss)
    loss = masked_loss.sum() / (~nan_mask).sum()

    return loss


def calculate_loss(params, state, inputs, dropout_key, mask_data: bool = True):
    outputs = state.apply_fn(
        {"params": params},
        # hp.mask_tensor(inputs, dataset, prng_key=mask_key),
        inputs,
        rngs={"dropout": dropout_key},
        deterministic=False,
        mask_data=mask_data,
    )
    loss = base_loss(inputs, outputs)
    # Create mask for nan inputs

    return loss


@jax.jit
def train_step(state: train_state.TrainState, batch, base_key, mask_data=True):
    dropout_key, mask_key, new_key = jax.random.split(base_key, 3)

    def calculate_loss_with_mask(params):
        return calculate_loss(params, state, batch, dropout_key, mask_data=True)

    def calculate_loss_without_mask(params):
        return calculate_loss(params, state, batch, dropout_key, mask_data=False)

    def loss_fn(params):
        return jax.lax.cond(
            mask_data,
            lambda _: calculate_loss_with_mask(params),
            lambda _: calculate_loss_without_mask(params),
            operand=None,
        )

    # def loss_fn(params):
    #     return calculate_loss(params, state, batch, dropout_key, mask_data=mask_data)

    grad_fn = jax.value_and_grad(loss_fn)

    # (loss, individual_losses), grad = grad_fn(state.params)
    loss, grad = grad_fn(state.params)
    # grad = replace_nans(grad)
    # grad = clip_gradients(grad, 1.0)
    state = state.apply_gradients(grads=grad)

    return state, loss, new_key


def evaluate(params, state, inputs, mask_data: bool = True):
    outputs = state.apply_fn(
        {"params": params},
        # hp.mask_tensor(inputs, dataset, prng_key=mask_key),
        inputs,
        deterministic=True,
        mask_data=mask_data,
    )
    loss = base_loss(inputs, outputs)
    return loss


@jax.jit
def eval_step(state: train_state.TrainState, batch, base_key, mask_data=True):
    mask_key, dropout_key, new_key = jax.random.split(base_key, 3)

    def calculate_loss_with_mask(params):
        return calculate_loss(params, state, batch, dropout_key, mask_data=True)

    def calculate_loss_without_mask(params):
        return calculate_loss(params, state, batch, dropout_key, mask_data=False)

    def loss_fn(params):
        return jax.lax.cond(
            mask_data,
            lambda _: calculate_loss_with_mask(params),
            lambda _: calculate_loss_without_mask(params),
            operand=None,
        )

    # def loss_fn(params):
    #     return evaluate(params, state, batch, mask_data=mask_data)

    # (loss, individual_losses), grad = grad_fn(state.params)
    loss = loss_fn(state.params)
    return loss, new_key


def create_train_state(model, prng, batch, lr):
    init_key, dropout_key = random.split(prng)
    params = model.init(
        {"params": init_key, "dropout": dropout_key}, batch, deterministic=False
    )
    # optimizer = optax.chain(optax.adam(lr))
    optimizer = optax.chain(optax.clip_by_global_norm(0.4), optax.adam(lr))
    # optimizer_state = optimizer.init(params)
    return train_state.TrainState.create(
        apply_fn=model.apply,
        params=params["params"],
        tx=optimizer,
        # tx_state=optimizer_state,
    )


batch_size = 2
# batch = train_ds[0]
# state = create_train_state(time_series_regressor, mts_main_key, batch, 0.0001)
state = create_train_state(time_series_regressor, mts_main_key, batch, 0.0001)

In [None]:
writer_name = "BERT_Embeddings"

writer_time = dt.now().strftime("%Y-%m-%dT%H:%M:%S")
model_name = writer_time + writer_name
train_summary_writer = SummaryWriter("runs/" + model_name)

MASK_DATA = True

test_set_key = random.PRNGKey(4454)

batch_size = 16
train_data_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
test_data_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=True)

# train_data_loader = DataLoader(train_ds, batch_size=256 // 2, shuffle=True)
# test_data_loader = DataLoader(test_ds, batch_size=256 // 2, shuffle=True)

batch_count = 0
base_key = random.PRNGKey(42)

# Disable IC for training
max_iters = 200
ic.disable()
for j in trange(1, desc=f"epochs for {train_summary_writer.log_dir}"):
    # arrs = train_data_loader()
    for i in tqdm(train_data_loader, leave=False, desc="batches"):
        # for i in trange(len(pre_train) // batch_size, leave=False):
        # for i in trange(len(pre_train) // batch_size //10, leave=False):
        # batch = make_batch(train_ds, i[0], 4)

        state, loss, base_key = train_step(
            state, jnp.array(i), base_key, mask_data=MASK_DATA
        )
        if jnp.isnan(loss):
            raise ValueError("Nan Value in loss, stopping")
        batch_count += 1

        if batch_count % 1 == 0:
            train_summary_writer.add_scalar(
                "loss/loss", np.array(loss.item()), batch_count
            )
        if batch_count % 10 == 0:
            test_loss, base_key = eval_step(
                state,
                jnp.array(next(iter(test_data_loader))),
                base_key,
                mask_data=MASK_DATA,
            )
            train_summary_writer.add_scalar(
                "loss/test_loss", np.array(test_loss.item()), batch_count
            )
            train_summary_writer.flush()
        # if batch_count > 200:
        #     break
        if batch_count > max_iters:
            break

train_summary_writer.close()

In [None]:
import orbax.checkpoint
from flax.training import orbax_utils

# Save the model
ckpt_dir = f"ckpts/{model_name}/"
# make absolute path
ckpt_dir = os.path.abspath(ckpt_dir)

# if os.path.exists(ckpt_dir):
#     shutil.rmtree(ckpt_dir)

ckpt = {"state": state, "step": batch_count}
# Save to disk
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
save_args = orbax_utils.save_args_from_target(ckpt)
orbax_checkpointer.save(ckpt_dir, ckpt, save_args=save_args)

new_state = create_train_state(time_series_regressor, mts_main_key, batch, 0.0001)
# Load from disk
ckpt1 = orbax_checkpointer.restore(ckpt_dir)
new_state = new_state.replace(params=ckpt1["state"]["params"])

In [None]:
model_name = "2024-06-13T03:10:09MAE_Loss_Large"
ckpt_dir = f"ckpts/{model_name}/"
# make absolute path
ckpt_dir = os.path.abspath(ckpt_dir)
new_state = create_train_state(time_series_regressor, mts_main_key, batch, 0.0001)
# Load from disk
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()

ckpt1 = orbax_checkpointer.restore(ckpt_dir)
new_state = new_state.replace(params=ckpt1["state"]["params"])
# state = new_state

In [None]:
def return_results(state, dataset, idx=0, mask_start: int = None):
    inputs = dataset[idx]
    if mask_start:
        inputs = inputs[:, :mask_start]
    inputs = jnp.array([inputs])
    outputs = state.apply_fn(
        {"params": state.params},
        # hp.mask_tensor(jnp.array([train_ds[0]]), dataset, prng_key=key),
        inputs,
        deterministic=True,
        mask_data=MASK_DATA,
    )
    return outputs, inputs

In [None]:
mask_data = False
MASK_DATA = True


def show_results_df(state, base_df, dataset, idx: int = 0, mask_start: int = None):
    outputs, inputs = return_results(state, dataset, idx=idx, mask_start=mask_start)

    outputs = jnp.squeeze(outputs)
    df_pred = pd.DataFrame(outputs.T)
    df_pred.columns = base_df.columns[1:]

    inputs = jnp.squeeze(inputs)
    df_actual_masked = pd.DataFrame(inputs.T)
    df_actual_masked.columns = base_df.columns[1:]
    # remove the first row to match the prediction
    # df_actual_masked = df_actual_masked.iloc[1:].reset_index()
    diff_df = df_pred - df_actual_masked

    inputs_no_mask = jnp.array([dataset[idx]])
    df_no_mask = pd.DataFrame(jnp.squeeze(inputs_no_mask).T)
    df_no_mask.columns = base_df.columns[1:]
    # df_no_mask = df_no_mask.iloc[1:].reset_index()  # rm first row
    diff_df_no_mask = df_pred - df_no_mask
    return {
        "pred": df_pred,
        "actual_masked": df_actual_masked,
        "actual_no_mask": df_no_mask,
        "diff_masked": diff_df,
        "diff_no_mask": diff_df_no_mask,
    }


res = show_results_df(new_state, train_df, train_ds, idx=0, mask_start=10)

In [None]:
def show_heatmap(df, title):
    """Shows heatmap for a dataframe
    excludes all columns that are only nan and all rows that are only nan"""

    df = df.dropna(axis=1, how="all")
    df = df.dropna(axis=0, how="all")
    plt.figure(figsize=(15, 10))
    cmap = sns.diverging_palette(220, 20, as_cmap=True)
    sns.heatmap(df, cmap=cmap, center=0, annot=True, fmt=".2f")
    plt.title(title)
    plt.show()


show_heatmap(res["diff_masked"], "Diff Masked")

In [None]:
show_heatmap(res["diff_no_mask"].head(14), "Diff Masked")

In [None]:
test_key = random.PRNGKey(4454)
x = jax.random.normal(test_key, (4, 26, 59, 256))

In [None]:
no_mask_out = state.apply_fn(
    {"params": state.params},
    # jnp.array([test_ds[0][:10, :]]),
    jnp.array([test_ds[0][:, :10]]),
    deterministic=True,
    mask_data=True,
)
mask_out = state.apply_fn(
    {"params": state.params},
    jnp.array([test_ds[0][:, :20]]),
    deterministic=True,
    mask_data=True,
)
mask_out_df = pd.DataFrame(jnp.squeeze(mask_out).T)
mask_out_df.columns = test_df.columns[1:]
no_mask_out_df = pd.DataFrame(jnp.squeeze(no_mask_out).T)
no_mask_out_df.columns = test_df.columns[1:]

test_diff = mask_out_df - no_mask_out_df
test_diff

In [None]:
def plot_planets(df_pred: pd.DataFrame, df_actual: pd.DataFrame, column: str, offset=0):
    plt.figure(figsize=(15, 10))
    plt.plot(df_pred[column], label="Autogregressive")
    plt.plot(df_actual[column], label="Actual")
    plt.title(f"{column} Predictions")
    plt.legend()
    # Show ticks and grid lines every 1 step
    plt.xticks(np.arange(0, len(df_pred), 1))
    plt.grid()
    # add black line at 0 on the y axis to show the difference
    plt.axhline(0, color="black")
    plt.show()

In [None]:
def auto_regressive_predictions(
    state: train_state.TrainState, inputs: jnp.ndarray
) -> np.ndarray:
    # get the first row that contains all nan vales
    # if nan_rows_start >= stop_idx:
    #     return inputs
    nan_columns = jnp.isnan(inputs).all(axis=1)
    outputs = state.apply_fn(
        {"params": state.params},
        jnp.array([inputs]),
        # jnp.array([inputs]),
        deterministic=True,
        mask_data=MASK_DATA,
    )
    outputs = jnp.squeeze(outputs)
    final_row = np.array(outputs[:, -1])
    final_row = final_row[:, None]
    inputs = jnp.concatenate([inputs, final_row], axis=1)
    inputs = np.array(inputs)
    inputs[nan_columns] = np.nan
    return inputs
    # return auto_regressive_predictions(state, inputs, stop_idx)

In [None]:
base_inputs = test_ds[300]
inputs_test = base_inputs[:, :10]
print(inputs_test.shape)
for i in trange(21):
    inputs_test = auto_regressive_predictions(state, inputs_test)

# x = auto_regressive_predictions(state, test_ds[0], 10)

In [None]:
df_auto = pd.DataFrame(inputs_test.T)
df_actual = pd.DataFrame(base_inputs.T)
df_auto.columns = train_df.columns[1:]
df_actual.columns = train_df.columns[1:]
df_diff = df_auto - df_actual

# Drop rows that are all nan
df_diff = df_diff.dropna(axis=0, how="all")

In [None]:
plot_planets(df_auto, df_actual, "time_step")

In [None]:
plot_planets(df_auto, df_actual, "planet3_x")

In [None]:
# inputs = jnp.array(next(iter(test_data_loader)))
inputs_test = make_batch(test_ds, 0, 1)

outputs = new_state.apply_fn(
    {"params": state.params},
    # hp.mask_tensor(inputs, dataset, prng_key=mask_key),
    inputs_test,
    deterministic=True,
    mask_data=True,
)
df_actual = pd.DataFrame(jnp.squeeze(inputs_test).T)
df_actual.columns = test_df.columns[1:]

df_pred = pd.DataFrame(jnp.squeeze(outputs).T)
df_pred.columns = test_df.columns[1:]
plot_planets(df_pred, df_actual, "time_step")
plot_planets(df_pred, df_actual, "planet2_y")

In [None]:
inputs_test.shape

In [None]:
show_heatmap(df_diff, "Auto Regressive Predictions")

In [None]:
# plot planet0_x from df_auto and df_actual


res = show_results_df(state, train_df, train_ds, idx=299, mask_start=20)
plot_planets(res["pred"], res["actual_masked"], "planet2_y", offset=0)
# plot_planets(df_auto, df_actual, "planet2_y")

In [None]:
# plot planet0_x from df_auto and df_actual
res = show_results_df(state, train_df, train_ds, idx=20, mask_start=20)
plot_planets(res["pred"], res["actual_no_mask"], "planet1_y", offset=0)

In [None]:
res = show_results_df(state, train_df, test_ds, idx=0, mask_start=30)

plot_planets(res["pred"], res["actual_masked"], "planet2_x", offset=0)

In [None]:
loss, key = eval_step(state, jnp.array(next(iter(test_data_loader))), base_key)
loss, key