<https://www.kaggle.com/competitions/predict-energy-behavior-of-prosumers>


In [5]:
import os
from dataclasses import dataclass
from datetime import datetime as dt

import icecream
import jax
import jax.numpy as jnp  # Oddly works in colab to set gpu
import numpy as np
import optax
import pandas as pd
from flax.training import checkpoints, train_state
from icecream import ic
from jax import profiler, random
from jax.tree_util import tree_flatten
from torch.utils.tensorboard import SummaryWriter
from tqdm.notebook import trange

import hephaestus as hp

icecream.install()
ic_disable = True
if ic_disable:
    ic.disable()
pd.options.mode.copy_on_write = True

In [6]:
def sinify_datetime(df, col):
    df[col + "_base_year"] = df[col].dt.year
    df[col + "_day_sin"] = np.sin(2 * np.pi * df[col].dt.day / 365.25)
    df[col + "_day_cos"] = np.cos(2 * np.pi * df[col].dt.day / 365.25)
    df[col + "_hour_sin"] = np.sin(2 * np.pi * df[col].dt.hour / 24)
    df[col + "_hour_cos"] = np.cos(2 * np.pi * df[col].dt.hour / 24)
    df[col + "_is_weekend"] = (
        df[col].dt.dayofweek.isin([5, 6]).astype(str).map({"True": 1, "False": 0})
    )

    return df

In [7]:
data = pd.read_csv("data/processed_energy.csv")

In [8]:
data.columns

Index(['is_business', 'is_consumption', 'county_name', 'product_name',
       'level_0', 'level_1', 'target', 'prediction_unit_id', 'month', 'day',
       'hour', 'dayofweek', 'dayofyear', 'eic_count', 'installed_capacity',
       'lowest_price_per_mwh', 'highest_price_per_mwh', 'euros_per_mwh',
       'temperature_fcast_mean', 'dewpoint_fcast_mean',
       'cloudcover_high_fcast_mean', 'cloudcover_low_fcast_mean',
       'cloudcover_mid_fcast_mean', 'cloudcover_total_fcast_mean',
       '10_metre_u_wind_component_fcast_mean',
       '10_metre_v_wind_component_fcast_mean',
       'direct_solar_radiation_fcast_mean',
       'surface_solar_radiation_downwards_fcast_mean', 'snowfall_fcast_mean',
       'total_precipitation_fcast_mean', 'temperature_fcast_mean_by_county',
       'dewpoint_fcast_mean_by_county', 'cloudcover_high_fcast_mean_by_county',
       'cloudcover_low_fcast_mean_by_county',
       'cloudcover_mid_fcast_mean_by_county',
       'cloudcover_total_fcast_mean_by_county',
 

In [9]:
pre_train = hp.TabularTimeSeriesData(
    data,
    batch_size=3120 * 1,  # target_column="euros_per_mwh"
    target_column="target",
)  # .head(2688 * 100)

# pre_train = hp.TabularDS(df)

In [10]:
time_series_regressor = hp.time_series.MaskedTimeSeriesRegression(pre_train)

In [11]:
# time_length = 1000 # 6336 * 3  # 112 * 3 #
test_cat, test_num, y = pre_train[0]

test_num_mask = hp.mask_tensor(test_num, pre_train)
test_cat_mask = hp.mask_tensor(test_cat, pre_train)

In [12]:
def calculate_memory_footprint(params):
    """Calculate total memory footprint of JAX model parameters."""
    total_bytes = 0
    # Flatten the parameter tree structure into a list of arrays
    flat_params, _ = tree_flatten(params)
    for param in flat_params:
        # Calculate bytes: number of elements * size of each element
        bytes_per_param = param.size * param.dtype.itemsize
        total_bytes += bytes_per_param
    return total_bytes

In [13]:
tsm_root_key = random.PRNGKey(44)
tsm_main_key, ts_params_key, ts_data_key = random.split(tsm_root_key, 3)

time_series_regressor_vars = time_series_regressor.init(
    tsm_main_key,
    jnp.array([test_cat, test_cat_mask]),
    jnp.array([test_num, test_num_mask]),
)
# shapes=[(2, 336, 1, 23), (), (2, 336, 23, 640)]

In [14]:
pre_train.n_tokens

97

In [15]:
calculate_memory_footprint(time_series_regressor_vars) / (1024**3)

2.5059131048619747

In [16]:
pre_train.category_columns

['is_business', 'is_consumption', 'county_name', 'product_name']

In [17]:
class TSBatch:
    def __init__(self, numeric: jnp.ndarray, categorical: jnp.ndarray, y: jnp.ndarray):
        self.numeric = numeric
        self.categorical = categorical
        self.y = y


def make_batch(ds: hp.TabularTimeSeriesData, start: int, count: int) -> TSBatch:
    numeric_list = []
    categorical_list = []
    y_list = []
    for i in range(count):
        idx = start + i
        categorical, numeric, y = ds[idx]
        numeric_list.append(numeric)
        categorical_list.append(categorical)
        y_list.append(y)

    numeric = jnp.array(numeric_list)
    categorical = jnp.array(categorical_list)
    y = jnp.array(y_list)
    batch = TSBatch(numeric=numeric, categorical=categorical, y=y)
    return batch


# batch = make_batch(pre_train, 0, 12)

# r = time_series_regressor.apply(
#     {"params": time_series_regressor_vars["params"]},
#     hp.mask_tensor(batch.categorical, pre_train, prng_key=ts_data_key),
#     hp.mask_tensor(batch.numeric, pre_train, prng_key=ts_data_key),
# )
# r[0].shape, r[1].shape  # ((12, 336, 668), (12, 336, 23))

In [18]:
# max_len = 10000
# d_pos_encoding = 32
# n_features = 24
# n_epochs = 9
# seq_len = 200
# position = jnp.arange(max_len)[:, jnp.newaxis]
# div_term = jnp.exp(
#     jnp.arange(0, d_pos_encoding, 2) * -(jnp.log(10000.0) / d_pos_encoding)
# )
# pe = jnp.zeros((max_len, d_pos_encoding))
# pe = pe.at[:, 0::2].set(jnp.sin(position * div_term))
# pe = pe.at[:, 1::2].set(jnp.cos(position * div_term))
# pe = pe[:seq_len, :]
# pe = pe[None, :, :, None]
# pe = jnp.tile(pe, (n_epochs, 1, 1, n_features))
# pe = pe.transpose((0, 1, 3, 2))  # (batch_size, seq_len, n_features, d_model)
# # concatenate the positional encoding with the input

In [None]:
# Plot the positional encoding
# plt.figure(figsize=(15, 5))
# plt.pcolormesh(pe[0, :, 0, :], cmap="viridis")

In [None]:
# batch_size = 12
# for i in range(len(pre_train) // batch_size):
#     batch = make_batch(pre_train, i * batch_size, batch_size)
#     r = time_series_regressor.apply(
#         {"params": time_series_regressor_vars["params"]},
#         hp.mask_tensor(batch.numeric, pre_train, prng_key=ts_data_key),
#         hp.mask_tensor(batch.categorical, pre_train, prng_key=ts_data_key),
#     )
#     print(r[0].shape, r[1].shape)

In [19]:
mts_root_key = random.PRNGKey(42)

mts_main_key, mts_params_key, mts_dropout_key = random.split(mts_root_key, 3)

In [20]:
# mts_mi = hp.create_masked_time_series_model_inputs(pre_train, 0, 3)

In [21]:
@dataclass
class MultiLoss:
    total_loss: jnp.array
    categorical_loss: jnp.array
    numeric_loss: jnp.array
    target_loss: jnp.array


def calculate_loss(params, state, inputs, dataset):
    numeric_loss_scale = 1.0
    target_loss_scale = 1.0
    category_out, numeric_out, target_out = state.apply(
        {"params": params},
        hp.mask_tensor(inputs.categorical, pre_train, prng_key=ts_data_key),
        hp.mask_tensor(inputs.numeric, pre_train, prng_key=ts_data_key),
    )
    numeric_col_tokens = dataset.numeric_col_tokens.clone()
    numeric_col_tokens = numeric_col_tokens[None, None, :]
    repeated_numeric_col_tokens = jnp.tile(
        numeric_col_tokens, (category_out.shape[0], category_out.shape[1], 1)
    )
    # repeated_numeric_col_tokens = jnp.tile(
    #     numeric_col_tokens, (inputs.categorical.shape[0], 1)
    # )
    ic(repeated_numeric_col_tokens.shape, inputs.categorical.shape, category_out.shape)
    categorical_targets = jnp.concatenate(
        [
            inputs.categorical,
            repeated_numeric_col_tokens,
        ],
        axis=-1,
    )
    ic(categorical_targets.shape)
    categorical_loss = optax.softmax_cross_entropy_with_integer_labels(
        category_out, categorical_targets
    ).mean()
    numeric_loss = optax.squared_error(numeric_out, inputs.numeric).mean()
    target_loss = optax.squared_error(target_out, jnp.squeeze(inputs.y, -1)).mean()
    total_loss = (
        categorical_loss
        + numeric_loss * numeric_loss_scale
        + target_loss * target_loss_scale
    )
    return total_loss, MultiLoss(
        total_loss=total_loss,
        categorical_loss=categorical_loss,
        numeric_loss=numeric_loss,
        target_loss=target_loss,
    )


batch = make_batch(pre_train, 0, 2)
calculate_loss(
    time_series_regressor_vars["params"], time_series_regressor, batch, pre_train
)

ValueError: not enough values to unpack (expected 3, got 2)

In [None]:
batch.y.shape

In [None]:
%tensorboard \
        --logdir '/content/drive/MyDrive/Colab Notebooks/Hephaestus/runs' \
        --load_fast=false

In [None]:
# profiler.start_trace(
#     "/content/drive/MyDrive/Colab Notebooks/Hephaestus/runs", create_perfetto_link=True
# )

try:
    del time_series_regressor_vars
except Exception as e:
    print(e)


def replace_nans(grads, replace_value=0):
    """Replace NaN values in gradients with a specified value."""
    return jax.tree_map(lambda g: jnp.where(jnp.isnan(g), replace_value, g), grads)


def clip_gradients(gradients, max_norm):
    total_norm = jnp.sqrt(sum(jnp.sum(jnp.square(grad)) for grad in gradients.values()))
    scale = max_norm / (total_norm + 1e-6)
    clipped_gradients = jax.tree_map(
        lambda grad: jnp.where(total_norm > max_norm, grad * scale, grad), gradients
    )
    return clipped_gradients


def calculate_loss(
    params,
    state,
    categorical_mask,
    numeric_mask,
    target_labs,
    dataset,
):
    target_out = state.apply_fn(
        {"params": params},
        categorical_mask,
        numeric_mask,
    )

    target_loss = optax.squared_error(target_out, target_labs).mean()

    return target_loss


@jax.jit
def train_step(
    state: train_state.TrainState, categorical_mask, numeric_mask, target_labs
):
    # ic("here1")

    def loss_fn(params):
        return calculate_loss(
            params,
            state,
            categorical_mask=categorical_mask,
            numeric_mask=numeric_mask,
            target_labs=target_labs,
            dataset=pre_train,
        )

    grad_fn = jax.value_and_grad(loss_fn)

    target_loss, grad = grad_fn(state.params)
    grad = replace_nans(grad)
    # grad = clip_gradients(grad, 1.0)
    state = state.apply_gradients(grads=grad)

    return state, target_loss


def create_train_state(model, prng, batch, lr):
    params = model.init(prng, batch.categorical, batch.numeric)
    optimizer = optax.chain(optax.clip_by_global_norm(0.4), optax.adam(lr))
    # optimizer_state = optimizer.init(params)
    return train_state.TrainState.create(
        apply_fn=model.apply,
        params=params["params"],
        tx=optimizer,
        # tx_state=optimizer_state,
    )


batch_size = 2
batch = make_batch(pre_train, 0, batch_size)

state = create_train_state(time_series_regressor, mts_main_key, batch, 0.0001)

summary_writer = SummaryWriter(
    "runs/" + dt.now().strftime("%Y-%m-%dT%H:%M:%S") + "_" + "better_batch"
)
batch_count = 0
for j in trange(2):
    for i in trange(len(pre_train) // batch_size, leave=False):
        batch = make_batch(pre_train, i * batch_size, batch_size)
        categorical_mask = hp.mask_tensor(
            batch.categorical, pre_train, prng_key=ts_data_key
        )
        numeric_mask = hp.mask_tensor(batch.numeric, pre_train, prng_key=ts_data_key)

        state, loss = train_step(state, categorical_mask, numeric_mask, batch.y)
        summary_writer.add_scalar("loss", np.array(loss.item()), batch_count)
        batch_count += 1
        # print(f"epoch: {j}, batch: {i}, {loss}")

# profiler.stop_trace()

ValueError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 1113772544 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:   25.95MiB
              constant allocation:         0B
        maybe_live_out allocation:    1.04GiB
     preallocated temp allocation:    4.00MiB
  preallocated temp fragmentation:         0B (0.00%)
                 total allocation:    1.07GiB
              total fragmentation:       112B (0.00%)
Peak buffers:
	Buffer 1:
		Size: 1.04GiB
		Operator: op_name="jit(dot_general)/jit(main)/dot_general[dimension_numbers=(((3, 4), (0, 1)), ((), ())) precision=None preferred_element_type=None]" source_file="/usr/local/lib/python3.10/dist-packages/flax/linen/linear.py" source_line=195
		XLA Label: custom-call
		Shape: f32[424456,656]
		==========================

	Buffer 2:
		Size: 25.91MiB
		Entry Parameter Subshape: f32[2,3121,68,4,4]
		==========================

	Buffer 3:
		Size: 4.00MiB
		Operator: op_name="jit(dot_general)/jit(main)/dot_general[dimension_numbers=(((3, 4), (0, 1)), ((), ())) precision=None preferred_element_type=None]" source_file="/usr/local/lib/python3.10/dist-packages/flax/linen/linear.py" source_line=195
		XLA Label: custom-call
		Shape: s8[4194304]
		==========================

	Buffer 4:
		Size: 41.0KiB
		Entry Parameter Subshape: f32[4,4,656]
		==========================

	Buffer 5:
		Size: 16B
		Operator: op_name="jit(dot_general)/jit(main)/dot_general[dimension_numbers=(((3, 4), (0, 1)), ((), ())) precision=None preferred_element_type=None]" source_file="/usr/local/lib/python3.10/dist-packages/flax/linen/linear.py" source_line=195
		XLA Label: custom-call
		Shape: (f32[424456,656], s8[4194304])
		==========================



In [None]:
profiler.stop_trace()

Open URL in browser: https://ui.perfetto.dev/#!/?url=http://127.0.0.1:9001/perfetto_trace.json.gz


KeyboardInterrupt: 

In [None]:
profiler.start_trace(
    "/content/drive/MyDrive/Colab Notebooks/Hephaestus/runs", create_perfetto_link=True
)


# def replace_nans(grads, replace_value=0):
#     """Replace NaN values in gradients with a specified value."""
#     return jax.tree_map(lambda g: jnp.where(jnp.isnan(g), replace_value, g), grads)


# def clip_gradients(gradients, max_norm):
#     total_norm = jnp.sqrt(sum(jnp.sum(jnp.square(grad)) for grad in gradients.values()))
#     scale = max_norm / (total_norm + 1e-6)
#     clipped_gradients = jax.tree_map(
#         lambda grad: jnp.where(total_norm > max_norm, grad * scale, grad), gradients
#     )
#     return clipped_gradients


def calculate_loss(
    params,
    state,
    categorical_mask,
    numeric_mask,
    categorical_labs,
    numeric_labs,
    target_labs,
    dataset,
):
    numeric_loss_scale = 1.0
    target_loss_scale = 2.0
    category_out, numeric_out, target_out = state.apply_fn(
        {"params": params},
        categorical_mask,
        numeric_mask,
    )
    numeric_col_tokens = dataset.numeric_col_tokens.clone()
    numeric_col_tokens = numeric_col_tokens[None, None, :]
    repeated_numeric_col_tokens = jnp.tile(
        numeric_col_tokens, (category_out.shape[0], category_out.shape[1], 1)
    )
    # repeated_numeric_col_tokens = jnp.tile(
    #     numeric_col_tokens, (inputs.categorical.shape[0], 1)
    # )
    # ic(repeated_numeric_col_tokens.shape, categorical.shape, category_out.shape)
    categorical_targets = jnp.concatenate(
        [
            categorical_labs,
            repeated_numeric_col_tokens,
        ],
        axis=-1,
    )
    # ic(categorical_targets.shape)
    categorical_loss = optax.softmax_cross_entropy_with_integer_labels(
        category_out, categorical_targets
    ).mean()

    target_loss = optax.squared_error(target_out, target_labs).mean()
    numeric_loss = optax.squared_error(numeric_out, numeric_labs).mean()
    total_loss = (
        categorical_loss
        + numeric_loss * numeric_loss_scale
        + target_loss * target_loss_scale
    )
    return total_loss


@jax.jit
def train_step(
    state: train_state.TrainState,
    categorical_mask,
    numeric_mask,
    categorical_labs,
    target_labs,
    numeric_labs,
):
    # ic("here1")

    def loss_fn(params):
        return calculate_loss(
            params,
            state,
            categorical_mask=categorical_mask,
            numeric_mask=numeric_mask,
            categorical_labs=categorical_labs,
            numeric_labs=numeric_labs,
            target_labs=target_labs,
            dataset=pre_train,
        )

    grad_fn = jax.value_and_grad(loss_fn)

    total_loss, multi_loss, grad = grad_fn(state.params)
    grad = replace_nans(grad)
    # grad = clip_gradients(grad, 1.0)
    state = state.apply_gradients(grads=grad)

    return state, total_loss, multi_loss


def create_train_state(model, prng, batch, lr):
    params = model.init(prng, batch.categorical, batch.numeric)
    optimizer = optax.chain(optax.clip_by_global_norm(0.4), optax.adam(lr))
    # optimizer_state = optimizer.init(params)
    return train_state.TrainState.create(
        apply_fn=model.apply,
        params=params["params"],
        tx=optimizer,
        # tx_state=optimizer_state,
    )


batch_size = 2
batch = make_batch(pre_train, 0, batch_size)

state = create_train_state(time_series_regressor, mts_main_key, batch, 0.0001)

summary_writer = SummaryWriter(
    "runs/" + dt.now().strftime("%Y-%m-%dT%H:%M:%S") + "_" + "better_batch"
)
batch_count = 0
for j in trange(2):
    for i in trange(len(pre_train) // batch_size, leave=False):
        batch = make_batch(pre_train, i * batch_size, batch_size)
        categorical_mask = hp.mask_tensor(
            batch.categorical, pre_train, prng_key=ts_data_key
        )
        numeric_mask = hp.mask_tensor(batch.numeric, pre_train, prng_key=ts_data_key)

        state, loss, multi_loss = train_step(
            state, categorical_mask, numeric_mask, batch.categorical, batch.numeric
        )
        summary_writer.add_scalar("loss", np.array(loss.item()), batch_count)
        batch_count += 1
        # print(f"epoch: {j}, batch: {i}, {loss}")

profiler.stop_trace()

In [None]:
model_name = f"new_data_{dt.now()}_"

current_dir = os.getcwd()

if not os.path.exists("./pre_trained_models/"):
    os.makedirs("./pre_trained_models/")

path = os.path.join(current_dir, "./pre_trained_models/")


ckpt_dir = f"./pre_trained_models/{model_name}"

checkpoints.save_checkpoint(
    ckpt_dir=path, target=state, step=batch_count, overwrite=True, prefix=model_name
)

In [None]:
state2 = checkpoints.restore_checkpoint(path, target=state)

In [None]:
state2

In [None]:
len(pre_train), len(pre_train) // batch_size

In [None]:
data.shape[0] // 3120

In [None]:
profiler.stop_trace()

In [None]:
len(pre_train)

In [None]:
def jax_array_memory_usage(array):
    """Calculate the memory usage of a JAX array in bytes."""
    # Get the number of bytes per element based on the data type
    bytes_per_element = array.dtype.itemsize
    # Calculate the total number of elements in the array
    total_elements = np.prod(array.shape)
    # Calculate total memory usage
    memory_usage_bytes = bytes_per_element * total_elements
    return memory_usage_bytes


cat_memory_usage = jax_array_memory_usage(batch.categorical)
num_memory_usage = jax_array_memory_usage(batch.numeric)
memory_usage = cat_memory_usage + num_memory_usage
memory_usage_gb = memory_usage / 1024 / 1024 / 1024
print(f"Memory usage: {memory_usage} bytes")
print(f"Memory usage: {memory_usage_gb} gb")