<https://github.com/PolymathicAI/xVal>


In [1]:
import os

os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.95"

In [2]:
import jax.numpy as jnp  # Oddly works in colab to set gpu

arr = jnp.array([1, 2, 3])
arr.devices()

{cuda(id=0)}

In [3]:
import icecream
from icecream import ic

icecream.install()
ic_disable = True
if ic_disable:
    ic.disable()
ic.configureOutput(includeContext=True, contextAbsPath=True)

In [4]:
import os
import ast

from datetime import datetime as dt
from torch.utils.data import Dataset

import hephaestus as hp
import jax
import jax.numpy as jnp
import numpy as np
import optax
import pandas as pd
from flax.training import train_state
from icecream import ic
from jax import random
from flax import struct
from jax.tree_util import tree_flatten
from torch.utils.tensorboard import SummaryWriter
from tqdm.notebook import tqdm, trange

pd.options.mode.copy_on_write = True



In [5]:
def line2df(line, idx):
    data_rows = []
    line = ast.literal_eval(line)
    for i, time_step in enumerate(line["data"]):
        row = {"time_step": i}
        # Add position data for each planet
        for j, position in enumerate(time_step):
            row[f"planet{j}_x"] = position[0]
            row[f"planet{j}_y"] = position[1]
        data_rows.append(row)

    df = pd.DataFrame(data_rows)
    description = line.pop("description")
    step_size = description.pop("stepsize")
    for k, v in description.items():
        for k_prop, v_prop in v.items():
            df[f"{k}_{k_prop}"] = v_prop
    df["time_step"] = df["time_step"] * step_size
    df.insert(0, "idx", idx)

    return df

In [6]:
files = os.listdir("data")
if "planets.parquet" not in files:
    with open("data/planets.data") as f:
        data = f.read().splitlines()

        dfs = []
        for idx, line in enumerate(tqdm(data)):
            dfs.append(line2df(line, idx))
        print("Concatenating dfs...")
        df = pd.concat(dfs)
    df.to_parquet("data/planets.parquet")
else:
    df = pd.read_parquet("data/planets.parquet")

In [7]:
# Get min, mean, and max number of time steps
df.groupby("idx").count().time_step.agg(["min", "mean", "max"])

min     30.000000
mean    44.511656
max     59.000000
Name: time_step, dtype: float64

In [8]:
class SimpleDS(Dataset):
    def __init__(self, df):
        # Add nan padding to make sure all sequences are the same length
        # use the idx column to group by
        self.max_seq_len = df.groupby("idx").count().time_step.max()

        self.df = df
        self.batch_size = self.max_seq_len

        self.special_tokens = ["[PAD]", "[NUMERIC_MASK]", "[MASK]"]
        self.cat_mask = "[MASK]"
        self.numeric_mask = "[NUMERIC_MASK]"

        self.col_tokens = [col_name for col_name in df.columns if col_name != "idx"]

        self.tokens = self.special_tokens + self.col_tokens

        self.token_dict = {token: i for i, token in enumerate(self.tokens)}
        self.token_decoder_dict = {i: token for i, token in enumerate(self.tokens)}
        self.n_tokens = len(self.tokens)
        self.numeric_indices = jnp.array(
            [self.tokens.index(i) for i in self.col_tokens]
        )

        self.numeric_mask_token = self.tokens.index(self.numeric_mask)

    def __len__(self):
        return df.idx.max() + 1  # probably should be max idx + 1 thanks

    def __getitem__(self, set_idx):
        batch = self.df.loc[
            df.idx == set_idx, [col for col in self.df.columns if col != "idx"]
        ]
        batch = np.array(batch.values)
        # Add padding
        batch_len, n_cols = batch.shape
        pad_len = self.max_seq_len - batch_len
        padding = np.full((pad_len, n_cols), jnp.nan)
        batch = np.concatenate([batch, padding], axis=0)
        return batch

In [9]:
np.full((3, 3), np.nan)

array([[nan, nan, nan],
       [nan, nan, nan],
       [nan, nan, nan]])

In [10]:
train_ds = SimpleDS(df)

In [11]:
time_series_regressor = hp.simple_time_series.SimplePred(train_ds, d_model=64 * 4)

In [12]:
train_ds[0]

array([[ 0.        ,  1.56006022, -0.85443699, ...,         nan,
                nan,         nan],
       [ 0.46511628,  1.68985799, -0.5143588 , ...,         nan,
                nan,         nan],
       [ 0.93023256,  1.75358875, -0.15420858, ...,         nan,
                nan,         nan],
       ...,
       [        nan,         nan,         nan, ...,         nan,
                nan,         nan],
       [        nan,         nan,         nan, ...,         nan,
                nan,         nan],
       [        nan,         nan,         nan, ...,         nan,
                nan,         nan]])

In [13]:
def make_batch(ds: SimpleDS, start: int, length: int):
    data = []
    for i in range(start, length + start):
        data.append(ds[i])

    return jnp.array(data)


batch = make_batch(train_ds, 0, 4)

In [14]:
vars = time_series_regressor.init(random.PRNGKey(0), batch)

x = time_series_regressor.apply(vars, batch)

In [15]:
x

Array([[[-0.00999765, -0.00999765, -0.00999765, ..., -0.00999765,
         -0.00999765, -0.00999765],
        [ 0.00572604,  0.00572604,  0.00572604, ...,  0.00572604,
          0.00572604,  0.00572604],
        [ 0.00717407,  0.00717407,  0.00717407, ...,  0.00717407,
          0.00717407,  0.00717407],
        ...,
        [-0.15913501, -0.15913501, -0.15913501, ..., -0.15913501,
         -0.15913501, -0.15913501],
        [-0.18229163, -0.18229163, -0.18229163, ..., -0.18229163,
         -0.18229163, -0.18229163],
        [-0.14735968, -0.14735968, -0.14735968, ..., -0.14735968,
         -0.14735968, -0.14735968]],

       [[ 0.01031545,  0.01031545,  0.01031545, ...,  0.01031545,
          0.01031545,  0.01031545],
        [ 0.03705181,  0.03705181,  0.03705181, ...,  0.03705181,
          0.03705181,  0.03705181],
        [ 0.0409942 ,  0.0409942 ,  0.0409942 , ...,  0.0409942 ,
          0.0409942 ,  0.0409942 ],
        ...,
        [-0.15913501, -0.15913501, -0.15913501, ..., -

In [16]:
batch.shape

(4, 59, 26)

In [17]:
def calculate_memory_footprint(params):
    """Calculate total memory footprint of JAX model parameters."""
    total_bytes = 0
    # Flatten the parameter tree structure into a list of arrays
    flat_params, _ = tree_flatten(params)
    for param in flat_params:
        # Calculate bytes: number of elements * size of each element
        bytes_per_param = param.size * param.dtype.itemsize
        total_bytes += bytes_per_param
    return total_bytes

In [27]:
def create_causal_mask(tensor: jnp.ndarray):
    """Create a causal mask to mask out future values."""
    mask = jnp.tril(jnp.ones((tensor.shape[0], tensor.shape[1])))
    return mask


def create_padding_mask(tensor: jnp.ndarray):
    """Create a padding mask to mask out padded values."""
    mask = jnp.isnan(tensor)
    return mask


def mask_array(tensor: jnp.array):
    """Create a mask for the tensor"""
    causal_mask = create_causal_mask(tensor)
    padding_mask = create_padding_mask(tensor)
    mask = jnp.logical_or(causal_mask, padding_mask)
    return mask

In [30]:
jnp.array(train_ds[0]).shape

(59, 26)

In [37]:
def create_causal_mask(batch_size, time_steps, n_columns):
    # Create a 2D mask for the time steps
    # Only future steps are masked
    mask = jnp.triu(np.ones((time_steps, time_steps)), k=1).astype(bool)

    # Expand the mask across the batch size and n_columns
    # The resulting mask will have the shape: (batch_size, time_steps, n_columns, time_steps)
    mask = jnp.tile(mask[:, None, :], (1, n_columns, 1))
    mask = jnp.tile(mask[None, :, :, :], (batch_size, 1, 1, 1))

    return mask


# Example usage
batch_size = 2
time_steps = 4
n_columns = 3
mask = create_causal_mask(batch_size, time_steps, n_columns)
example_tensor = jnp.ones((batch_size, time_steps, n_columns))
masked_tensor = jnp.where(mask, example_tensor, jnp.nan)
masked_tensor

ValueError: Incompatible shapes for broadcasting: shapes=[(2, 4, 3, 4), (2, 4, 3), ()]

In [33]:
create_causal_mask(jnp.ones((4, 5)))

Array([[1., 0., 0., 0., 0.],
       [1., 1., 0., 0., 0.],
       [1., 1., 1., 0., 0.],
       [1., 1., 1., 1., 0.]], dtype=float32)

In [21]:
mts_root_key = random.PRNGKey(44)
mts_main_key, ts_params_key, ts_data_key = random.split(mts_root_key, 3)


def clip_gradients(gradients, max_norm):
    total_norm = jnp.sqrt(sum(jnp.sum(jnp.square(grad)) for grad in gradients.values()))
    scale = max_norm / (total_norm + 1e-6)
    clipped_gradients = jax.tree_map(
        lambda grad: jnp.where(total_norm > max_norm, grad * scale, grad), gradients
    )
    return clipped_gradients


# @jax.jit
def calculate_loss(params, state, inputs, dataset: SimpleDS):
    out = state.apply_fn(
        {"params": params},
        hp.mask_tensor(inputs, dataset, prng_key=ts_data_key),
    )

    # Create mask for nan inputs
    nan_mask = jnp.isnan(inputs)
    # inputs = jnp.where(nan_mask, jnp.zeros_like(inputs), inputs)
    # out = jnp.where(nan_mask, jnp.zeros_like(out), out)

    loss_raw = optax.squared_error(out, inputs)
    loss = jnp.compress(~nan_mask, loss_raw).mean()
    # print(f"Loss ITEM!!!! {loss.item()}")
    return loss


# @jax.jit
def eval_step(state: train_state.TrainState, batch):
    def loss_fn(params):
        return calculate_loss(params, state, batch, train_ds)

    # (loss, individual_losses), grad = grad_fn(state.params)
    loss, loss_dict = loss_fn(state.params)
    return loss, loss_dict


@jax.jit
def train_step(state: train_state.TrainState, batch):
    def loss_fn(params):
        return calculate_loss(params, state, batch, train_ds)

    grad_fn = jax.value_and_grad(loss_fn)

    # (loss, individual_losses), grad = grad_fn(state.params)
    loss, grad = grad_fn(state.params)
    # grad = replace_nans(grad)
    # grad = clip_gradients(grad, 1.0)
    state = state.apply_gradients(grads=grad)

    return state, loss


def create_train_state(model, prng, batch, lr):
    params = model.init(prng, batch)
    # optimizer = optax.chain(optax.adam(lr))
    optimizer = optax.chain(optax.clip_by_global_norm(0.4), optax.adam(lr))
    # optimizer_state = optimizer.init(params)
    return train_state.TrainState.create(
        apply_fn=model.apply,
        params=params["params"],
        tx=optimizer,
        # tx_state=optimizer_state,
    )


batch_size = 2
# batch = train_ds[0]

state = create_train_state(time_series_regressor, mts_main_key, batch, 0.0001)

In [22]:
from torch.utils.data import DataLoader

from torch.utils.data import DataLoader


In [23]:
writer_name = "SimpleTS"

writer_time = dt.now().strftime("%Y-%m-%dT%H:%M:%S")

train_summary_writer = SummaryWriter(
    "runs/" + writer_time + "_wow_" + writer_name + "_train"
)


test_set_key = random.PRNGKey(4454)

train_data_loader = DataLoader(train_ds, batch_size=512, shuffle=True)
batch_count = 0
for j in trange(2, desc=f"epochs for {train_summary_writer.log_dir}"):
    # arrs = train_data_loader()
    for i in tqdm(train_data_loader, leave=False, desc="batches"):
        # for i in trange(len(pre_train) // batch_size, leave=False):
        # for i in trange(len(pre_train) // batch_size //10, leave=False):
        # batch = make_batch(train_ds, i[0], 4)

        state, loss = train_step(state, jnp.array(i))
        if jnp.isnan(loss):
            raise ValueError("Nan Value in loss, stopping")
        batch_count += 1

        if batch_count % 1 == 0:
            train_summary_writer.add_scalar(
                "loss/loss", np.array(loss.item()), batch_count
            )

epochs for runs/2024-05-13T01:58:49_wow_SimpleTS_train:   0%|          | 0/14 [00:00<?, ?it/s]

batches:   0%|          | 0/245 [00:00<?, ?it/s]

batches:   0%|          | 0/245 [00:00<?, ?it/s]

batches:   0%|          | 0/245 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [50]:
test_result = jnp.squeeze(
    state.apply_fn({"params": state.params}, jnp.array([train_ds[0]]))
)
test_result

Array([[1.4373543, 1.4373543, 1.4373913, ..., 1.4373543, 1.4373543,
        1.4373543],
       [1.4121699, 1.4121698, 1.4121699, ..., 1.4121699, 1.4121699,
        1.4121699],
       [1.3919055, 1.3919055, 1.3919055, ..., 1.3919055, 1.3919055,
        1.3919055],
       ...,
       [2.2000782, 2.2000782, 2.2000782, ..., 2.2000782, 2.2000782,
        2.2000782],
       [2.1688452, 2.1688452, 2.1688452, ..., 2.1688452, 2.1688452,
        2.1688452],
       [2.1502194, 2.1502194, 2.1502194, ..., 2.1502194, 2.1502194,
        2.1502194]], dtype=float32)

In [46]:
jnp.array(train_ds[0]).shape, test_result.shape

((59, 26), (1, 59, 26))

In [48]:
jnp.squeeze(test_result).shape

(59, 26)

In [52]:
df_pred = pd.DataFrame(test_result)
df_actual = pd.DataFrame(train_ds[0])

In [54]:
df_pred

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,16,17,18,19,20,21,22,23,24,25
0,1.437354,1.437354,1.437391,1.437354,1.437354,1.437354,1.437354,1.437391,1.437354,1.437354,...,1.437354,1.437354,1.437354,1.437354,1.437354,1.437354,1.437354,1.437354,1.437354,1.437354
1,1.41217,1.41217,1.41217,1.41217,1.41217,1.41217,1.41217,1.41217,1.41217,1.41217,...,1.41217,1.41217,1.41217,1.41217,1.41217,1.41217,1.41217,1.41217,1.41217,1.41217
2,1.391906,1.391906,1.391906,1.391906,1.391906,1.391906,1.391906,1.391906,1.391906,1.391906,...,1.391906,1.391906,1.391906,1.391906,1.391906,1.391906,1.391906,1.391906,1.391906,1.391906
3,1.36637,1.366354,1.366348,1.366347,1.366355,1.366348,1.366363,1.366355,1.366354,1.366354,...,1.366355,1.366355,1.366355,1.366355,1.366355,1.366355,1.366355,1.366355,1.366355,1.366355
4,1.323469,1.321951,1.321842,1.321652,1.321838,1.321779,1.322127,1.32192,1.321971,1.321971,...,1.321885,1.321885,1.321885,1.321885,1.321885,1.321885,1.321885,1.321885,1.321885,1.321885
5,1.271234,1.243167,1.242631,1.240222,1.241583,1.240969,1.244538,1.243003,1.243547,1.243487,...,1.242753,1.242753,1.242753,1.242753,1.242753,1.242753,1.242753,1.242753,1.242753,1.242753
6,2.28667,1.145433,1.14432,1.115955,1.124038,1.120317,1.162274,1.145752,1.1522,1.15174,...,1.142506,1.142506,1.142506,1.142506,1.142506,1.142506,1.142506,1.142506,1.142506,1.142506
7,4.521352,1.14699,1.187843,0.945027,0.935768,0.912765,1.324656,1.174479,1.234436,1.230192,...,1.145297,1.145297,1.145297,1.145297,1.145297,1.145297,1.145297,1.145297,1.145297,1.145297
8,3.774994,1.007358,1.450285,0.487091,0.141636,0.032512,1.942989,1.287989,1.576423,1.552956,...,1.142628,1.142628,1.142628,1.142628,1.142628,1.142628,1.142628,1.142628,1.142628,1.142628
9,4.412226,0.58843,1.866369,0.34909,-0.716906,-1.021551,2.44313,1.454736,1.956791,1.914964,...,1.168766,1.168766,1.168766,1.168766,1.168766,1.168766,1.168766,1.168766,1.168766,1.168766


In [55]:
df_actual

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,16,17,18,19,20,21,22,23,24,25
0,0.0,1.56006,-0.854437,0.720639,0.691729,0.944008,2.700632,1.312562,1.944263,1.897807,...,,,,,,,,,,
1,0.465116,1.689858,-0.514359,0.333295,0.942289,0.681604,2.785811,1.312562,1.944263,1.897807,...,,,,,,,,,,
2,0.930233,1.753589,-0.154209,-0.124995,0.992368,0.412951,2.845461,1.312562,1.944263,1.897807,...,,,,,,,,,,
3,1.395349,1.748068,0.212022,-0.556775,0.831727,0.14054,2.879232,1.312562,1.944263,1.897807,...,,,,,,,,,,
4,1.860465,1.673573,0.569904,-0.870579,0.494812,-0.133144,2.887018,1.312562,1.944263,1.897807,...,,,,,,,,,,
5,2.325581,1.533811,0.905607,-1.000151,0.053178,-0.405638,2.868949,1.312562,1.944263,1.897807,...,,,,,,,,,,
6,2.790698,1.335526,1.206832,-0.918192,-0.399688,-0.67453,2.825381,1.312562,1.944263,1.897807,...,,,,,,,,,,
7,3.255814,1.087824,1.463513,-0.64198,-0.767958,-0.937474,2.756889,1.312562,1.944263,1.897807,...,,,,,,,,,,
8,3.72093,0.801365,1.668203,-0.229808,-0.973515,-1.192211,2.664256,1.312562,1.944263,1.897807,...,,,,,,,,,,
9,4.186047,0.487549,1.816168,0.231092,-0.972449,-1.436585,2.548458,1.312562,1.944263,1.897807,...,,,,,,,,,,


In [None]:
model_name = f"big_train_{dt.now()}_"

current_dir = os.getcwd()

if not os.path.exists("./pre_trained_models/"):
    os.makedirs("./pre_trained_models/")

path = os.path.join(current_dir, "./pre_trained_models/")


ckpt_dir = f"./pre_trained_models/{model_name}"

# checkpoints.save_checkpoint(
#     ckpt_dir=path, target=state, step=batch_count, overwrite=True, prefix=model_name
# )

In [None]:
def jax_array_memory_usage(array):
    """Calculate the memory usage of a JAX array in bytes."""
    # Get the number of bytes per element based on the data type
    bytes_per_element = array.dtype.itemsize
    # Calculate the total number of elements in the array
    total_elements = np.prod(array.shape)
    # Calculate total memory usage
    memory_usage_bytes = bytes_per_element * total_elements
    return memory_usage_bytes


cat_memory_usage = jax_array_memory_usage(batch.categorical)
num_memory_usage = jax_array_memory_usage(batch.numeric)
memory_usage = cat_memory_usage + num_memory_usage
memory_usage_gb = memory_usage / 1024 / 1024 / 1024
print(f"Memory usage: {memory_usage} bytes")
print(f"Memory usage: {memory_usage_gb} gb")