In [1]:
from datetime import datetime as dt
from itertools import chain
import os
from dataclasses import dataclass, field

import numpy as np
import pandas as pd
import jax
import jax.numpy as jnp
from jax import random
from flax import linen as nn
import optax

from tqdm.notebook import tqdm, trange
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt
import seaborn as sns
import hephaestus_jax as hp

# Load and preprocess the dataset (assuming you have a CSV file)
df = pd.read_csv("../data/diamonds.csv")
df.head()

Unnamed: 0,carat,cut,color,clarity,depth,table,price,x,y,z
0,0.23,Ideal,E,SI2,61.5,55.0,326,3.95,3.98,2.43
1,0.21,Premium,E,SI1,59.8,61.0,326,3.89,3.84,2.31
2,0.23,Good,E,VS1,56.9,65.0,327,4.05,4.07,2.31
3,0.29,Premium,I,VS2,62.4,58.0,334,4.2,4.23,2.63
4,0.31,Good,J,SI2,63.3,58.0,335,4.34,4.35,2.75


In [2]:
dataset = hp.TabularDS(df, target_column="price")

In [3]:
mtm = hp.MTM(dataset, n_heads=4)
trm = hp.TRM(dataset, n_heads=4)

In [4]:
batch_size = 3
test_num = dataset.X_train_numeric[0:batch_size, :]
test_num_mask = hp.mask_tensor(test_num, dataset)
test_cat = dataset.X_test_categorical[0:batch_size, :]
test_cat_mask = hp.mask_tensor(test_cat, dataset)

In [5]:
jnp.concatenate([test_num, test_cat], axis=1).shape

(3, 9)

In [6]:
key = random.PRNGKey(0)
rngs = {
    "params": jax.random.PRNGKey(0),
    "dropout": jax.random.PRNGKey(1),
}
mtm_variables = mtm.init(
    rngs,
    test_num_mask,
    test_cat_mask,
)

trm_variables = trm.init(
    rngs,
    test_num_mask,
    test_cat_mask,
)

In [7]:
regression_out = trm.apply(
    {"params": trm_variables["params"]},
    test_num_mask,
    test_cat_mask,
    rngs={"dropout": jax.random.PRNGKey(43)},
)
regression_out

Array([[-0.65439683, -0.67379665, -0.91364646, -0.79517096, -0.5047654 ,
        -0.73814166, -0.7040618 , -0.61579883, -0.7652412 ],
       [-0.5062399 , -0.47018635, -0.7961222 , -1.0067551 , -0.81049603,
        -0.89864063, -0.6520189 , -0.67292017, -0.9244095 ],
       [-0.32595527, -0.4032916 , -0.67198247, -0.44443786, -0.4422539 ,
        -0.46660137, -0.8696212 , -0.48016414, -0.38429704]],      dtype=float32)

In [8]:
mtm_out = mtm.apply(
    {"params": mtm_variables["params"]},
    test_num_mask,
    test_cat_mask,
    rngs={"dropout": jax.random.PRNGKey(43)},
)
mtm_out[0].shape, mtm_out[1].shape

((3, 9, 33), (3, 6))

In [9]:
mtm_out[0]

Array([[[-1.34741869e-02, -1.87000060e+00, -5.46815515e-01,
          3.28416049e-01, -4.31626856e-01,  1.29093671e+00,
          1.09174132e+00, -1.59050024e+00, -8.95736337e-01,
         -6.80820882e-01,  3.61232638e-01,  1.79834890e+00,
         -7.61659801e-01, -5.96087873e-01, -1.28177893e+00,
         -1.92998722e-01,  5.56062520e-01,  2.12468982e+00,
         -5.59963226e-01,  1.17033887e+00,  7.47882724e-01,
          1.89003319e-01, -1.26332355e+00, -1.98147893e+00,
          2.18777966e+00,  3.09999526e-01,  3.78764153e-01,
         -1.28439403e+00, -7.07636714e-01, -4.36969846e-01,
          5.44406354e-01,  5.47584653e-01,  9.76059914e-01],
        [-5.43075740e-01, -1.12100422e+00, -1.05681157e+00,
         -8.91107142e-01,  1.71672061e-01,  1.80148914e-01,
          9.94629145e-01, -1.04267263e+00, -2.99634188e-01,
          3.95746708e-01,  1.38280940e+00,  2.10540152e+00,
         -8.57059419e-01, -1.14014316e+00, -3.28990012e-01,
         -8.92484039e-02,  1.36850297e+

In [10]:
regression_out = trm.apply(
    {"params": trm_variables["params"]},
    test_num_mask,
    test_cat_mask,
    rngs={"dropout": jax.random.PRNGKey(43)},
)
regression_out.shape

(3, 9)

In [21]:
from flax import struct  # Flax dataclasses
from clu import metrics
from flax.training import train_state  # Useful dataclass to keep train state

In [25]:
@struct.dataclass
class Metrics(metrics.Collection):
    numeric_loss: metrics.Average.from_output("numeric_loss")
    categorical_loss: metrics.Average.from_output("cat_loss")
    combined_loss: metrics.Average.from_output("loss")

In [26]:
class TrainState(train_state.TrainState):
    metrics: Metrics


def create_train_state(model, rng, inputs, learning_rate: float):
    """Creates an initial `TrainState`."""
    params = model.init(
        rngs,
        inputs.masked_num_input,
        inputs.masked_cat_input,
    )["params"]

    # params = model.init(rng, num_input, cat_input)["params"]
    # tx = optax.adam(learning_rate)
    tx = optax.chain(
        optax.clip_by_global_norm(max_norm=1.0),  # Gradient clipping
        optax.adam(learning_rate),  # Learning rate
    )
    return TrainState.create(
        apply_fn=model.apply, params=params, tx=tx, metrics=Metrics.empty()
    )

In [27]:
@dataclass
class ModelInputs:
    num_input: jnp.ndarray
    cat_input: jnp.ndarray
    numeric_col_tokens: jnp.ndarray
    masked_num_input: jnp.ndarray = None
    masked_cat_input: jnp.ndarray = None
    cat_targets: jnp.ndarray = None

    def __post_init__(self):
        self.masked_num_input = hp.mask_tensor(self.num_input, dataset)
        self.masked_cat_input = hp.mask_tensor(self.cat_input, dataset)
        repeated_numeric_col_tokens = jnp.tile(
            self.numeric_col_tokens, (self.cat_input.shape[0], 1)
        )
        self.cat_targets = jnp.concatenate(
            [self.cat_input, repeated_numeric_col_tokens], axis=1
        )
        # self.cat_targets = jnp.broadcast_to(
        #     self.numeric_col_tokens, (self.cat_input.shape[0], -1)
        # )

In [28]:
@jax.jit
def eval_step(
    *,
    state,
    masked_num_input,
    masked_cat_input,
    cat_targets,
    num_input,
):
    cat_logits, numeric_preds = state.apply_fn(
        {"params": state.params},
        masked_num_input,
        masked_cat_input,
        rngs={"dropout": jax.random.PRNGKey(43)},
    )
    cat_loss = optax.softmax_cross_entropy_with_integer_labels(
        logits=cat_logits, labels=cat_targets
    ).mean()  # TODO Do we want mean here?

    numeric_loss = optax.squared_error(numeric_preds, num_input).mean()
    loss = cat_loss + numeric_loss * 10  # Hyperparameter
    # loss_dict = {"loss": loss, "cat_loss": cat_loss, "numeric_loss": numeric_loss}
    # metrics_update = state.metrics.single_from_model_output(
    #     logits=
    # )

    return loss  # loss_dict


# from flax.training import clip_grads


@jax.jit
def train_step(
    state,
    masked_num_input,
    masked_cat_input,
    cat_targets,
    num_input,
):
    """Train for a single step."""

    def loss_fn(params):
        cat_logits, numeric_preds = state.apply_fn(
            {"params": params},
            masked_num_input,
            masked_cat_input,
            rngs={"dropout": jax.random.PRNGKey(43)},
        )
        cat_loss = optax.softmax_cross_entropy_with_integer_labels(
            logits=cat_logits, labels=cat_targets
        ).mean()  # TODO Do we want mean here?
        numeric_loss = optax.squared_error(numeric_preds, num_input).mean()
        loss = cat_loss + numeric_loss * 10  # Hyperparameter
        # loss_dict = {"loss": loss, "cat_loss": cat_loss, "numeric_loss": numeric_loss}
        # print(f"loss is: {loss}")
        # summary_writer.add_scalar("loss", loss)
        # summary_writer.add_scalar("cat_loss", cat_loss)
        # summary_writer.add_scalar("numeric_loss", numeric_loss)
        return loss

    grad_fn = jax.grad(loss_fn)
    grads = grad_fn(state.params)
    # grads = optax.clip_by_global_norm(grads, max_norm=threshold)

    state = state.apply_gradients(grads=grads)
    return state


@jax.jit
def compute_metrics(
    *,
    state,
    masked_num_input,
    masked_cat_input,
    cat_targets,
    num_input,
):
    cat_logits, numeric_preds = state.apply_fn(
        {"params": state.params},
        masked_num_input,
        masked_cat_input,
        rngs={"dropout": jax.random.PRNGKey(43)},
    )
    cat_loss = optax.softmax_cross_entropy_with_integer_labels(
        logits=cat_logits, labels=cat_targets
    ).mean()  # TODO Do we want mean here?
    # print(f"cat_loss is: {cat_loss}")

    numeric_loss = optax.squared_error(numeric_preds, num_input).mean()
    loss = cat_loss + numeric_loss * 10
    loss_dict = {"loss": loss, "cat_loss": cat_loss, "numeric_loss": numeric_loss}
    # metric_updates = state.metrics.single_from_model_output(
    #     logits=logits, labels=batch["label"], loss=loss
    # )
    # metrics = state.metrics.merge(metric_updates)
    # state = state.replace(metrics=metrics)
    return loss_dict

In [29]:
inputs = ModelInputs(
    num_input=dataset.X_train_numeric[0:10],
    cat_input=dataset.X_train_categorical[0:10],
    numeric_col_tokens=dataset.numeric_col_tokens,
)

In [30]:
import numpy as np

init_rng = jax.random.PRNGKey(0)
learning_rate = 0.0001
state = create_train_state(
    mtm,
    init_rng,
    inputs,
    learning_rate,
)
del init_rng  # Must not be used anymore.
epochs = 1
batch_size = 1
model_name = "mtm_" + dt.now().strftime("%Y%m%d-%H%M%S")
summary_writer = SummaryWriter(f"runs/{model_name}")
pbar = trange(epochs, desc="Epochs", leave=True)
for epoch in pbar:  # range(epochs):
    # for i in range(0, dataset.X_train_numeric.shape[0], batch_size):
    for i in range(0, batch_size, batch_size):
        inputs = ModelInputs(
            num_input=dataset.X_train_numeric[i : i + batch_size],
            cat_input=dataset.X_train_categorical[i : i + batch_size],
            numeric_col_tokens=dataset.numeric_col_tokens,
        )
        state = train_step(
            state,
            inputs.masked_num_input,
            inputs.masked_cat_input,
            inputs.cat_targets,
            inputs.num_input,
        )
        loss_dict = compute_metrics(
            state=state,
            masked_num_input=inputs.masked_num_input,
            masked_cat_input=inputs.masked_cat_input,
            cat_targets=inputs.cat_targets,
            num_input=inputs.num_input,
        )
        summary_writer.add_scalar("loss", np.asarray(loss_dict["loss"]), epoch)
        # summary_writer.add_scalar("cat_loss", loss_dict["cat_loss"], epoch)
        # summary_writer.add_scalar("numeric_loss", loss_dict["numeric_loss"], epoch)
        # def train_step(state, masked_num_input, masked_cat_input, cat_targets):

Epochs:   0%|          | 0/1 [00:00<?, ?it/s]

In [31]:
np.asarray(loss_dict["loss"])

array(nan, dtype=float32)

In [32]:
cat_logits, numeric_preds = mtm.apply(
    {"params": state.params},
    inputs.masked_num_input,
    inputs.masked_cat_input,
    rngs={"dropout": jax.random.PRNGKey(43)},
)

In [None]:
numeric_preds

Array([[nan, nan, nan, nan, nan, nan]], dtype=float32)