<https://www.kaggle.com/competitions/predict-energy-behavior-of-prosumers>


In [1]:
import os

os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.95"

In [2]:
import jax.numpy as jnp  # Oddly works in colab to set gpu

arr = jnp.array([1, 2, 3])
arr.devices()

{cuda(id=0)}

In [3]:
from google.colab import drive

drive.mount("/content/drive")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [4]:
import sys

IN_COLAB = "google.colab" in sys.modules
if IN_COLAB:
    # from google.colab import drive
    # drive.mount('/content/drive')
    # %pip -q install clu
    %cd '/content/drive/MyDrive/Colab Notebooks/Hephaestus'
    %pip -q install icecream
    %pip install -Uq tensorboard-plugin-profile
    %pip install -Uq tensorboard
    # %pip install -U -q jax
    %load_ext tensorboard
    %pip install -U tensorboard-plugin-profile

    import jax.tools.colab_tpu

    try:
        jax.tools.colab_tpu.setup_tpu()
    except RuntimeError:
        print("No TPU")

    # %tensorboard \
    #     --logdir '/content/drive/MyDrive/Colab Notebooks/Hephaestus/runs' \
    #     --load_fast=false

/content/drive/MyDrive/Colab Notebooks/Hephaestus
No TPU


In [5]:
import icecream
from icecream import ic

icecream.install()
ic_disable = True
if ic_disable:
    ic.disable()

In [6]:
import os
from dataclasses import dataclass, field
from datetime import datetime as dt
from itertools import chain

import hephaestus as hp
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax
import pandas as pd
import seaborn as sns
from flax import linen as nn
from flax import struct  # Flax dataclasses
from flax.training import checkpoints, train_state
from flax.training.early_stopping import EarlyStopping
from icecream import ic
from jax import profiler, random
from jax.config import config
from jax.tree_util import tree_flatten
from torch.utils.tensorboard import SummaryWriter
from tqdm.notebook import tqdm, trange

pd.options.mode.copy_on_write = True

  from jax.config import config


In [7]:
def sinify_datetime(df, col):
    df[col + "_base_year"] = df[col].dt.year
    df[col + "_day_sin"] = np.sin(2 * np.pi * df[col].dt.day / 365.25)
    df[col + "_day_cos"] = np.cos(2 * np.pi * df[col].dt.day / 365.25)
    df[col + "_hour_sin"] = np.sin(2 * np.pi * df[col].dt.hour / 24)
    df[col + "_hour_cos"] = np.cos(2 * np.pi * df[col].dt.hour / 24)
    df[col + "_is_weekend"] = (
        df[col].dt.dayofweek.isin([5, 6]).astype(str).map({"True": 1, "False": 0})
    )

    return df

In [8]:
train = pd.read_csv("data/predict-energy-behavior-of-prosumers/train.csv")
gas_df = pd.read_csv("data/predict-energy-behavior-of-prosumers/gas_prices.csv")
electricity_df = pd.read_csv(
    "data/predict-energy-behavior-of-prosumers/electricity_prices.csv"
)
client_df = pd.read_csv("data/predict-energy-behavior-of-prosumers/client.csv")
fw_df = pd.read_csv("data/predict-energy-behavior-of-prosumers/forecast_weather.csv")
hw_df = pd.read_csv("data/predict-energy-behavior-of-prosumers/historical_weather.csv")
# locations is customize data sets which are used by other coders in the compettion to help in
# getting more consolidated merged data set to work on
locations = pd.read_csv("data/predict-energy-behavior-of-prosumers/county_lon_lats.csv")

data = train
client = client_df
hist_weather = hw_df
forecast_weather = fw_df
electricity = electricity_df
gas = gas_df
locations = locations

counties = pd.read_json(
    "data/predict-energy-behavior-of-prosumers/county_id_to_name_map.json",
    orient="index",
).reset_index()
counties.columns = ["county", "county_name"]

# Dropping (target) nan values
data = data[data["target"].notnull()]

# Converting (datetime) column to datetime
data["datetime"] = pd.to_datetime(data["datetime"], utc=True)

# Renaming (forecast_date) to (datetime) for merging with the train data later
electricity = electricity.rename(columns={"forecast_date": "datetime"})

# Converting (datetime) column to datetime
electricity["datetime"] = pd.to_datetime(electricity["datetime"], utc=True)

# Decreasing (data_block_id) in client data because it's 2 steps ahead from train's data (data_block_id)
client["data_block_id"] -= 2


"""locations is a custom data that will help replace (latitude) and (longitude)
columns by the counties for each coordination | you can find the data in Input """
locations = locations.drop("Unnamed: 0", axis=1)

# Rounding the (latitude) and (longitude) for 1 decimal fraction
forecast_weather[["latitude", "longitude"]] = (
    forecast_weather[["latitude", "longitude"]].astype(float).round(1)
)

# Merging counties in locations data with the coordinations in the forecast_weather data
forecast_weather = forecast_weather.merge(
    locations, how="left", on=["longitude", "latitude"]
)

# dropping nan values
forecast_weather.dropna(axis=0, inplace=True)

# Converting (county) column to integer
forecast_weather["county"] = forecast_weather["county"].astype("int64")

# Dropping the columns we won't need | We will use the (forecast_datetime) column instead of the (origin_datetime)
forecast_weather.drop(
    ["origin_datetime", "latitude", "longitude", "hours_ahead", "data_block_id"],
    axis=1,
    inplace=True,
)

# Renaming (forecast_datetime) to (datetime) for merging with the train data later
forecast_weather.rename(columns={"forecast_datetime": "datetime"}, inplace=True)

# Converting (datetime) column to datetime
forecast_weather["datetime"] = pd.to_datetime(forecast_weather["datetime"], utc=True)

"""Grouping all forecast_weather columns mean values by hour, So each hour
will have the mean values of the forecast_weather columns"""
forecast_weather_datetime = (
    forecast_weather.groupby([forecast_weather["datetime"].dt.to_period("h")])[
        list(forecast_weather.drop(["county", "datetime"], axis=1).columns)
    ]
    .mean()
    .reset_index()
)

# After converting the (datetime) column to hour period for the groupby we convert it back to datetime
forecast_weather_datetime["datetime"] = pd.to_datetime(
    forecast_weather_datetime["datetime"].dt.to_timestamp(), utc=True
)

"""Grouping all forecast_weather columns mean values by hour and county, So each hour and county
will have the mean values of the forecast_weather columns for each county"""
forecast_weather_datetime_county = (
    forecast_weather.groupby(
        ["county", forecast_weather["datetime"].dt.to_period("h")]
    )[list(forecast_weather.drop(["county", "datetime"], axis=1).columns)]
    .mean()
    .reset_index()
)

# After converting the (datetime) column to hour period for the groupby we convert it back to datetime
forecast_weather_datetime_county["datetime"] = pd.to_datetime(
    forecast_weather_datetime_county["datetime"].dt.to_timestamp(), utc=True
)

# Rounding the (latitude) and (longitude) for 1 decimal fraction
hist_weather[["latitude", "longitude"]] = (
    hist_weather[["latitude", "longitude"]].astype(float).round(1)
)

# Merging counties in locations data with the coordinations in the historical_weather data
hist_weather = hist_weather.merge(locations, how="left", on=["longitude", "latitude"])

# Dropping nan values
hist_weather.dropna(axis=0, inplace=True)

# Dropping the columns we won't need
hist_weather.drop(["latitude", "longitude"], axis=1, inplace=True)

# Converting (county) to integer
hist_weather["county"] = hist_weather["county"].astype("int64")

# Converting (datetime) column to datetime
hist_weather["datetime"] = pd.to_datetime(hist_weather["datetime"], utc=True)

"""Grouping all historical_weather columns mean values by hour, So each hour
will have the mean values of the historical_weather columns"""
hist_weather_datetime = (
    hist_weather.groupby([hist_weather["datetime"].dt.to_period("h")])[
        list(hist_weather.drop(["county", "datetime", "data_block_id"], axis=1).columns)
    ]
    .mean()
    .reset_index()
)

# After converting the (datetime) column to hour period for the groupby we convert it back to datetime
hist_weather_datetime["datetime"] = pd.to_datetime(
    hist_weather_datetime["datetime"].dt.to_timestamp(), utc=True
)

# Merging (data_block_id) back after dropping it in the last step | (data_block_id will be used to merge with train data)
hist_weather_datetime = hist_weather_datetime.merge(
    hist_weather[["datetime", "data_block_id"]], how="left", on="datetime"
)

"""Grouping all historical_weather columns mean values by hour and county, So each hour
will have the mean values of the historical_weather columns for each county"""
hist_weather_datetime_county = (
    hist_weather.groupby(["county", hist_weather["datetime"].dt.to_period("h")])[
        list(hist_weather.drop(["county", "datetime", "data_block_id"], axis=1).columns)
    ]
    .mean()
    .reset_index()
)

# After converting the (datetime) column to hour period for the groupby we convert it back to datetime
hist_weather_datetime_county["datetime"] = pd.to_datetime(
    hist_weather_datetime_county["datetime"].dt.to_timestamp(), utc=True
)

# Merging (data_block_id) back after dropping it in the last step
hist_weather_datetime_county = hist_weather_datetime_county.merge(
    hist_weather[["datetime", "data_block_id"]], how="left", on="datetime"
)

# Adding year column in train data
data["year"] = data["datetime"].dt.year

# Adding month column in train data
data["month"] = data["datetime"].dt.month

# Adding day column in train data
data["day"] = data["datetime"].dt.day

# Adding hour column in train data
data["hour"] = data["datetime"].dt.hour

# Adding dayofweek column in train data
data["dayofweek"] = data["datetime"].dt.dayofweek

# Adding dayofyear column in train data
data["dayofyear"] = data["datetime"].dt.dayofyear

# Adding hour column to electricity used to merge with the train data
electricity["hour"] = electricity["datetime"].dt.hour

# Merging train data with client data
data = data.merge(
    client.drop(columns=["date"]),
    how="left",
    on=["data_block_id", "county", "is_business", "product_type"],
)

# Merging train data with gas data
data = data.merge(
    gas[["data_block_id", "lowest_price_per_mwh", "highest_price_per_mwh"]],
    how="left",
    on="data_block_id",
)

# Merging train data with electricity data
data = data.merge(
    electricity[["euros_per_mwh", "hour", "data_block_id"]],
    how="left",
    on=["hour", "data_block_id"],
)

# Merging train data with forecast_weather_datetime data
data = data.merge(forecast_weather_datetime, how="left", on=["datetime"])

# Merging train data with forecast_weather_datetime_county data
data = data.merge(
    forecast_weather_datetime_county,
    how="left",
    on=["datetime", "county"],
    suffixes=("_fcast_mean", "_fcast_mean_by_county"),
)

# Creating hour columns in both historical_weather data | used to merge both data with the train data
hist_weather_datetime["hour"] = hist_weather_datetime["datetime"].dt.hour
hist_weather_datetime_county["hour"] = hist_weather_datetime_county["datetime"].dt.hour

# Dropping duplicates and (datetime) column
hist_weather_datetime.drop_duplicates(inplace=True)
hist_weather_datetime_county.drop_duplicates(inplace=True)
hist_weather_datetime.drop("datetime", axis=1, inplace=True)
hist_weather_datetime_county.drop("datetime", axis=1, inplace=True)

# Merging hist_weather_datetime with train data
data = data.merge(hist_weather_datetime, how="left", on=["data_block_id", "hour"])


# Merging hist_weather_datetime_county with train data
data = data.merge(
    hist_weather_datetime_county,
    how="left",
    on=["data_block_id", "county", "hour"],
    suffixes=("_hist_mean", "_hist_mean_by_county"),
)

# Filling nan values with hourly mean values for each column | Helps for the county missing value
data = (
    data.groupby(["year", "day", "hour"], as_index=False)
    .apply(lambda x: x.ffill().bfill())
    .reset_index()
)
data = data.merge(counties, how="left", on="county").drop(columns=["county"])
product_dict = {0: "Combined", 1: "Fixed", 2: "General service", 3: "Spot"}

# Convert the dictionary to a list of tuples
product_list = list(product_dict.items())

# Create the DataFrame
product = pd.DataFrame(product_list, columns=["product_type", "product_name"])

data = data.merge(product, how="left", on="product_type").drop(columns=["product_type"])
data["is_business"] = data["is_business"].map({0: "Residential", 1: "Business"})
data["is_consumption"] = data["is_consumption"].map({0: "Production", 1: "Consumption"})

# Dropping uneeded data
data.drop(
    ["row_id", "data_block_id", "year", "datetime"],
    axis=1,
    inplace=True,
)

data = data[
    list(
        chain(
            data.select_dtypes(include=["object"]).columns,
            data.select_dtypes(exclude=["object"]).columns,
        )
    )
]

data_head = data.head(100)

  forecast_weather.groupby([forecast_weather["datetime"].dt.to_period("h")])[
  ["county", forecast_weather["datetime"].dt.to_period("h")]
  hist_weather.groupby([hist_weather["datetime"].dt.to_period("h")])[
  hist_weather.groupby(["county", hist_weather["datetime"].dt.to_period("h")])[
To preserve the previous behavior, use

	>>> .groupby(..., group_keys=False)


	>>> .groupby(..., group_keys=True)
  .apply(lambda x: x.ffill().bfill())


In [9]:
data.columns

Index(['is_business', 'is_consumption', 'county_name', 'product_name', 'index',
       'target', 'prediction_unit_id', 'month', 'day', 'hour', 'dayofweek',
       'dayofyear', 'eic_count', 'installed_capacity', 'lowest_price_per_mwh',
       'highest_price_per_mwh', 'euros_per_mwh', 'temperature_fcast_mean',
       'dewpoint_fcast_mean', 'cloudcover_high_fcast_mean',
       'cloudcover_low_fcast_mean', 'cloudcover_mid_fcast_mean',
       'cloudcover_total_fcast_mean', '10_metre_u_wind_component_fcast_mean',
       '10_metre_v_wind_component_fcast_mean',
       'direct_solar_radiation_fcast_mean',
       'surface_solar_radiation_downwards_fcast_mean', 'snowfall_fcast_mean',
       'total_precipitation_fcast_mean', 'temperature_fcast_mean_by_county',
       'dewpoint_fcast_mean_by_county', 'cloudcover_high_fcast_mean_by_county',
       'cloudcover_low_fcast_mean_by_county',
       'cloudcover_mid_fcast_mean_by_county',
       'cloudcover_total_fcast_mean_by_county',
       '10_metre_u_wi

In [10]:
pre_train = hp.TabularTimeSeriesData(
    data,
    batch_size=3120 * 1,  # target_column="euros_per_mwh"
    target_column="target",
)  # .head(2688 * 100)

# pre_train = hp.TabularDS(df)

In [11]:
pre_train[1]

(Array([[72, 78, 83, 87],
        [72, 73, 83, 87],
        [75, 78, 83, 94],
        ...,
        [75, 78, 91, 87],
        [75, 73, 91, 87],
        [72, 78, 84, 87]], dtype=int32),
 Array([[-1.7266936 ,  0.09977204,  0.70107555, ..., -0.5924703 ,
         -0.481902  , -0.68075854],
        [-1.726692  ,  0.09977204,  0.70107555, ..., -0.5924703 ,
         -0.481902  , -0.68075854],
        [-1.7266903 ,  0.15081689,  0.70107555, ..., -0.5924703 ,
         -0.481902  , -0.68075854],
        ...,
        [-1.7213409 , -1.2784388 ,  0.70107555, ..., -0.5924703 ,
         -0.481902  , -0.68075854],
        [-1.7213391 , -1.2784388 ,  0.70107555, ..., -0.5924703 ,
         -0.481902  , -0.68075854],
        [-1.7213374 , -1.227394  ,  0.70107555, ..., -0.5924703 ,
         -0.481902  , -0.68075854]], dtype=float32),
 Array([[5.2000e-02],
        [8.7508e+01],
        [0.0000e+00],
        ...,
        [0.0000e+00],
        [2.7202e+01],
        [6.8000e-02]], dtype=float32))

In [12]:
time_series_regressor = hp.time_series.MaskedTimeSeries(pre_train)

In [13]:
pre_train[0]

(Array([[72, 78, 93, 94],
        [72, 73, 93, 94],
        [72, 78, 93, 79],
        ...,
        [72, 78, 83, 94],
        [72, 73, 83, 94],
        [72, 78, 83, 87]], dtype=int32),
 Array([[-1.73205   , -1.6867976 ,  0.70107555, ..., -0.5924703 ,
         -0.481902  , -0.68075854],
        [-1.7320483 , -1.6867976 ,  0.70107555, ..., -0.5924703 ,
         -0.481902  , -0.68075854],
        [-1.7320465 , -1.6357528 ,  0.70107555, ..., -0.5924703 ,
         -0.481902  , -0.68075854],
        ...,
        [-1.7266971 ,  0.04872719,  0.70107555, ..., -0.5924703 ,
         -0.481902  , -0.68075854],
        [-1.7266954 ,  0.04872719,  0.70107555, ..., -0.5924703 ,
         -0.481902  , -0.68075854],
        [-1.7266936 ,  0.09977204,  0.70107555, ..., -0.5924703 ,
         -0.481902  , -0.68075854]], dtype=float32),
 Array([[7.1300e-01],
        [9.6590e+01],
        [0.0000e+00],
        ...,
        [0.0000e+00],
        [1.4178e+01],
        [5.2000e-02]], dtype=float32))

In [14]:
# time_length = 1000 # 6336 * 3  # 112 * 3 #
test_cat, test_num, y = pre_train[0]

test_num_mask = hp.mask_tensor(test_num, pre_train)
test_cat_mask = hp.mask_tensor(test_cat, pre_train)

In [15]:
def calculate_memory_footprint(params):
    """Calculate total memory footprint of JAX model parameters."""
    total_bytes = 0
    # Flatten the parameter tree structure into a list of arrays
    flat_params, _ = tree_flatten(params)
    for param in flat_params:
        # Calculate bytes: number of elements * size of each element
        bytes_per_param = param.size * param.dtype.itemsize
        total_bytes += bytes_per_param
    return total_bytes

In [16]:
tsm_root_key = random.PRNGKey(44)
tsm_main_key, ts_params_key, ts_data_key = random.split(tsm_root_key, 3)

time_series_regressor_vars = time_series_regressor.init(
    tsm_main_key,
    jnp.array([test_cat, test_cat_mask]),
    jnp.array([test_num, test_num_mask]),
)
# shapes=[(2, 336, 1, 23), (), (2, 336, 23, 640)]

In [17]:
pre_train.n_tokens

96

In [18]:
calculate_memory_footprint(time_series_regressor_vars) / (1024**3)

0.6398345232009888

In [19]:
pre_train.category_columns

['is_business', 'is_consumption', 'county_name', 'product_name']

In [20]:
class TSBatch:
    def __init__(self, numeric: jnp.ndarray, categorical: jnp.ndarray, y: jnp.ndarray):
        self.numeric = numeric
        self.categorical = categorical
        self.y = y


def make_batch(ds: hp.TabularTimeSeriesData, start: int, count: int) -> TSBatch:
    numeric_list = []
    categorical_list = []
    y_list = []
    for i in range(count):
        idx = start + i
        categorical, numeric, y = ds[idx]
        numeric_list.append(numeric)
        categorical_list.append(categorical)
        y_list.append(y)

    numeric = jnp.array(numeric_list)
    categorical = jnp.array(categorical_list)
    y = jnp.array(y_list)
    batch = TSBatch(numeric=numeric, categorical=categorical, y=y)
    return batch


# batcha = make_batch(pre_train, 0, 12)

# r = time_series_regressor.apply(
#     {"params": time_series_regressor_vars["params"]},
#     hp.mask_tensor(batch.categorical, pre_train, prng_key=ts_data_key),
#     hp.mask_tensor(batch.numeric, pre_train, prng_key=ts_data_key),
# )
# r[0].shape, r[1].shape  # ((12, 336, 668), (12, 336, 23))

pre_t


In [21]:
# max_len = 10000
# d_pos_encoding = 32
# n_features = 24
# n_epochs = 9
# seq_len = 200
# position = jnp.arange(max_len)[:, jnp.newaxis]
# div_term = jnp.exp(
#     jnp.arange(0, d_pos_encoding, 2) * -(jnp.log(10000.0) / d_pos_encoding)
# )
# pe = jnp.zeros((max_len, d_pos_encoding))
# pe = pe.at[:, 0::2].set(jnp.sin(position * div_term))
# pe = pe.at[:, 1::2].set(jnp.cos(position * div_term))
# pe = pe[:seq_len, :]
# pe = pe[None, :, :, None]
# pe = jnp.tile(pe, (n_epochs, 1, 1, n_features))
# pe = pe.transpose((0, 1, 3, 2))  # (batch_size, seq_len, n_features, d_model)
# # concatenate the positional encoding with the input

In [22]:
# Plot the positional encoding
# plt.figure(figsize=(15, 5))
# plt.pcolormesh(pe[0, :, 0, :], cmap="viridis")

In [23]:
# batch_size = 12
# for i in range(len(pre_train) // batch_size):
#     batch = make_batch(pre_train, i * batch_size, batch_size)
#     r = time_series_regressor.apply(
#         {"params": time_series_regressor_vars["params"]},
#         hp.mask_tensor(batch.numeric, pre_train, prng_key=ts_data_key),
#         hp.mask_tensor(batch.categorical, pre_train, prng_key=ts_data_key),
#     )
#     print(r[0].shape, r[1].shape)

In [24]:
mts_root_key = random.PRNGKey(42)

mts_main_key, mts_params_key, mts_dropout_key = random.split(mts_root_key, 3)

In [25]:
# mts_mi = hp.create_masked_time_series_model_inputs(pre_train, 0, 3)

In [26]:
@dataclass
class MultiLoss:
    total_loss: jnp.array
    categorical_loss: jnp.array
    numeric_loss: jnp.array
    target_loss: jnp.array


def calculate_loss(params, state, inputs, dataset):
    numeric_loss_scale = 1.0
    target_loss_scale = 3.0
    category_out, numeric_out = state.apply(
        {"params": params},
        hp.mask_tensor(inputs.categorical, pre_train, prng_key=ts_data_key),
        hp.mask_tensor(inputs.numeric, pre_train, prng_key=ts_data_key),
    )
    numeric_col_tokens = dataset.numeric_col_tokens.clone()
    numeric_col_tokens = numeric_col_tokens[None, None, :]
    repeated_numeric_col_tokens = jnp.tile(
        numeric_col_tokens, (category_out.shape[0], category_out.shape[1], 1)
    )
    # repeated_numeric_col_tokens = jnp.tile(
    #     numeric_col_tokens, (inputs.categorical.shape[0], 1)
    # )
    ic(repeated_numeric_col_tokens.shape, inputs.categorical.shape, category_out.shape)
    categorical_targets = jnp.concatenate(
        [
            inputs.categorical,
            repeated_numeric_col_tokens,
        ],
        axis=-1,
    )
    ic(categorical_targets.shape)
    categorical_loss = optax.softmax_cross_entropy_with_integer_labels(
        category_out, categorical_targets
    ).mean()
    numeric_loss = optax.squared_error(numeric_out, inputs.numeric).mean()
    target_loss = optax.squared_error(numeric_out, inputs.y).mean()
    total_loss = (
        categorical_loss
        + numeric_loss * numeric_loss_scale
        + target_loss * target_loss_scale
    )
    return total_loss, MultiLoss(
        total_loss=total_loss,
        categorical_loss=categorical_loss,
        numeric_loss=numeric_loss,
        target_loss=target_loss,
    )


batch = make_batch(pre_train, 0, 2)
calculate_loss(
    time_series_regressor_vars["params"], time_series_regressor, batch, pre_train
)

Array(5.3340983, dtype=float32)

In [1]:
%tensorboard \
        --logdir '/content/drive/MyDrive/Colab Notebooks/Hephaestus/runs' \
        --load_fast=false

UsageError: Line magic function `%tensorboard` not found.


In [28]:
profiler.start_trace(
    "/content/drive/MyDrive/Colab Notebooks/Hephaestus/runs", create_perfetto_link=True
)


def replace_nans(grads, replace_value=0):
    """Replace NaN values in gradients with a specified value."""
    return jax.tree_map(lambda g: jnp.where(jnp.isnan(g), replace_value, g), grads)


def clip_gradients(gradients, max_norm):
    total_norm = jnp.sqrt(sum(jnp.sum(jnp.square(grad)) for grad in gradients.values()))
    scale = max_norm / (total_norm + 1e-6)
    clipped_gradients = jax.tree_map(
        lambda grad: jnp.where(total_norm > max_norm, grad * scale, grad), gradients
    )
    return clipped_gradients


def calculate_loss(
    params,
    state,
    categorical_mask,
    numeric_mask,
    categorical_labs,
    numeric_labs,
    target_labs,
    dataset,
):
    numeric_loss_scale = 1.0
    target_loss_scale = 2.0
    category_out, numeric_out, target_out = state.apply_fn(
        {"params": params},
        categorical_mask,
        numeric_mask,
    )
    numeric_col_tokens = dataset.numeric_col_tokens.clone()
    numeric_col_tokens = numeric_col_tokens[None, None, :]
    repeated_numeric_col_tokens = jnp.tile(
        numeric_col_tokens, (category_out.shape[0], category_out.shape[1], 1)
    )
    # repeated_numeric_col_tokens = jnp.tile(
    #     numeric_col_tokens, (inputs.categorical.shape[0], 1)
    # )
    # ic(repeated_numeric_col_tokens.shape, categorical.shape, category_out.shape)
    categorical_targets = jnp.concatenate(
        [
            categorical_labs,
            repeated_numeric_col_tokens,
        ],
        axis=-1,
    )
    # ic(categorical_targets.shape)
    categorical_loss = optax.softmax_cross_entropy_with_integer_labels(
        category_out, categorical_targets
    ).mean()

    target_loss = optax.squared_error(target_out, target_labs).mean()
    numeric_loss = optax.squared_error(numeric_out, numeric_labs).mean()
    total_loss = (
        categorical_loss
        + numeric_loss * numeric_loss_scale
        + target_loss * target_loss_scale
    )
    return total_loss


@jax.jit
def train_step(
    state: train_state.TrainState,
    categorical_mask,
    numeric_mask,
    categorical_labs,
    target_labs,
    numeric_labs,
):
    # ic("here1")

    def loss_fn(params):
        return calculate_loss(
            params,
            state,
            categorical_mask=categorical_mask,
            numeric_mask=numeric_mask,
            categorical_labs=categorical_labs,
            numeric_labs=numeric_labs,
            target_labs=target_labs,
            dataset=pre_train,
        )

    grad_fn = jax.value_and_grad(loss_fn)

    total_loss, multi_loss, grad = grad_fn(state.params)
    grad = replace_nans(grad)
    # grad = clip_gradients(grad, 1.0)
    state = state.apply_gradients(grads=grad)

    return state, total_loss, multi_loss


def create_train_state(model, prng, batch, lr):
    params = model.init(prng, batch.categorical, batch.numeric)
    optimizer = optax.chain(optax.clip_by_global_norm(0.4), optax.adam(lr))
    # optimizer_state = optimizer.init(params)
    return train_state.TrainState.create(
        apply_fn=model.apply,
        params=params["params"],
        tx=optimizer,
        # tx_state=optimizer_state,
    )


batch_size = 2
batch = make_batch(pre_train, 0, batch_size)

state = create_train_state(time_series_regressor, mts_main_key, batch, 0.0001)

summary_writer = SummaryWriter(
    "runs/" + dt.now().strftime("%Y-%m-%dT%H:%M:%S") + "_" + "better_batch"
)
batch_count = 0
for j in trange(2):
    for i in trange(len(pre_train) // batch_size, leave=False):
        batch = make_batch(pre_train, i * batch_size, batch_size)
        categorical_mask = hp.mask_tensor(
            batch.categorical, pre_train, prng_key=ts_data_key
        )
        numeric_mask = hp.mask_tensor(batch.numeric, pre_train, prng_key=ts_data_key)

        state, loss, multi_loss = train_step(
            state, categorical_mask, numeric_mask, batch.categorical, batch.numeric
        )
        summary_writer.add_scalar("loss", np.array(loss.item()), batch_count)
        batch_count += 1
        # print(f"epoch: {j}, batch: {i}, {loss}")

profiler.stop_trace()

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/258 [00:00<?, ?it/s]

  0%|          | 0/258 [00:00<?, ?it/s]

Open URL in browser: https://ui.perfetto.dev/#!/?url=http://127.0.0.1:9001/perfetto_trace.json.gz


KeyboardInterrupt: 

In [None]:
model_name = f"new_data_{dt.now()}_"

current_dir = os.getcwd()

if not os.path.exists(f"./pre_trained_models/"):
    os.makedirs(f"./pre_trained_models/")

path = os.path.join(current_dir, f"./pre_trained_models/")


ckpt_dir = f"./pre_trained_models/{model_name}"

checkpoints.save_checkpoint(
    ckpt_dir=path, target=state, step=batch_count, overwrite=True, prefix=model_name
)

In [None]:
state2 = checkpoints.restore_checkpoint(path, target=state)

In [None]:
state2

In [None]:
len(pre_train), len(pre_train) // batch_size

In [None]:
data.shape[0] // 3120

In [None]:
profiler.stop_trace()

In [None]:
len(pre_train)

In [None]:
def jax_array_memory_usage(array):
    """Calculate the memory usage of a JAX array in bytes."""
    # Get the number of bytes per element based on the data type
    bytes_per_element = array.dtype.itemsize
    # Calculate the total number of elements in the array
    total_elements = np.prod(array.shape)
    # Calculate total memory usage
    memory_usage_bytes = bytes_per_element * total_elements
    return memory_usage_bytes


cat_memory_usage = jax_array_memory_usage(batch.categorical)
num_memory_usage = jax_array_memory_usage(batch.numeric)
memory_usage = cat_memory_usage + num_memory_usage
memory_usage_gb = memory_usage / 1024 / 1024 / 1024
print(f"Memory usage: {memory_usage} bytes")
print(f"Memory usage: {memory_usage_gb} gb")