In [1]:
import polars as pl
import numpy as np
from transformers import BertTokenizer, BertForPreTraining

import numpy as np
import torch.nn as nn
import torch
import hephaestus as hp

import hashlib
import math

In [2]:
df = pl.read_csv("../data/diamonds.csv")
df.head()

carat,cut,color,clarity,depth,table,price,x,y,z
f64,str,str,str,f64,f64,i64,f64,f64,f64
0.23,"""Ideal""","""E""","""SI2""",61.5,55.0,326,3.95,3.98,2.43
0.21,"""Premium""","""E""","""SI1""",59.8,61.0,326,3.89,3.84,2.31
0.23,"""Good""","""E""","""VS1""",56.9,65.0,327,4.05,4.07,2.31
0.29,"""Premium""","""I""","""VS2""",62.4,58.0,334,4.2,4.23,2.63
0.31,"""Good""","""J""","""SI2""",63.3,58.0,335,4.34,4.35,2.75


In [3]:
df = hp.make_lower_remove_special_chars(df)
val_tokens = hp.get_unique_utf8_values(df)
col_tokens = hp.get_col_tokens(df)

In [4]:
special_tokens = np.array(
    [
        "missing",
        "<mask>",
        "<numeric_mask>" "<pad>",
        "<unk>",
        ":",
        ",",
        "<row-start>",
        "<row-end>",
    ]
)

In [5]:
tokens = np.unique(
    np.concatenate(
        (
            val_tokens,
            col_tokens,
            special_tokens,
            np.array(
                [
                    "<numeric>",
                ]
            ),
        )
    )
)
tokens

array([',', ':', '<mask>', '<numeric>', '<numeric_mask><pad>',
       '<row-end>', '<row-start>', '<unk>', 'carat', 'clarity', 'color',
       'cut', 'd', 'depth', 'e', 'f', 'fair', 'g', 'good', 'h', 'i', 'i1',
       'ideal', 'if', 'j', 'missing', 'premium', 'price', 'si1', 'si2',
       'table', 'very good', 'vs1', 'vs2', 'vvs1', 'vvs2', 'x', 'y', 'z'],
      dtype=object)

In [6]:
df = (
    df.with_columns(
        pl.concat_str(pl.all().exclude("price").cast(pl.Utf8)).alias("all_cols")
    )
    .with_columns(
        pl.col("all_cols")
        .apply(lambda x: hashlib.md5(x.encode()).hexdigest())
        .alias("hash")
    )
    .drop("all_cols")
)
df.select(pl.col("hash").is_duplicated().sum())

hash
u32
685


In [7]:
train_fraction = 0.8
n_train = int(train_fraction * len(df))
train_test_df = df.select(pl.all().exclude(["price", "hash"]))

train, test = train_test_df.head(n_train), train_test_df.tail(
    len(train_test_df) - n_train
)

In [8]:
ds = hp.TabularDataset(
    train,
    tokens,
    special_tokens=special_tokens,
    shuffle_cols=True,
    max_row_length=50,
)

print(len(ds[0]))

50


In [9]:
if torch.backends.mps.is_built():
    device_name = "mps"
elif torch.cuda.is_available():
    device_name = "cuda"
else:
    device_name = "cpu"
device = torch.device(device_name)
print(device)

mps


In [10]:
class StringNumericEmbedding(nn.Module):
    def __init__(
        self,
        state_dict,
        device: torch.device,
        tokenizer,
        bert_model_name="bert-base-uncased",
    ):
        super().__init__()
        self.device = device
        # self.bert_tokenizer = BertTokenizer.from_pretrained(bert_model_name)
        self.bert_tokenizer = tokenizer
        self.word_embeddings = nn.Embedding(*state_dict["weight"].shape).to(device)
        self.word_embeddings.load_state_dict(state_dict)  # .to(device)
        self.numeric_embedding = nn.Sequential(
            nn.Linear(1, 128),  # First hidden layer
            nn.ReLU(),
            nn.Linear(128, 64),  # Second hidden layer
            nn.ReLU(),
            nn.Linear(64, state_dict["weight"].shape[1]),  # Output layer
        ).to(device)

        # self.numeric_embedding = nn.Linear(1, d_model).to(device)

    def forward(self, input: hp.StringNumeric):
        tensor_list = []
        for idx, val in enumerate(input):
            if val.is_numeric:
                val = torch.Tensor([val.value]).float().to(self.device)
                val = self.numeric_embedding(val)
                val = val.reshape(1, 1, -1)  # val.shape[0])
                tensor_list.append(val)
            else:
                tokens_ids = self.bert_tokenizer.encode_plus(
                    val.value, return_tensors="pt", add_special_tokens=False
                )
                tensor_list.append(
                    self.word_embeddings(tokens_ids["input_ids"].to(self.device))
                )

        # return tensor_list
        return torch.cat(tensor_list, dim=-2)

In [11]:
class HybridBertModel(nn.Module):
    def __init__(
        self,
        device: torch.device,
        bert_model_name="bert-base-uncased",
    ):
        super(HybridBertModel, self).__init__()

        # BERT Tokenizer and Model
        self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
        self.bert_lm = BertForPreTraining.from_pretrained("bert-base-uncased")
        # self.tokenizer = BertTokenizer.from_pretrained(bert_model_name)
        self.tokenizer.add_tokens(
            [
                "<numeric>",
                "<numeric-mask>",
                "<row-start>",
                "<row-end>",
            ]
        )

        # Add tokens to BERT model

        # self.bert = BertModel.from_pretrained(bert_model_name).to(device)
        self.bert_lm.resize_token_embeddings(len(self.tokenizer))
        self.bert_embedding_state_dict = (
            self.bert_lm.bert.embeddings.word_embeddings.state_dict()
        )
        self.embedding_dim = self.bert_lm.bert.config.hidden_size
        self.string_numeric_embd = StringNumericEmbedding(
            state_dict=self.bert_embedding_state_dict,
            device=device,
            tokenizer=self.tokenizer,
            bert_model_name=bert_model_name,
        )
        # self.decoder = nn.Linear(self.embedding_dim, len(self.tokenizer)).to(device)
        # Numeric Neural Net for numbers prediction after BERT
        self.numeric_predictor = nn.Sequential(
            nn.Linear(self.embedding_dim, 128), nn.ReLU(), nn.Linear(128, 1)
        )

    def forward(self, input: hp.StringNumeric):
        input = self.string_numeric_embd(input)
        bert_output = self.bert_lm.bert(inputs_embeds=input)
        last_hidden_state = bert_output.last_hidden_state
        pooled_output = bert_output.pooler_output

        bert_logits = self.bert_lm.cls(last_hidden_state, pooled_output)
        numeric_prediction = self.numeric_predictor(last_hidden_state)
        # mlm_output = self.decoder(mlm_output.last_hidden_state)
        return bert_logits, numeric_prediction


# Sample usage:

# Assuming we have our input prepared as:
# input_data = [
#     # hp.StringNumeric("<row-start>"),
#     hp.StringNumeric("Hello"),
#     # hp.StringNumeric(42.0),
#     hp.StringNumeric("world"),
#     # hp.StringNumeric(12),
#     # hp.StringNumeric("<row-end>"),
# ]
input_data = [
    hp.StringNumeric(i)
    for i in "Hello Greg, my name is Kai, nice to [MASK] [MASK]!".split()
]
input_data.append(hp.StringNumeric(42.0))
model = HybridBertModel(device=device).to(device)
# input_data_tensor = StringNumeric(input_data, device=device)
cat_preds, numeric_preds = model(input_data)
# print(type(mlm_output))

In [17]:
x["input_ids"]

tensor([[ 7592,  6754,  1010,  2026,  2171,  2003, 11928,  1010,  3835,  2000,
           103,   103,   999, 30522]])

In [19]:
model.tokenizer.decode(token_ids=x["input_ids"][0])

'hello greg, my name is kai, nice to [MASK] [MASK]! <numeric>'

In [114]:
model.tokenizer.decode(preds)

', greg, my name is kai, welcome to be you!,'