<https://github.com/PolymathicAI/xVal>


In [1]:
import os

os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.95"

In [2]:
import jax.numpy as jnp  # Oddly works in colab to set gpu

arr = jnp.array([1, 2, 3])
arr.devices()

{cuda(id=0)}

In [3]:
import icecream
from icecream import ic

icecream.install()
ic_disable = False
if ic_disable:
    ic.disable()
ic.configureOutput(includeContext=True, contextAbsPath=True)

In [4]:
import os
import ast

from datetime import datetime as dt
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

import hephaestus as hp
import jax
import jax.numpy as jnp
import numpy as np
import optax
import pandas as pd
from flax.training import train_state
from icecream import ic
from jax import random
from flax import struct
from flax.training import checkpoints
from jax.tree_util import tree_flatten
from torch.utils.tensorboard import SummaryWriter
from tqdm.notebook import tqdm, trange
from hephaestus.models.simple_time_series import SimpleDS

pd.options.mode.copy_on_write = True



In [5]:
def line2df(line, idx):
    data_rows = []
    line = ast.literal_eval(line)
    for i, time_step in enumerate(line["data"]):
        row = {"time_step": i}
        # Add position data for each planet
        for j, position in enumerate(time_step):
            row[f"planet{j}_x"] = position[0]
            row[f"planet{j}_y"] = position[1]
        data_rows.append(row)

    df = pd.DataFrame(data_rows)
    description = line.pop("description")
    step_size = description.pop("stepsize")
    for k, v in description.items():
        for k_prop, v_prop in v.items():
            df[f"{k}_{k_prop}"] = v_prop
    df["time_step"] = df["time_step"] * step_size
    df.insert(0, "idx", idx)

    return df

In [6]:
files = os.listdir("data")
if "planets.parquet" not in files:
    with open("data/planets.data") as f:
        data = f.read().splitlines()

        dfs = []
        for idx, line in enumerate(tqdm(data)):
            dfs.append(line2df(line, idx))
        print("Concatenating dfs...")
        df = pd.concat(dfs)
    df.to_parquet("data/planets.parquet")
else:
    df = pd.read_parquet("data/planets.parquet")

In [7]:
# Get min, mean, and max number of time steps
df.groupby("idx").count().time_step.agg(["min", "mean", "max"])

min     30.000000
mean    44.511656
max     59.000000
Name: time_step, dtype: float64

In [8]:
# Get train test split at 80/20
train_idx = int(df.idx.max() * 0.8)
train_df = df.loc[df.idx < train_idx].copy()
test_df = df.loc[df.idx >= train_idx].copy()
# del df
train_ds = SimpleDS(train_df)
test_ds = SimpleDS(test_df)
len(train_ds), len(test_ds)

(99999, 25001)

In [9]:
test_df.shape, train_df.shape

((1110975, 27), (4452982, 27))

In [10]:
df.idx.max()

124999

In [11]:
def make_batch(ds: SimpleDS, start: int, length: int):
    data = []
    for i in range(start, length + start):
        data.append(ds[i])

    return jnp.array(data)


batch = make_batch(train_ds, 0, 4)

In [12]:
time_series_regressor = hp.simple_time_series.SimplePred(train_ds, d_model=64 * 4)

In [13]:
key = random.PRNGKey(0)
init_key, dropout_key = random.split(key)
vars = time_series_regressor.init(
    {"params": init_key, "dropout": dropout_key}, batch, deterministic=False
)
dropout_key, original_dropout_key = random.split(dropout_key)
x = time_series_regressor.apply(
    vars, batch, deterministic=False, rngs={"dropout": dropout_key}
)

ic| simple_time_series.py:224 in __call__()
    mask.shape: (4, 26, 1, 59, 59)
ic| simple_time_series.py:227 in __call__()
    col_embeddings.shape: (26, 256)
    numeric_inputs.shape: (4, 26, 59)
ic| simple_time_series.py:231 in __call__()
    "before swap": 'before swap'
    repeated_numeric_indices.shape: (59, 26)
ic| simple_time_series.py:234 in __call__()
    "after swap": 'after swap'
    repeated_numeric_indices.shape: (26, 59)
ic| simple_time_series.py:236 in __call__()
    "Embedding!!": 'Embedding!!'
    numeric_col_embeddings.shape: (26, 59, 256)
ic| simple_time_series.py:244 in __call__()
    "Retiling": 'Retiling'
    numeric_col_embeddings.shape: (4, 26, 59, 256)
ic| simple_time_series.py:249 in __call__()
    "Before Broadcast": 'Before Broadcast'
    numeric_inputs.shape: (4, 26, 59)
    numeric_col_embeddings.shape: (4, 26, 59, 256)
ic| simple_time_series.py:262 in __call__()
    "Before where": 'Before where'
    numeric_broadcast.shape: (4, 26, 59, 256)
    nan_mask.

In [14]:
key = jax.random.PRNGKey(0)
key1, key2 = jax.random.split(key)

# Create random arrays
numeric_inputs = jax.random.normal(key1, (4, 26, 59))
numeric_col_embeddings = jax.random.normal(key2, (4, 26, 59, 256))
# numeric_inputs.shape: (4, 26, 59)
# numeric_col_embeddings.shape: (4, 26, 59, 256)
(numeric_inputs[:, :, :, None] * numeric_col_embeddings).shape

(4, 26, 59, 256)

In [15]:
time_series_regressor.tabulate(
    {"params": init_key, "dropout": dropout_key},
    batch,
    console_kwargs={"force_jupyter": True, "width": 120},
)

ic| simple_time_series.py:224 in __call__()
    mask.shape: (4, 26, 1, 59, 59)
ic| simple_time_series.py:227 in __call__()
    col_embeddings.shape: (26, 256)
    numeric_inputs.shape: (4, 26, 59)
ic| simple_time_series.py:231 in __call__()
    "before swap": 'before swap'
    repeated_numeric_indices.shape: (59, 26)
ic| simple_time_series.py:234 in __call__()
    "after swap": 'after swap'
    repeated_numeric_indices.shape: (26, 59)
ic| simple_time_series.py:236 in __call__()
    "Embedding!!": 'Embedding!!'
    numeric_col_embeddings.shape: (26, 59, 256)
ic| simple_time_series.py:244 in __call__()
    "Retiling": 'Retiling'
    numeric_col_embeddings.shape: (4, 26, 59, 256)
ic| simple_time_series.py:249 in __call__()
    "Before Broadcast": 'Before Broadcast'
    numeric_inputs.shape: (4, 26, 59)
    numeric_col_embeddings.shape: (4, 26, 59, 256)
ic| simple_time_series.py:262 in __call__()
    "Before where": 'Before where'
    numeric_broadcast.shape: (4, 26, 59, 256)
    nan_mask.

'\n\n'

In [16]:
def calculate_memory_footprint(params):
    """Calculate total memory footprint of JAX model parameters and total
    number of parameters."""
    total_bytes = 0
    # Flatten the parameter tree structure into a list of arrays
    flat_params, _ = tree_flatten(params)
    for param in flat_params:
        # Calculate bytes: number of elements * size of each element
        bytes_per_param = param.size * param.dtype.itemsize
        total_bytes += bytes_per_param
    return total_bytes


def count_parameters(params):
    return sum(jnp.prod(jnp.array(p.shape)) for p in jax.tree_util.tree_leaves(params))


mem = calculate_memory_footprint(vars)
total_params = count_parameters(vars)


print(f"Memory of custom: {mem / 1e6:.2f} MB with {total_params:,} parameters")

Memory of custom: 3.24 MB with 809,889 parameters


In [17]:
batch.shape

(4, 59, 26)

In [18]:
mts_root_key = random.PRNGKey(44)
mts_main_key, ts_params_key, ts_data_key = random.split(mts_root_key, 3)


def clip_gradients(gradients, max_norm):
    total_norm = jnp.sqrt(sum(jnp.sum(jnp.square(grad)) for grad in gradients.values()))
    scale = max_norm / (total_norm + 1e-6)
    clipped_gradients = jax.tree_map(
        lambda grad: jnp.where(total_norm > max_norm, grad * scale, grad), gradients
    )
    return clipped_gradients


def base_loss(inputs, outputs):
    # Create mask for nan inputs
    # inputs = jnp.swapaxes(inputs, 0, 1)
    inputs = jnp.swapaxes(inputs, 1, 2)
    nan_mask = jnp.isnan(inputs)
    inputs = jnp.where(nan_mask, jnp.zeros_like(inputs), inputs)
    outputs = jnp.where(nan_mask, jnp.zeros_like(outputs), outputs)
    # shift the outputs and inputs by one on either side 
    inputs = jnp.
    raw_loss = optax.squared_error(outputs, inputs)
    masked_loss = jnp.where(nan_mask, 0.0, raw_loss)
    loss = masked_loss.sum() / (~nan_mask).sum()

    return loss


def calculate_loss(params, state, inputs, dataset: SimpleDS, dropout_key, mask_key):
    outputs = state.apply_fn(
        {"params": params},
        # hp.mask_tensor(inputs, dataset, prng_key=mask_key),
        inputs,
        rngs={"dropout": dropout_key},
        deterministic=False,
    )
    loss = base_loss(inputs, outputs)
    # Create mask for nan inputs

    return loss


@jax.jit
def train_step(state: train_state.TrainState, batch, base_key):
    dropout_key, mask_key, new_key = jax.random.split(base_key, 3)

    def loss_fn(params):
        return calculate_loss(params, state, batch, train_ds, dropout_key, mask_key)

    grad_fn = jax.value_and_grad(loss_fn)

    # (loss, individual_losses), grad = grad_fn(state.params)
    loss, grad = grad_fn(state.params)
    # grad = replace_nans(grad)
    # grad = clip_gradients(grad, 1.0)
    state = state.apply_gradients(grads=grad)

    return state, loss, new_key


def evaluate(params, state, inputs, dataset: SimpleDS, dropout_key, mask_key):
    outputs = state.apply_fn(
        {"params": params},
        # hp.mask_tensor(inputs, dataset, prng_key=mask_key),
        inputs,
        deterministic=True,
    )
    loss = base_loss(inputs, outputs)
    return loss


@jax.jit
def eval_step(state: train_state.TrainState, batch, base_key):
    mask_key, dropout_key, new_key = jax.random.split(base_key, 3)

    def loss_fn(params):
        return evaluate(params, state, batch, train_ds, dropout_key, mask_key)

    # (loss, individual_losses), grad = grad_fn(state.params)
    loss = loss_fn(state.params)
    return loss, new_key


def create_train_state(model, prng, batch, lr):
    init_key, dropout_key = random.split(prng)
    params = model.init(
        {"params": init_key, "dropout": dropout_key}, batch, deterministic=False
    )
    # optimizer = optax.chain(optax.adam(lr))
    optimizer = optax.chain(optax.clip_by_global_norm(0.4), optax.adam(lr))
    # optimizer_state = optimizer.init(params)
    return train_state.TrainState.create(
        apply_fn=model.apply,
        params=params["params"],
        tx=optimizer,
        # tx_state=optimizer_state,
    )


batch_size = 2
# batch = train_ds[0]
# state = create_train_state(time_series_regressor, mts_main_key, batch, 0.0001)
state = create_train_state(time_series_regressor, mts_main_key, batch, 0.0001)

ic| simple_time_series.

py:224 in __call__()
    mask.shape: (4, 26, 1, 59, 59)
ic| simple_time_series.py:227 in __call__()
    col_embeddings.shape: (26, 256)
    numeric_inputs.shape: (4, 26, 59)
ic| simple_time_series.py:231 in __call__()
    "before swap": 'before swap'
    repeated_numeric_indices.shape: (59, 26)
ic| simple_time_series.py:234 in __call__()
    "after swap": 'after swap'
    repeated_numeric_indices.shape: (26, 59)
ic| simple_time_series.py:236 in __call__()
    "Embedding!!": 'Embedding!!'
    numeric_col_embeddings.shape: (26, 59, 256)
ic| simple_time_series.py:244 in __call__()
    "Retiling": 'Retiling'
    numeric_col_embeddings.shape: (4, 26, 59, 256)
ic| simple_time_series.py:249 in __call__()
    "Before Broadcast": 'Before Broadcast'
    numeric_inputs.shape: (4, 26, 59)
    numeric_col_embeddings.shape: (4, 26, 59, 256)
ic| simple_time_series.py:262 in __call__()
    "Before where": 'Before where'
    numeric_broadcast.shape: (4, 26, 59, 256)
    nan_mask.shape: (4, 26, 59)
ic| 

In [19]:
writer_name = "RowsLast"

writer_time = dt.now().strftime("%Y-%m-%dT%H:%M:%S")
model_name = writer_time + writer_name
train_summary_writer = SummaryWriter("runs/" + model_name)


test_set_key = random.PRNGKey(4454)

train_data_loader = DataLoader(train_ds, batch_size=512, shuffle=True)
test_data_loader = DataLoader(test_ds, batch_size=512, shuffle=True)
batch_count = 0
base_key = random.PRNGKey(42)

# Disable IC for training
ic.disable()
for j in trange(1, desc=f"epochs for {train_summary_writer.log_dir}"):
    # arrs = train_data_loader()
    for i in tqdm(train_data_loader, leave=False, desc="batches"):
        # for i in trange(len(pre_train) // batch_size, leave=False):
        # for i in trange(len(pre_train) // batch_size //10, leave=False):
        # batch = make_batch(train_ds, i[0], 4)

        state, loss, base_key = train_step(state, jnp.array(i), base_key)
        if jnp.isnan(loss):
            raise ValueError("Nan Value in loss, stopping")
        batch_count += 1

        if batch_count % 1 == 0:
            train_summary_writer.add_scalar(
                "loss/loss", np.array(loss.item()), batch_count
            )
        if batch_count % 10 == 0:
            test_loss, base_key = eval_step(
                state, jnp.array(next(iter(test_data_loader))), base_key
            )
            train_summary_writer.add_scalar(
                "loss/test_loss", np.array(test_loss.item()), batch_count
            )

train_summary_writer.close()

epochs for runs/2024-05-28T19:51:37RowsLast:   0%|          | 0/1 [00:00<?, ?it/s]

batches:   0%|          | 0/196 [00:00<?, ?it/s]

In [20]:
len(train_ds) / 512

195.310546875

In [21]:
import orbax
from flax.training import orbax_utils

ckpts_dir = "ckpts"
absolute_ckpts_dir = os.path.abspath(os.path.join(ckpts_dir, model_name))

ckpt = {"state": state, "batch_count": batch_count}

orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
save_args = orbax_utils.save_args_from_target(ckpt)
orbax_checkpointer.save(absolute_ckpts_dir, ckpt, save_args=save_args, force=True)



In [33]:
def return_results(state, dataset, idx=0, mask_start: int = None):
    inputs = dataset[idx]
    if mask_start:
        inputs[mask_start:] = np.nan
    inputs = jnp.array([inputs])
    outputs = state.apply_fn(
        {"params": state.params},
        # hp.mask_tensor(jnp.array([train_ds[0]]), dataset, prng_key=key),
        inputs,
        deterministic=True,
    )
    return outputs, inputs

In [38]:
def show_results_df(state, base_df, dataset, idx: int = 0, mask_start: int = None):
    outputs, inputs = return_results(state, dataset, idx=idx, mask_start=mask_start)
    outputs = jnp.squeeze(outputs)
    df_pred = pd.DataFrame(outputs.T)
    df_pred.columns = base_df.columns[1:]

    inputs = jnp.squeeze(inputs)
    df_actual_masked = pd.DataFrame(inputs)
    df_actual_masked.columns = base_df.columns[1:]
    diff_df = df_pred - df_actual_masked

    inputs_no_mask = jnp.array([dataset[idx]])
    df_no_mask = pd.DataFrame(jnp.squeeze(inputs_no_mask))
    df_no_mask.columns = base_df.columns[1:]
    diff_df_no_mask = df_pred - df_no_mask
    return {
        "pred": df_pred,
        "actual_maksed": df_actual_masked,
        "diff_masked": diff_df,
        "diff_no_mask": diff_df_no_mask,
    }


res = show_results_df(state, train_df, train_ds, idx=0, mask_start=10)

In [40]:
res["diff_masked"].head(11)

Unnamed: 0,time_step,planet0_x,planet0_y,planet1_x,planet1_y,planet2_x,planet2_y,planet0_m,planet0_a,planet0_e,...,planet3_x,planet3_y,planet3_m,planet3_a,planet3_e,planet4_x,planet4_y,planet4_m,planet4_a,planet4_e
0,-0.35777,-0.047076,-0.215068,0.025685,-0.092193,0.054196,-0.289221,-0.006884,0.013739,-0.112133,...,,,,,,,,,,
1,0.044599,0.019554,-0.15554,-0.045098,-0.100935,0.009763,-0.276559,-0.068556,-0.061455,-0.160055,...,,,,,,,,,,
2,0.118069,0.039835,-0.051931,0.057597,-0.078524,0.037579,-0.31871,-0.014227,-0.063348,-0.100266,...,,,,,,,,,,
3,0.073607,-0.000671,-0.151576,-0.029292,-0.158168,-0.010127,-0.379353,-0.052349,-0.041198,-0.107174,...,,,,,,,,,,
4,0.001059,-0.037753,-0.190545,-0.082807,-0.109776,-0.047755,-0.461931,-0.07962,0.022997,-0.061437,...,,,,,,,,,,
5,-0.016992,-0.087956,-0.167907,-0.07684,-0.010825,0.001469,-0.472241,-0.065588,0.030891,-0.04824,...,,,,,,,,,,
6,-0.006848,-0.034331,-0.095636,-0.070119,0.044749,0.040074,-0.430049,-0.026372,0.008751,-0.054768,...,,,,,,,,,,
7,-0.033706,0.035607,-0.042126,0.005883,0.06353,-0.020378,-0.38508,-0.046992,-0.074249,-0.118004,...,,,,,,,,,,
8,-0.040974,0.073527,-0.133016,-0.015784,0.091087,-0.028764,-0.316528,-0.046147,-0.037118,-0.093719,...,,,,,,,,,,
9,-0.098638,0.054315,-0.167286,-0.124322,0.004615,-0.035692,-0.315266,-0.04402,-0.001616,-0.033457,...,,,,,,,,,,


In [41]:
res["diff_no_mask"].head(11)

Unnamed: 0,time_step,planet0_x,planet0_y,planet1_x,planet1_y,planet2_x,planet2_y,planet0_m,planet0_a,planet0_e,...,planet3_x,planet3_y,planet3_m,planet3_a,planet3_e,planet4_x,planet4_y,planet4_m,planet4_a,planet4_e
0,-0.35777,-0.047076,-0.215068,0.025685,-0.092193,0.054196,-0.289221,-0.006884,0.013739,-0.112133,...,,,,,,,,,,
1,0.044599,0.019554,-0.15554,-0.045098,-0.100935,0.009763,-0.276559,-0.068556,-0.061455,-0.160055,...,,,,,,,,,,
2,0.118069,0.039835,-0.051931,0.057597,-0.078524,0.037579,-0.31871,-0.014227,-0.063348,-0.100266,...,,,,,,,,,,
3,0.073607,-0.000671,-0.151576,-0.029292,-0.158168,-0.010127,-0.379353,-0.052349,-0.041198,-0.107174,...,,,,,,,,,,
4,0.001059,-0.037753,-0.190545,-0.082807,-0.109776,-0.047755,-0.461931,-0.07962,0.022997,-0.061437,...,,,,,,,,,,
5,-0.016992,-0.087956,-0.167907,-0.07684,-0.010825,0.001469,-0.472241,-0.065588,0.030891,-0.04824,...,,,,,,,,,,
6,-0.006848,-0.034331,-0.095636,-0.070119,0.044749,0.040074,-0.430049,-0.026372,0.008751,-0.054768,...,,,,,,,,,,
7,-0.033706,0.035607,-0.042126,0.005883,0.06353,-0.020378,-0.38508,-0.046992,-0.074249,-0.118004,...,,,,,,,,,,
8,-0.040974,0.073527,-0.133016,-0.015784,0.091087,-0.028764,-0.316528,-0.046147,-0.037118,-0.093719,...,,,,,,,,,,
9,-0.098638,0.054315,-0.167286,-0.124322,0.004615,-0.035692,-0.315266,-0.04402,-0.001616,-0.033457,...,,,,,,,,,,


In [25]:
idx = 100
test_result = return_results(state, test_ds, idx=idx)
test_result = jnp.squeeze(test_result)
df_pred = pd.DataFrame(test_result.T)
df_actual_masked = pd.DataFrame(train_ds[idx])
df_pred.columns = train_df.drop("idx", axis=1).columns
df_actual_masked.columns = train_df.drop("idx", axis=1).columns
df_pred.head()

Unnamed: 0,time_step,planet0_x,planet0_y,planet1_x,planet1_y,planet2_x,planet2_y,planet0_m,planet0_a,planet0_e,...,planet3_x,planet3_y,planet3_m,planet3_a,planet3_e,planet4_x,planet4_y,planet4_m,planet4_a,planet4_e
0,-0.490324,0.957647,-0.204687,2.266601,-0.207227,2.309985,-0.248194,1.624715,0.949058,0.095831,...,1.268444,-0.223621,4.195572,1.443578,1.245591,1.57456,-0.151188,4.361728,1.813501,0.991018
1,0.80838,0.777317,0.582421,2.158467,0.528388,2.144089,0.519811,1.529973,0.927393,0.074323,...,1.125876,0.751289,4.206074,1.346461,1.109982,1.501735,0.711718,4.370624,1.82615,1.12584
2,1.616103,0.105701,0.867759,1.886571,0.892146,2.031134,1.009875,1.565807,0.967285,0.132231,...,0.855595,1.125196,4.25492,1.391695,1.069855,1.495232,1.247108,4.383264,1.85295,1.185014
3,2.290273,-0.633031,0.713062,1.609085,1.197222,1.852196,1.368202,1.540126,0.960211,0.064489,...,0.292088,1.281565,4.232001,1.344712,1.058634,1.181173,1.452768,4.411496,1.791916,1.13309
4,2.971382,-1.019092,0.230082,1.348955,1.649029,1.61532,1.689265,1.515148,0.913508,0.048351,...,-0.418265,1.326439,4.193496,1.34836,1.069153,0.681974,1.610971,4.411561,1.733157,1.071785


In [26]:
df_pred_masked = pd.DataFrame(jnp.squeeze(test_result_masked).T)
df_pred_masked.columns = train_df.drop("idx", axis=1).columns
df_pred_masked.head()

NameError: name 'test_result_masked' is not defined

In [None]:
df_pred

Unnamed: 0,time_step,planet0_x,planet0_y,planet1_x,planet1_y,planet2_x,planet2_y,planet0_m,planet0_a,planet0_e,...,planet3_x,planet3_y,planet3_m,planet3_a,planet3_e,planet4_x,planet4_y,planet4_m,planet4_a,planet4_e
0,-0.872413,2.239077,-0.331503,1.14134,-0.434859,2.161565,-0.43446,1.228074,2.738998,1.792781,...,1.332433,-0.480385,2.072505,1.62806,1.010441,1.674201,-0.506628,3.700562,2.0177,1.662664
1,0.596177,2.366127,0.40071,0.787214,0.648504,1.985467,0.466675,1.096087,2.858291,1.734295,...,1.326582,0.583495,2.195369,1.500912,0.828569,1.618879,0.567062,3.649784,2.112479,1.818143
2,1.31766,2.328211,0.65454,0.327443,0.906084,1.86952,0.77583,1.149743,2.963109,1.778262,...,1.120704,0.897498,2.223309,1.50139,0.810674,1.582358,0.954042,3.643643,2.106834,1.9151
3,1.765285,2.172432,0.939959,-0.236792,0.886375,1.761162,1.062812,1.113723,3.024565,1.833238,...,0.713003,1.078541,2.150734,1.39656,0.838221,1.433248,1.184962,3.674434,2.03986,1.856166
4,2.335949,2.090542,1.411212,-0.750069,0.587026,1.620063,1.483251,1.088198,3.00136,1.864917,...,0.271445,1.275727,2.05303,1.435364,0.803338,1.144273,1.460208,3.671819,1.939544,1.767251
5,3.139543,1.877757,1.794686,-1.015816,0.159706,1.451967,1.710505,1.145259,2.943053,1.939366,...,-0.178328,1.450426,1.97744,1.480067,0.792377,0.908102,1.815544,3.690362,1.875548,1.711974
6,3.77859,1.611909,1.959486,-1.005651,-0.271707,1.365617,1.880034,1.138903,2.841297,1.881519,...,-0.74374,1.416732,1.99586,1.516062,0.786899,0.526088,1.948521,3.719086,1.906242,1.688567
7,4.369803,1.3366,1.954697,-0.651495,-0.746266,1.151062,1.996518,1.078638,2.800012,1.816502,...,-1.300079,1.07204,2.014664,1.531587,0.807726,0.020142,1.806325,3.664069,1.971412,1.694124
8,4.957397,1.15719,2.033644,-0.032156,-0.838292,0.787889,2.209671,1.12333,2.826603,1.815081,...,-1.405501,0.717358,2.075088,1.46818,0.79027,-0.416963,1.741753,3.612466,1.960248,1.801533
9,5.576791,0.912068,2.204164,0.443044,-0.673263,0.315479,2.244567,1.093641,2.868494,1.885712,...,-1.484567,0.297686,2.033107,1.426914,0.853074,-0.845995,1.653671,3.564578,1.89506,1.874196


In [None]:
df_actual_masked

Unnamed: 0,time_step,planet0_x,planet0_y,planet1_x,planet1_y,planet2_x,planet2_y,planet0_m,planet0_a,planet0_e,...,planet3_x,planet3_y,planet3_m,planet3_a,planet3_e,planet4_x,planet4_y,planet4_m,planet4_a,planet4_e
0,,,,,,,,,,,...,,,,,,,,,,
1,,,,,,,,,,,...,,,,,,,,,,
2,,,,,,,,,,,...,,,,,,,,,,
3,,,,,,,,,,,...,,,,,,,,,,
4,,,,,,,,,,,...,,,,,,,,,,
5,,,,,,,,,,,...,,,,,,,,,,
6,,,,,,,,,,,...,,,,,,,,,,
7,,,,,,,,,,,...,,,,,,,,,,
8,,,,,,,,,,,...,,,,,,,,,,
9,,,,,,,,,,,...,,,,,,,,,,


In [None]:
# Subtract the actual from the predicted
df_diff = df_pred - df_actual_masked
df_diff

Unnamed: 0,time_step,planet0_x,planet0_y,planet1_x,planet1_y,planet2_x,planet2_y,planet0_m,planet0_a,planet0_e,...,planet3_x,planet3_y,planet3_m,planet3_a,planet3_e,planet4_x,planet4_y,planet4_m,planet4_a,planet4_e
0,,,,,,,,,,,...,,,,,,,,,,
1,,,,,,,,,,,...,,,,,,,,,,
2,,,,,,,,,,,...,,,,,,,,,,
3,,,,,,,,,,,...,,,,,,,,,,
4,,,,,,,,,,,...,,,,,,,,,,
5,,,,,,,,,,,...,,,,,,,,,,
6,,,,,,,,,,,...,,,,,,,,,,
7,,,,,,,,,,,...,,,,,,,,,,
8,,,,,,,,,,,...,,,,,,,,,,
9,,,,,,,,,,,...,,,,,,,,,,


In [None]:
test_result_masked = return_results(state, train_ds, mask_start=10)

In [None]:
model_name = f"big_train_{dt.now()}_"

current_dir = os.getcwd()

if not os.path.exists("./pre_trained_models/"):
    os.makedirs("./pre_trained_models/")

path = os.path.join(current_dir, "./pre_trained_models/")


ckpt_dir = f"./pre_trained_models/{model_name}"

# checkpoints.save_checkpoint(
#     ckpt_dir=path, target=state, step=batch_count, overwrite=True, prefix=model_name
# )