<https://www.kaggle.com/competitions/predict-energy-behavior-of-prosumers>


In [1]:
import os

os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.95"

In [2]:
import jax.numpy as jnp  # Oddly works in colab to set gpu

arr = jnp.array([1, 2, 3])
arr.devices()

{CpuDevice(id=0)}

In [3]:
!pip freeze | grep tensor

tensorboard==2.15.2
tensorboard-data-server==0.7.2
tensorflow==2.15.0
tensorflow-estimator==2.15.0
tensorflow-io-gcs-filesystem==0.36.0
tensorflow-macos==2.15.0
tensorstore==0.1.54


In [4]:
import icecream
from icecream import ic

icecream.install()
ic_disable = True
if ic_disable:
    ic.disable()

In [5]:
import os
from dataclasses import dataclass, field
from datetime import datetime as dt
from itertools import chain

import hephaestus as hp
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax
import pandas as pd
import seaborn as sns
from flax import linen as nn
from flax import struct  # Flax dataclasses
from flax.training import checkpoints, train_state
from flax.training.early_stopping import EarlyStopping
from icecream import ic
from jax import profiler, random
from jax.tree_util import tree_flatten
from torch.utils.tensorboard import SummaryWriter
from tqdm.notebook import tqdm, trange

pd.options.mode.copy_on_write = True

In [6]:
data = pd.read_parquet("data/processed_energy.parquet")

In [7]:
def sinify_datetime(df, col):
    df[col + "_base_year"] = df[col].dt.year
    df[col + "_day_sin"] = np.sin(2 * np.pi * df[col].dt.day / 365.25)
    df[col + "_day_cos"] = np.cos(2 * np.pi * df[col].dt.day / 365.25)
    df[col + "_hour_sin"] = np.sin(2 * np.pi * df[col].dt.hour / 24)
    df[col + "_hour_cos"] = np.cos(2 * np.pi * df[col].dt.hour / 24)
    df[col + "_is_weekend"] = (
        df[col].dt.dayofweek.isin([5, 6]).astype(str).map({"True": 1, "False": 0})
    )

    return df

In [8]:
train_ds = hp.TabularTimeSeriesData(
    data.copy(),
    batch_size=3120 // 2,  # target_column="euros_per_mwh"
    target_column="target",
)
train_ds[0]

(Array([[78, 91, 95, 75],
        [78, 86, 95, 75],
        [78, 91, 95, 76],
        ...,
        [81, 86, 84, 73],
        [81, 91, 84, 75],
        [81, 86, 84, 75]], dtype=int32),
 Array([[-2.1993124 , -1.73205   , -1.6867976 , ..., -0.5924703 ,
         -0.481902  , -0.68075854],
        [-2.1993124 , -1.7320483 , -1.6867976 , ..., -0.5924703 ,
         -0.481902  , -0.68075854],
        [-2.1993124 , -1.7320465 , -1.6357528 , ..., -0.5924703 ,
         -0.481902  , -0.68075854],
        ...,
        [-2.1935778 , -1.7313237 , -0.2575419 , ..., -0.5924703 ,
         -0.481902  , -0.68075854],
        [-2.1935778 , -1.731322  , -0.20649704, ..., -0.5924703 ,
         -0.481902  , -0.68075854],
        [-2.1935778 , -1.7313204 , -0.20649704, ..., -0.5924703 ,
         -0.481902  , -0.68075854]], dtype=float32),
 Array([[  0.713],
        [ 96.59 ],
        [  0.   ],
        ...,
        [550.723],
        [  0.   ],
        [177.887]], dtype=float32))

In [9]:
test_ds = hp.TabularTimeSeriesData(
    data.copy(),
    batch_size=3120 // 2,  # target_column="euros_per_mwh"
    target_column="target",
    type="test",
)
test_ds[0]

(Array([[78, 86, 80, 76],
        [78, 91, 80, 74],
        [78, 86, 80, 74],
        ...,
        [81, 91, 83, 73],
        [81, 86, 83, 73],
        [81, 91, 83, 75]], dtype=int32),
 Array([[ 0.86485595,  1.5875819 ,  0.55917567, ...,  0.4454942 ,
          0.06976517,  1.1171488 ],
        [ 0.86485595,  1.5875837 ,  0.6102205 , ...,  0.4454942 ,
          0.06976517,  1.1171488 ],
        [ 0.86485595,  1.5875853 ,  0.6102205 , ...,  0.4454942 ,
          0.06976517,  1.1171488 ],
        ...,
        [ 0.8705905 ,  1.1045619 , -0.6148558 , ..., -0.5924703 ,
         -0.481902  , -0.68075854],
        [ 0.8705905 ,  1.1045637 , -0.6148558 , ..., -0.5924703 ,
         -0.481902  , -0.68075854],
        [ 0.8705905 ,  1.1045654 , -0.563811  , ..., -0.5924703 ,
         -0.481902  , -0.68075854]], dtype=float32),
 Array([[  6.828],
        [185.183],
        [649.653],
        ...,
        [  0.   ],
        [215.54 ],
        [  0.   ]], dtype=float32))

In [10]:
test_ds.n_cat_cols, train_ds.n_cat_cols

(4, 4)

In [11]:
time_series_regressor = hp.time_series.MaskedTimeSeriesRegClass(train_ds)

In [12]:
def calculate_memory_footprint(params):
    """Calculate total memory footprint of JAX model parameters."""
    total_bytes = 0
    # Flatten the parameter tree structure into a list of arrays
    flat_params, _ = tree_flatten(params)
    for param in flat_params:
        # Calculate bytes: number of elements * size of each element
        bytes_per_param = param.size * param.dtype.itemsize
        total_bytes += bytes_per_param
    return total_bytes

In [13]:
@struct.dataclass
class Batch:
    numeric: jnp.ndarray
    categorical: jnp.ndarray
    y: jnp.ndarray

    @classmethod
    def from_ds_idx(
        cls,
        ds: hp.TabularTimeSeriesData,
        start: int = None,
        count: int = None,
        arr: list = None,
    ):
        numeric_list = []
        categorical_list = []
        y_list = []

        if (start is None and count is not None) or (
            start is not None and count is None
        ):
            if count is None:
                problem = "count"
            else:
                problem = "start"
            raise ValueError(
                "Both `count` and `start` must be defined or neither"
                + f"got only {problem} defined"
            )
        if (start is None) and (arr is None):
            raise ValueError("Do not provide start/count and arr values")

        if arr is not None:
            indexer = arr

        else:
            indexer = range(start + 1, start + count + 1)

        for idx in indexer:
            if arr is not None:
                idx = idx.item()
            categorical, numeric, y = ds[idx]
            numeric_list.append(numeric)
            categorical_list.append(categorical)
            y_list.append(y)

        numeric = jnp.array(numeric_list)
        categorical = jnp.array(categorical_list)
        y = jnp.array(y_list)
        return cls(numeric=numeric, categorical=categorical, y=y)


ts_train = Batch.from_ds_idx(train_ds, 0, 1)
ts_test = Batch.from_ds_idx(test_ds, 0, 1)
print(ts_train.numeric.shape, ts_test.numeric.shape)
print(ts_train.y.shape, ts_test.y.shape)
print(ts_train.categorical.shape, ts_test.categorical.shape)

assert ts_train.categorical.shape == ts_test.categorical.shape

(1, 1560, 65) (1, 1560, 65)
(1, 1560, 1) (1, 1560, 1)
(1, 1560, 4) (1, 1560, 4)


In [14]:
# %%memit
# profiler.start_trace(
#     "/content/drive/MyDrive/Colab Notebooks/Hephaestus/runs"
# )
mts_root_key = random.PRNGKey(44)
mts_main_key, ts_params_key, ts_data_key = random.split(mts_root_key, 3)


def replace_nans(grads, replace_value=0):
    """Replace NaN values in gradients with a specified value."""
    return jax.tree_map(lambda g: jnp.where(jnp.isnan(g), replace_value, g), grads)


def clip_gradients(gradients, max_norm):
    total_norm = jnp.sqrt(sum(jnp.sum(jnp.square(grad)) for grad in gradients.values()))
    scale = max_norm / (total_norm + 1e-6)
    clipped_gradients = jax.tree_map(
        lambda grad: jnp.where(total_norm > max_norm, grad * scale, grad), gradients
    )
    return clipped_gradients


# @jax.jit
def calculate_loss(params, state, inputs, dataset):
    numeric_loss_scale = 1.0
    target_loss_scale = 0.001
    ic("here")
    category_out, numeric_out, target_out = state.apply_fn(
        {"params": params},
        hp.mask_tensor(inputs.categorical, dataset, prng_key=ts_data_key),
        hp.mask_tensor(inputs.numeric, dataset, prng_key=ts_data_key),
    )
    ic("here")
    numeric_col_tokens = dataset.numeric_col_tokens.clone()
    numeric_col_tokens = numeric_col_tokens[None, None, :]
    repeated_numeric_col_tokens = jnp.tile(
        numeric_col_tokens, (category_out.shape[0], category_out.shape[1], 1)
    )

    ic(repeated_numeric_col_tokens.shape, inputs.categorical.shape, category_out.shape)
    categorical_targets = jnp.concatenate(
        [
            inputs.categorical,
            repeated_numeric_col_tokens,
        ],
        axis=-1,
    )
    ic(categorical_targets.shape)
    categorical_loss = optax.softmax_cross_entropy_with_integer_labels(
        category_out, categorical_targets
    ).mean()
    ic()
    target_loss = optax.squared_error(target_out, jnp.squeeze(inputs.y, -1)).mean()
    ic()
    numeric_loss = optax.squared_error(numeric_out, inputs.numeric).mean()
    ic()
    total_loss = (
        categorical_loss
        + numeric_loss * numeric_loss_scale
        + target_loss * target_loss_scale
    )
    ic(total_loss, type(total_loss))
    return total_loss, {
        "target_loss": target_loss.astype(float),
        "numeric_loss": numeric_loss.astype(float),
        "categorical_loss": categorical_loss.astype(float),
    }


@jax.jit
def eval_step(state: train_state.TrainState, batch):

    def loss_fn(params):
        return calculate_loss(params, state, batch, train_ds)

    # (loss, individual_losses), grad = grad_fn(state.params)
    loss, loss_dict = loss_fn(state.params)
    return loss, loss_dict


@jax.jit
def train_step(state: train_state.TrainState, batch):

    def loss_fn(params):
        return calculate_loss(params, state, batch, train_ds)

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)

    # (loss, individual_losses), grad = grad_fn(state.params)
    (loss, loss_dict), grad = grad_fn(state.params)
    # grad = replace_nans(grad)
    # grad = clip_gradients(grad, 1.0)
    state = state.apply_gradients(grads=grad)

    return state, loss, loss_dict


def create_train_state(model, prng, batch, lr):
    params = model.init(prng, batch.categorical, batch.numeric)
    # optimizer = optax.chain(optax.adam(lr))
    optimizer = optax.chain(optax.clip_by_global_norm(0.4), optax.adam(lr))
    # optimizer_state = optimizer.init(params)
    return train_state.TrainState.create(
        apply_fn=model.apply,
        params=params["params"],
        tx=optimizer,
        # tx_state=optimizer_state,
    )


batch_size = 2
batch = Batch.from_ds_idx(train_ds, 0, batch_size)

state = create_train_state(time_series_regressor, mts_main_key, batch, 0.0001)

In [15]:
# nn.enable_named_call()
# with jax.profiler.trace("./runs", create_perfetto_link=False):
jax.profiler.start_trace("runs/")
state, loss, loss_dict = train_step(state, batch)
jax.profiler.stop_trace()

In [None]:
summary_writer = SummaryWriter(
    "runs/" + dt.now().strftime("%Y-%m-%dT%H:%M:%S") + "_" + "shuffled"
)


def create_iterations(ds, batch_size, prng):
    arr = jnp.arange(len(ds))
    arr = jax.random.permutation(prng, arr)

    num_full_groups = len(arr) // batch_size
    size_last_group = len(arr) % batch_size

    if size_last_group > 0:
        num_full_groups += 1

    arr = jnp.split(arr, num_full_groups)
    return arr


test_batch = Batch.from_ds_idx(test_ds, 0, 2)

batch_count = 0
for j in trange(100):
    arrs = create_iterations(train_ds, 2, ts_data_key)
    test_arrs = create_iterations(test_ds, 2, ts_data_key)
    for i in tqdm(arrs, leave=False):
        # for i in trange(len(pre_train) // batch_size, leave=False):
        # for i in trange(len(pre_train) // batch_size //10, leave=False):
        batch = Batch.from_ds_idx(ds=train_ds, arr=i)

        state, loss, loss_dict = train_step(state, batch)
        summary_writer.add_scalar("ts_train/loss", np.array(loss.item()), batch_count)
        summary_writer.add_scalar(
            "ts_train/categorical", np.array(loss_dict["categorical_loss"]), batch_count
        )
        summary_writer.add_scalar(
            "ts_train/numeric", np.array(loss_dict["numeric_loss"]), batch_count
        )
        summary_writer.add_scalar(
            "ts_train/target", np.array(loss_dict["target_loss"]), batch_count
        )
        batch_count += 1

        if batch_count % 1 == 0:

            test_loss, test_loss_dict = eval_step(state, test_batch)
            summary_writer.add_scalar(
                "ts_test/loss", np.array(test_loss.item()), batch_count
            )
            summary_writer.add_scalar(
                "ts_test/categorical",
                np.array(test_loss_dict["categorical_loss"]),
                batch_count,
            )
            summary_writer.add_scalar(
                "ts_test/numeric", np.array(test_loss_dict["numeric_loss"]), batch_count
            )
            summary_writer.add_scalar(
                "ts_test/target", np.array(test_loss_dict["target_loss"]), batch_count
            )
        # print(f"epoch: {j}, batch: {i}, {loss}")


# profiler.stop_trace()

In [None]:
(
    Batch.from_ds_idx(train_ds, 0, 2).categorical.shape,
    Batch.from_ds_idx(test_ds, 0, 2).categorical.shape,
)

((2, 1560, 4), (2, 1560, 0))

In [None]:
train_ds.numeric_col_tokens.shape, test_ds.numeric_col_tokens.shape

((65,), (69,))

In [None]:
(
    Batch.from_ds_idx(train_ds, 0, 2).numeric.shape,
    Batch.from_ds_idx(test_ds, 0, 2).numeric.shape,
)

((2, 1560, 65), (2, 1560, 69))

In [None]:
test_batch.y.shape, test_batch.numeric.shape, test_batch.categorical.shape

((2, 1561, 1), (2, 1561, 69), (2, 1561, 0))

In [None]:
test_batch.categorical

Array([], dtype=float32)

In [None]:
test_ds[0][1].shape, train_ds[0][1].shape

((1561, 69), (1561, 65))

In [None]:
Batch.from_ds_idx(test_ds, arr=[jnp.array([0])])

Batch(numeric=Array([], dtype=float32), categorical=Array([], dtype=float32), y=Array([], dtype=float32))

In [None]:
model_name = f"big_train_{dt.now()}_"

current_dir = os.getcwd()

if not os.path.exists(f"./pre_trained_models/"):
    os.makedirs(f"./pre_trained_models/")

path = os.path.join(current_dir, f"./pre_trained_models/")


ckpt_dir = f"./pre_trained_models/{model_name}"

checkpoints.save_checkpoint(
    ckpt_dir=path, target=state, step=batch_count, overwrite=True, prefix=model_name
)

In [None]:
def jax_array_memory_usage(array):
    """Calculate the memory usage of a JAX array in bytes."""
    # Get the number of bytes per element based on the data type
    bytes_per_element = array.dtype.itemsize
    # Calculate the total number of elements in the array
    total_elements = np.prod(array.shape)
    # Calculate total memory usage
    memory_usage_bytes = bytes_per_element * total_elements
    return memory_usage_bytes


cat_memory_usage = jax_array_memory_usage(batch.categorical)
num_memory_usage = jax_array_memory_usage(batch.numeric)
memory_usage = cat_memory_usage + num_memory_usage
memory_usage_gb = memory_usage / 1024 / 1024 / 1024
print(f"Memory usage: {memory_usage} bytes")
print(f"Memory usage: {memory_usage_gb} gb")