# Diamond Transformer

This notebook investigates the use of a transformer to predict the price of diamonds. The dataset is from Kaggle and can be found [here](https://www.kaggle.com/shivam2503/diamonds).

While many traditional ML and DL techniques work on the dataset, our approach uses far less labeled data while achieving similar results. This is done by using a transformer.

--------------------------------------------------------------------------------

## Libraries


In [1]:
import hashlib
import math

import polars as pl


import polars as pl
import numpy as np


from torch import nn, Tensor
from tqdm.notebook import trange, tqdm

import hephaestus as hp
import torch
from torch.utils.tensorboard import SummaryWriter

The diamonds dataset: df = pl.read_csv("../data/diamonds.csv") df.head()


In [2]:
df = pl.read_csv("../data/diamonds.csv")
df.head()

carat,cut,color,clarity,depth,table,price,x,y,z
f64,str,str,str,f64,f64,i64,f64,f64,f64
0.23,"""Ideal""","""E""","""SI2""",61.5,55.0,326,3.95,3.98,2.43
0.21,"""Premium""","""E""","""SI1""",59.8,61.0,326,3.89,3.84,2.31
0.23,"""Good""","""E""","""VS1""",56.9,65.0,327,4.05,4.07,2.31
0.29,"""Premium""","""I""","""VS2""",62.4,58.0,334,4.2,4.23,2.63
0.31,"""Good""","""J""","""SI2""",63.3,58.0,335,4.34,4.35,2.75


In [3]:
df.describe()

describe,carat,cut,color,clarity,depth,table,price,x,y,z
str,f64,str,str,str,f64,f64,f64,f64,f64,f64
"""count""",53940.0,"""53940""","""53940""","""53940""",53940.0,53940.0,53940.0,53940.0,53940.0,53940.0
"""null_count""",0.0,"""0""","""0""","""0""",0.0,0.0,0.0,0.0,0.0,0.0
"""mean""",0.79794,,,,61.749405,57.457184,3932.799722,5.731157,5.734526,3.538734
"""std""",0.474011,,,,1.432621,2.234491,3989.439738,1.121761,1.142135,0.705699
"""min""",0.2,"""Fair""","""D""","""I1""",43.0,43.0,326.0,0.0,0.0,0.0
"""max""",5.01,"""Very Good""","""J""","""VVS2""",79.0,95.0,18823.0,10.74,58.9,31.8
"""median""",0.7,,,,61.8,57.0,2401.0,5.7,5.71,3.53
"""25%""",0.4,,,,61.0,56.0,950.0,4.71,4.72,2.91
"""75%""",1.04,,,,62.5,59.0,5325.0,6.54,6.54,4.04


In [4]:
df = hp.scale_numeric(df)
df.head()

carat,cut,color,clarity,depth,table,price,x,y,z
f64,str,str,str,f64,f64,f64,f64,f64,f64
-1.198157,"""Ideal""","""E""","""SI2""",-0.17409,-1.099662,-0.904087,-1.587823,-1.536181,-1.571115
-1.24035,"""Premium""","""E""","""SI1""",-1.360726,1.585514,-0.904087,-1.64131,-1.658759,-1.741159
-1.198157,"""Good""","""E""","""VS1""",-3.384987,3.375631,-0.903836,-1.498677,-1.457382,-1.741159
-1.071577,"""Premium""","""I""","""VS2""",0.454129,0.242926,-0.902081,-1.364959,-1.317293,-1.287708
-1.029384,"""Good""","""J""","""SI2""",1.082348,0.242926,-0.901831,-1.240155,-1.212227,-1.117663


In [5]:
df = hp.make_lower_remove_special_chars(df)
val_tokens = hp.get_unique_utf8_values(df)
col_tokens = hp.get_col_tokens(df)

In [6]:
df.head()

carat,cut,color,clarity,depth,table,price,x,y,z
f64,str,str,str,f64,f64,f64,f64,f64,f64
-1.198157,"""ideal""","""e""","""si2""",-0.17409,-1.099662,-0.904087,-1.587823,-1.536181,-1.571115
-1.24035,"""premium""","""e""","""si1""",-1.360726,1.585514,-0.904087,-1.64131,-1.658759,-1.741159
-1.198157,"""good""","""e""","""vs1""",-3.384987,3.375631,-0.903836,-1.498677,-1.457382,-1.741159
-1.071577,"""premium""","""i""","""vs2""",0.454129,0.242926,-0.902081,-1.364959,-1.317293,-1.287708
-1.029384,"""good""","""j""","""si2""",1.082348,0.242926,-0.901831,-1.240155,-1.212227,-1.117663


In [7]:
special_tokens = np.array(
    [
        "missing",
        "[MASK]",
        "[NUMERIC]",
        "<unk>",
        ":",
        ",",
        "<row-start>",
        "<row-end>",
    ]
)

In [8]:
tokens = np.unique(
    np.concatenate(
        (
            val_tokens,
            col_tokens,
            special_tokens,
            np.array(
                [
                    "<numeric>",
                ]
            ),
        )
    )
)
tokens

array([',', ':', '<numeric>', '<row-end>', '<row-start>', '<unk>',
       '[MASK]', '[NUMERIC]', 'carat', 'clarity', 'color', 'cut', 'd',
       'depth', 'e', 'f', 'fair', 'g', 'good', 'h', 'i', 'i1', 'ideal',
       'if', 'j', 'missing', 'premium', 'price', 'si1', 'si2', 'table',
       'very good', 'vs1', 'vs2', 'vvs1', 'vvs2', 'x', 'y', 'z'],
      dtype=object)

# Train Test Split

To show the actual model performance out of sample we split the data into a training and test set. The training set will be used to train the model and the test set will be used to evaluate the model performance. We will use 80% of the data for training and 20% for testing.

We also remove the price column from the training and test sets and will only use a tiny subset of the data to simulate an industrial process with lots of input data but expensive and limited labeled data.


In [9]:
df = (
    df.with_columns(
        pl.concat_str(pl.all().exclude("price").cast(pl.Utf8)).alias("all_cols")
    )
    .with_columns(
        pl.col("all_cols")
        .apply(lambda x: hashlib.md5(x.encode()).hexdigest())
        .alias("hash")
    )
    .drop("all_cols")
)
df.select(pl.col("hash").is_duplicated().sum())

hash
u32
685


In [10]:
# Shuffle for randomness
df = df.sample(fraction=1.0, seed=42)
df.head()

carat,cut,color,clarity,depth,table,price,x,y,z,hash
f64,str,str,str,f64,f64,f64,f64,f64,f64,str
-1.198157,"""ideal""","""e""","""si2""",-0.17409,-1.099662,-0.904087,-1.587823,-1.536181,-1.571115,"""501249736f26c3…"
-1.24035,"""premium""","""e""","""si1""",-1.360726,1.585514,-0.904087,-1.64131,-1.658759,-1.741159,"""5427305ea67a9e…"
-1.198157,"""good""","""e""","""vs1""",-3.384987,3.375631,-0.903836,-1.498677,-1.457382,-1.741159,"""155663ac256a0e…"
-1.071577,"""premium""","""i""","""vs2""",0.454129,0.242926,-0.902081,-1.364959,-1.317293,-1.287708,"""45ee1317264026…"
-1.029384,"""good""","""j""","""si2""",1.082348,0.242926,-0.901831,-1.240155,-1.212227,-1.117663,"""bce254e518e158…"


In [11]:
train_fraction = 0.8
n_train = int(train_fraction * len(df))
train_test_df = df.select(pl.all().exclude(["price", "hash"]))

train, test = train_test_df.head(n_train), train_test_df.tail(
    len(train_test_df) - n_train
)

In [12]:
train.head(), train.shape

(shape: (5, 9)
 ┌───────────┬─────────┬───────┬─────────┬───┬───────────┬───────────┬───────────┬───────────┐
 │ carat     ┆ cut     ┆ color ┆ clarity ┆ … ┆ table     ┆ x         ┆ y         ┆ z         │
 │ ---       ┆ ---     ┆ ---   ┆ ---     ┆   ┆ ---       ┆ ---       ┆ ---       ┆ ---       │
 │ f64       ┆ str     ┆ str   ┆ str     ┆   ┆ f64       ┆ f64       ┆ f64       ┆ f64       │
 ╞═══════════╪═════════╪═══════╪═════════╪═══╪═══════════╪═══════════╪═══════════╪═══════════╡
 │ -1.198157 ┆ ideal   ┆ e     ┆ si2     ┆ … ┆ -1.099662 ┆ -1.587823 ┆ -1.536181 ┆ -1.571115 │
 │ -1.24035  ┆ premium ┆ e     ┆ si1     ┆ … ┆ 1.585514  ┆ -1.64131  ┆ -1.658759 ┆ -1.741159 │
 │ -1.198157 ┆ good    ┆ e     ┆ vs1     ┆ … ┆ 3.375631  ┆ -1.498677 ┆ -1.457382 ┆ -1.741159 │
 │ -1.071577 ┆ premium ┆ i     ┆ vs2     ┆ … ┆ 0.242926  ┆ -1.364959 ┆ -1.317293 ┆ -1.287708 │
 │ -1.029384 ┆ good    ┆ j     ┆ si2     ┆ … ┆ 0.242926  ┆ -1.240155 ┆ -1.212227 ┆ -1.117663 │
 └───────────┴─────────┴───────┴───

In [13]:
train_test_df.head()

carat,cut,color,clarity,depth,table,x,y,z
f64,str,str,str,f64,f64,f64,f64,f64
-1.198157,"""ideal""","""e""","""si2""",-0.17409,-1.099662,-1.587823,-1.536181,-1.571115
-1.24035,"""premium""","""e""","""si1""",-1.360726,1.585514,-1.64131,-1.658759,-1.741159
-1.198157,"""good""","""e""","""vs1""",-3.384987,3.375631,-1.498677,-1.457382,-1.741159
-1.071577,"""premium""","""i""","""vs2""",0.454129,0.242926,-1.364959,-1.317293,-1.287708
-1.029384,"""good""","""j""","""si2""",1.082348,0.242926,-1.240155,-1.212227,-1.117663


In [14]:
ds = hp.TabularDataset(
    train,
    tokens,
    special_tokens=special_tokens,
    shuffle_cols=True,
    max_row_length=50,
)

print(len(ds[0]))

38


In [15]:
row = [str(i.value) for i in ds[1]]
print("".join(row))

<row-start>color:e,table:1.5855140097261475,y:-1.6587588116247007,depth:-1.3607258727059102,cut:premium,carat:-1.2403497908694328,x:-1.6413100716719282,clarity:si1,z:-1.7411588294109859,<row-end>


In [16]:
print([i.value for i in ds[0]])

['<row-start>', 'clarity', ':', 'si2', ',', 'color', ':', 'e', ',', 'x', ':', -1.5878227303011756, ',', 'cut', ':', 'ideal', ',', 'table', ':', -1.0996617971586031, ',', 'depth', ':', -0.17408989455083768, ',', 'carat', ':', -1.1981566989627475, ',', 'z', ':', -1.5711146235593887, ',', 'y', ':', -1.5361813230221135, ',', '<row-end>']


In [17]:
if torch.backends.mps.is_built():
    device_name = "mps"
elif torch.cuda.is_available():
    device_name = "cuda"
else:
    device_name = "cpu"
device = torch.device(device_name)
print(device)

mps


In [18]:
# n_token = len(ds.vocab)  # size of vocabulary
# d_model = 96  # embedding dimension
# d_hid = 1_000  # dimension of the feedforward network model in ``nn.TransformerEncoder``
# n_layers = 6  # number of ``nn.TransformerEncoderLayer`` in ``nn.TransformerEncoder``
# n_head = 6  # number of heads in ``nn.MultiheadAttention``
# dropout = 0.2  # dropout probability
# model = hp.TransformerModel(
#     n_token, d_model, n_head, d_hid, n_layers, device, dropout
# ).to(device)
model = hp.TransformerModel(device=device).to(device)
data, targets, is_numeric_mask = hp.batch_data(ds, 1, model, n_row=1)

In [19]:
import copy
import time


lr = 0.9  # learning rate
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size =100, gamma=0.5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode="min",
    factor=0.8,
    patience=5,
    threshold=0.001,
    threshold_mode="rel",
    cooldown=0,
    min_lr=0.01,
    eps=1e-08,
    verbose=False,
)


def train(model: nn.Module, epochs=1, model_name="") -> None:
    model.train()  # turn on train mode
    total_loss = 0.0
    log_interval = 1000
    lr_eval_interval = 25
    n_row = 2  # one because it's not time series
    start_time = time.time()
    for epoch in trange(1, epochs + 1, leave=True, desc="Epoch"):
        pbar = trange(0, len(ds) - 1, n_row, desc=f"Batch for {epoch}/{epochs}")
        writer = SummaryWriter("runs/" + model_name + "_run_" + str(epoch))
        for batch, i in enumerate(pbar):
            data, targets = hp.batch_data(
                ds, i, model, n_row=n_row
            )  # hp.batch_data(ds, 1, model, n_row=1)
            class_preds, numeric_preds = model(data)
            loss, loss_dict = hp.hephaestus_loss(
                class_preds, numeric_preds, targets, model
            )
            num_loss = loss_dict["reg_loss"].item()
            class_loss = loss_dict["class_loss"].item()
            writer.add_scalar("Loss/total_loss", loss, batch)
            writer.add_scalar("Loss/numeric_loss", num_loss, batch)
            writer.add_scalar("Loss/class_loss", class_loss, batch)
            writer.add_scalar(
                "Metrics/learning_rate", optimizer.param_groups[0]["lr"], batch
            )
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
            optimizer.step()
            pbar.set_postfix(
                {
                    "tl": f"{loss:.2f}",
                    "cl": f"{class_loss:.2f}",
                    "nl": f"{num_loss:.2f}",
                    "tr": f"{optimizer.param_groups[0]['lr']:.2f}",
                },
                refresh=True,
            )
            total_loss += loss.item()
            if batch % lr_eval_interval == 0:
                # pbar.set_p(
                #     f"tl: {}, nl: {num_loss:.2f}, cl: {class_loss:.2f}, tr: {optimizer.param_groups[0]['lr']:.2f}"
                # )

                scheduler.step(loss)

                start_time = time.time()
                # scheduler.step(loss)
        writer.close()

In [20]:
is_numeric_mask = [True, False, True, True, False]
random_nums = torch.rand(len(is_numeric_mask), 10).to(device)
actual_num_index = [idx for idx, val in enumerate(is_numeric_mask) if val]
actual_num_index = torch.tensor(actual_num_index).to(device)
random_nums[actual_num_index].shape

torch.Size([3, 10])

In [104]:
lr = 0.1  # learning rate
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size =100, gamma=0.5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode="min",
    factor=0.8,
    patience=5,
    threshold=0.001,
    threshold_mode="rel",
    cooldown=0,
    min_lr=0.0001,
    eps=1e-08,
    verbose=False,
)

epochs = 1
model_name = "diamonds_scaled_lowerMin"
model.train()  # turn on train mode
total_loss = 0.0
log_interval = 1000
lr_eval_interval = 25
n_row = 1  # one because it's not time series
start_time = time.time()
for epoch in trange(1, epochs + 1, leave=True, desc="Epoch"):
    pbar = trange(0, len(ds) - 1, n_row, desc=f"Batch for {epoch}/{epochs}")
    writer = SummaryWriter("runs/" + model_name + "_run_" + str(epoch))
    for batch, i in enumerate(pbar):
        data, target, is_numeric_mask = hp.batch_data(
            ds, i, model, n_row=n_row
        )  # hp.batch_data(ds, 1, model, n_row=1)
        class_preds, numeric_preds = model(data)
        loss, loss_dict = hp.hephaestus_loss(
            class_preds, numeric_preds, target, is_numeric_mask, model
        )
        num_loss = loss_dict["reg_loss"].item()
        class_loss = loss_dict["class_loss"].item()
        writer.add_scalar("Loss/total_loss", loss, batch)
        writer.add_scalar("Loss/numeric_loss", num_loss, batch)
        writer.add_scalar("Loss/class_loss", class_loss, batch)
        writer.add_scalar(
            "Metrics/learning_rate", optimizer.param_groups[0]["lr"], batch
        )
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
        optimizer.step()
        pbar.set_postfix(
            {
                "tl": f"{loss:.2f}",
                "cl": f"{class_loss:.2f}",
                "nl": f"{num_loss:.2f}",
                "tr": f"{optimizer.param_groups[0]['lr']:.2f}",
            },
            refresh=True,
        )
        total_loss += loss.item()
        if batch % lr_eval_interval == 0:
            # pbar.set_p(
            #     f"tl: {}, nl: {num_loss:.2f}, cl: {class_loss:.2f}, tr: {optimizer.param_groups[0]['lr']:.2f}"
            # )

            scheduler.step(loss)

            start_time = time.time()
            # scheduler.step(loss)
    writer.close()

Epoch:   0%|          | 0/1 [00:00<?, ?it/s]

Batch for 1/1:   0%|          | 0/43151 [00:00<?, ?it/s]

In [105]:
# save model
torch.save(model.state_dict(), "models/BERTDiamondsBetter.pt")
# model.save_pretrained('models/BERTDiamonds.pt')

In [22]:
data, target, is_numeric_mask = hp.batch_data(
    ds, 5, model, n_row=1
)  # hp.batch_data(ds, 1, model, n_row=1)
target_tensor = hp.gen_class_target_tokens(model, target)
class_preds, numeric_preds = model(data)
print(
    f"class_preds.shape: {class_preds.shape}, target_tensor.shape: {target_tensor.shape}"
)

class_preds.shape: torch.Size([1, 43, 30527]), target_tensor.shape: torch.Size([43])


In [77]:
data, target, is_numeric_mask = hp.batch_data(ds, 0, model, n_row=1)
with torch.no_grad():
    class_preds, numeric_preds = model(data)

In [78]:
class_preds_max = torch.argmax(class_preds.squeeze(), dim=1)
class_preds_decode = model.tokenizer.decode(class_preds_max)
print(class_preds_decode)

[CLS] <row-start> z : [numeric], y : [numeric], clarity : si1, cut : ideal, table : [numeric], depth : [numeric], x : [numeric], carat : [numeric], color : e, <row-end>


In [79]:
class_numeric = []
for idx, num in enumerate(class_preds_max):
    decoded_num = model.tokenizer.decode([num], skip_special_tokens=True)
    # print(decoded_num)
    if decoded_num == "[numeric]":
        class_numeric.append(str(numeric_preds.squeeze()[idx].item()))
    else:
        class_numeric.append(decoded_num)
print(" ".join(class_numeric))

 <row-start> z : -1.0915077924728394 , y : -1.0985910892486572 , clarity : si ##1 , cut : ideal , table : -0.740872323513031 , depth : -0.5159561634063721 , x : -0.8603442907333374 , cara ##t : -0.8042788505554199 , color : e , <row-end>


In [80]:
actuals = []
for i in target:
    actuals.append(str(i.value))
print(" ".join(actuals))

<row-start> z : -1.5711146235593887 , y : -1.5361813230221135 , clarity : si2 , cut : ideal , table : -1.0996617971586031 , depth : -0.17408989455083768 , x : -1.5878227303011756 , carat : -1.1981566989627475 , color : e , <row-end>


In [None]:
from datetime import datetime as dt

model_time = dt.now()
model_time = model_time.strftime("%Y-%m-%dT%H:%M:%S")

exp_name = "non-scale"

model_name = model_time + "_" + exp_name
epochs = 3
train(model=model, epochs=epochs, model_name=model_name)

In [None]:
ds_test = hp.TabularDataset(
    test,
    tokens,
    special_tokens=special_tokens,
    shuffle_cols=False,
    max_row_length=50,
)

In [None]:
res = hp.evaluate_custom(model, ds_test, 1)