In [1]:
from datetime import datetime as dt
from itertools import chain
import os
import numpy as np
import pandas as pd
import jax
import jax.numpy as jnp
from jax import random
from flax import linen as nn

from tqdm.notebook import tqdm, trange
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt
import seaborn as sns
import hephaestus_jax as hp

# Load and preprocess the dataset (assuming you have a CSV file)
df = pd.read_csv("../data/diamonds.csv")
df.head()

Unnamed: 0,carat,cut,color,clarity,depth,table,price,x,y,z
0,0.23,Ideal,E,SI2,61.5,55.0,326,3.95,3.98,2.43
1,0.21,Premium,E,SI1,59.8,61.0,326,3.89,3.84,2.31
2,0.23,Good,E,VS1,56.9,65.0,327,4.05,4.07,2.31
3,0.29,Premium,I,VS2,62.4,58.0,334,4.2,4.23,2.63
4,0.31,Good,J,SI2,63.3,58.0,335,4.34,4.35,2.75


In [2]:
dataset = hp.TabularDS(df, target_column="price")

In [3]:
model = hp.TabTransformer(dataset, n_heads=8)

batch_size = 3
test_num = dataset.X_train_numeric[0:batch_size, :]
test_num_mask = hp.mask_tensor(test_num, model)
test_cat = dataset.X_test_categorical[0:batch_size, :]
test_cat_mask = hp.mask_tensor(test_cat, model)

In [4]:
key = random.PRNGKey(0)
rngs = {
    "params": jax.random.PRNGKey(0),
    "dropout": jax.random.PRNGKey(1),
}
regression_variables = model.init(
    rngs,
    test_num_mask,
    test_cat_mask,
    task="regression",
)

mlm_variables = model.init(
    rngs,
    test_num_mask,
    test_cat_mask,
    task="mlm",
)

numeric_out shape: (3, 576)
numeric_out shape decoded: (3, 6)


In [5]:
mlm_variables.keys()

dict_keys(['params'])

In [6]:
x = model.apply(
    {"params": regression_variables["params"]},
    test_num_mask,
    test_cat_mask,
    rngs={"dropout": jax.random.PRNGKey(43)},
    task="regression",
)
x.shape  # [0].shape, x[1].shape

(3, 9)

In [7]:
x = model.apply(
    {"params": mlm_variables["params"]},
    test_num_mask,
    test_cat_mask,
    rngs={"dropout": jax.random.PRNGKey(43)},
    task="mlm",
)

numeric_out shape: (3, 576)
numeric_out shape decoded: (3, 6)


In [8]:
x[0].shape, x[1].shape

((3, 9, 33), (3, 6))