# Diamond Transformer

This notebook investigates the use of a transformer to predict the price of diamonds. The dataset is from Kaggle and can be found [here](https://www.kaggle.com/shivam2503/diamonds).

While many traditional ML and DL techniques work on the dataset, our approach uses far less labeled data while achieving similar results. This is done by using a transformer.

--------------------------------------------------------------------------------

## Libraries


In [5]:
import hashlib
import math

import polars as pl


import polars as pl
import numpy as np


from torch import nn, Tensor
from tqdm.notebook import trange, tqdm

import hephaestus as hp
import torch
from torch.utils.tensorboard import SummaryWriter

The diamonds dataset: df = pl.read_csv("../data/diamonds.csv") df.head()


In [6]:
df = pl.read_csv("../data/diamonds.csv")
df.head()

carat,cut,color,clarity,depth,table,price,x,y,z
f64,str,str,str,f64,f64,i64,f64,f64,f64
0.23,"""Ideal""","""E""","""SI2""",61.5,55.0,326,3.95,3.98,2.43
0.21,"""Premium""","""E""","""SI1""",59.8,61.0,326,3.89,3.84,2.31
0.23,"""Good""","""E""","""VS1""",56.9,65.0,327,4.05,4.07,2.31
0.29,"""Premium""","""I""","""VS2""",62.4,58.0,334,4.2,4.23,2.63
0.31,"""Good""","""J""","""SI2""",63.3,58.0,335,4.34,4.35,2.75


In [7]:
df.describe()

describe,carat,cut,color,clarity,depth,table,price,x,y,z
str,f64,str,str,str,f64,f64,f64,f64,f64,f64
"""count""",53940.0,"""53940""","""53940""","""53940""",53940.0,53940.0,53940.0,53940.0,53940.0,53940.0
"""null_count""",0.0,"""0""","""0""","""0""",0.0,0.0,0.0,0.0,0.0,0.0
"""mean""",0.79794,,,,61.749405,57.457184,3932.799722,5.731157,5.734526,3.538734
"""std""",0.474011,,,,1.432621,2.234491,3989.439738,1.121761,1.142135,0.705699
"""min""",0.2,"""Fair""","""D""","""I1""",43.0,43.0,326.0,0.0,0.0,0.0
"""max""",5.01,"""Very Good""","""J""","""VVS2""",79.0,95.0,18823.0,10.74,58.9,31.8
"""median""",0.7,,,,61.8,57.0,2401.0,5.7,5.71,3.53
"""25%""",0.4,,,,61.0,56.0,950.0,4.71,4.72,2.91
"""75%""",1.04,,,,62.5,59.0,5325.0,6.54,6.54,4.04


In [8]:
# df = hp.scale_numeric(df)
df.head()

carat,cut,color,clarity,depth,table,price,x,y,z
f64,str,str,str,f64,f64,i64,f64,f64,f64
0.23,"""Ideal""","""E""","""SI2""",61.5,55.0,326,3.95,3.98,2.43
0.21,"""Premium""","""E""","""SI1""",59.8,61.0,326,3.89,3.84,2.31
0.23,"""Good""","""E""","""VS1""",56.9,65.0,327,4.05,4.07,2.31
0.29,"""Premium""","""I""","""VS2""",62.4,58.0,334,4.2,4.23,2.63
0.31,"""Good""","""J""","""SI2""",63.3,58.0,335,4.34,4.35,2.75


In [9]:
df = hp.make_lower_remove_special_chars(df)
val_tokens = hp.get_unique_utf8_values(df)
col_tokens = hp.get_col_tokens(df)

In [10]:
df.head()

carat,cut,color,clarity,depth,table,price,x,y,z
f64,str,str,str,f64,f64,i64,f64,f64,f64
0.23,"""ideal""","""e""","""si2""",61.5,55.0,326,3.95,3.98,2.43
0.21,"""premium""","""e""","""si1""",59.8,61.0,326,3.89,3.84,2.31
0.23,"""good""","""e""","""vs1""",56.9,65.0,327,4.05,4.07,2.31
0.29,"""premium""","""i""","""vs2""",62.4,58.0,334,4.2,4.23,2.63
0.31,"""good""","""j""","""si2""",63.3,58.0,335,4.34,4.35,2.75


In [11]:
special_tokens = np.array(
    [
        "missing",
        "<mask>",
        "<numeric_mask>" "<pad>",
        "<unk>",
        ":",
        ",",
        "<row-start>",
        "<row-end>",
    ]
)

In [12]:
tokens = np.unique(
    np.concatenate(
        (
            val_tokens,
            col_tokens,
            special_tokens,
            np.array(
                [
                    "<numeric>",
                ]
            ),
        )
    )
)
tokens

array([',', ':', '<mask>', '<numeric>', '<numeric_mask><pad>',
       '<row-end>', '<row-start>', '<unk>', 'carat', 'clarity', 'color',
       'cut', 'd', 'depth', 'e', 'f', 'fair', 'g', 'good', 'h', 'i', 'i1',
       'ideal', 'if', 'j', 'missing', 'premium', 'price', 'si1', 'si2',
       'table', 'very good', 'vs1', 'vs2', 'vvs1', 'vvs2', 'x', 'y', 'z'],
      dtype=object)

# Train Test Split

To show the actual model performance out of sample we split the data into a training and test set. The training set will be used to train the model and the test set will be used to evaluate the model performance. We will use 80% of the data for training and 20% for testing.

We also remove the price column from the training and test sets and will only use a tiny subset of the data to simulate an industrial process with lots of input data but expensive and limited labeled data.


In [13]:
df = (
    df.with_columns(
        pl.concat_str(pl.all().exclude("price").cast(pl.Utf8)).alias("all_cols")
    )
    .with_columns(
        pl.col("all_cols")
        .apply(lambda x: hashlib.md5(x.encode()).hexdigest())
        .alias("hash")
    )
    .drop("all_cols")
)
df.select(pl.col("hash").is_duplicated().sum())

hash
u32
685


In [14]:
# Shuffle for randomness
df = df.sample(fraction=1.0, seed=42)
df.head()

carat,cut,color,clarity,depth,table,price,x,y,z,hash
f64,str,str,str,f64,f64,i64,f64,f64,f64,str
0.23,"""ideal""","""e""","""si2""",61.5,55.0,326,3.95,3.98,2.43,"""11e7fc6699aadd…"
0.21,"""premium""","""e""","""si1""",59.8,61.0,326,3.89,3.84,2.31,"""da2cbf3ae11a51…"
0.23,"""good""","""e""","""vs1""",56.9,65.0,327,4.05,4.07,2.31,"""119c1ab2831f02…"
0.29,"""premium""","""i""","""vs2""",62.4,58.0,334,4.2,4.23,2.63,"""8f6f35f6b9cc64…"
0.31,"""good""","""j""","""si2""",63.3,58.0,335,4.34,4.35,2.75,"""8df56fd59d650e…"


In [15]:
train_fraction = 0.8
n_train = int(train_fraction * len(df))
train_test_df = df.select(pl.all().exclude(["price", "hash"]))

train, test = train_test_df.head(n_train), train_test_df.tail(
    len(train_test_df) - n_train
)

In [16]:
train.head(), train.shape

(shape: (5, 9)
 ┌───────┬─────────┬───────┬─────────┬───┬───────┬──────┬──────┬──────┐
 │ carat ┆ cut     ┆ color ┆ clarity ┆ … ┆ table ┆ x    ┆ y    ┆ z    │
 │ ---   ┆ ---     ┆ ---   ┆ ---     ┆   ┆ ---   ┆ ---  ┆ ---  ┆ ---  │
 │ f64   ┆ str     ┆ str   ┆ str     ┆   ┆ f64   ┆ f64  ┆ f64  ┆ f64  │
 ╞═══════╪═════════╪═══════╪═════════╪═══╪═══════╪══════╪══════╪══════╡
 │ 0.23  ┆ ideal   ┆ e     ┆ si2     ┆ … ┆ 55.0  ┆ 3.95 ┆ 3.98 ┆ 2.43 │
 │ 0.21  ┆ premium ┆ e     ┆ si1     ┆ … ┆ 61.0  ┆ 3.89 ┆ 3.84 ┆ 2.31 │
 │ 0.23  ┆ good    ┆ e     ┆ vs1     ┆ … ┆ 65.0  ┆ 4.05 ┆ 4.07 ┆ 2.31 │
 │ 0.29  ┆ premium ┆ i     ┆ vs2     ┆ … ┆ 58.0  ┆ 4.2  ┆ 4.23 ┆ 2.63 │
 │ 0.31  ┆ good    ┆ j     ┆ si2     ┆ … ┆ 58.0  ┆ 4.34 ┆ 4.35 ┆ 2.75 │
 └───────┴─────────┴───────┴─────────┴───┴───────┴──────┴──────┴──────┘,
 (43152, 9))

In [17]:
train_test_df.head()

carat,cut,color,clarity,depth,table,x,y,z
f64,str,str,str,f64,f64,f64,f64,f64
0.23,"""ideal""","""e""","""si2""",61.5,55.0,3.95,3.98,2.43
0.21,"""premium""","""e""","""si1""",59.8,61.0,3.89,3.84,2.31
0.23,"""good""","""e""","""vs1""",56.9,65.0,4.05,4.07,2.31
0.29,"""premium""","""i""","""vs2""",62.4,58.0,4.2,4.23,2.63
0.31,"""good""","""j""","""si2""",63.3,58.0,4.34,4.35,2.75


In [18]:
ds = hp.TabularDataset(
    train,
    tokens,
    special_tokens=special_tokens,
    shuffle_cols=True,
    max_row_length=50,
)

print(len(ds[0]))

50


In [19]:
row = [str(i.value) for i in ds[1]]
print("".join(row))

<row-start>clarity:si1,x:3.89,depth:59.8,color:e,cut:premium,y:3.84,z:2.31,carat:0.21,table:61.0,<row-end><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>


In [16]:
print([i.value for i in ds[0]])

['<row-start>', 'depth', ':', 61.5, ',', 'x', ':', 3.95, ',', 'z', ':', 2.43, ',', 'carat', ':', 0.23, ',', 'table', ':', 55.0, ',', 'cut', ':', 'ideal', ',', 'clarity', ':', 'si2', ',', 'color', ':', 'e', ',', 'y', ':', 3.98, ',', '<row-end>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>']


In [17]:
if torch.backends.mps.is_built():
    device_name = "mps"
elif torch.cuda.is_available():
    device_name = "cuda"
else:
    device_name = "cpu"
device = torch.device(device_name)
print(device)

mps


In [18]:
data, targets = hp.batch_data(ds, 1, n_row=1)

In [19]:
n_token = len(ds.vocab)  # size of vocabulary
d_model = 96  # embedding dimension
d_hid = 1_000  # dimension of the feedforward network model in ``nn.TransformerEncoder``
n_layers = 6  # number of ``nn.TransformerEncoderLayer`` in ``nn.TransformerEncoder``
n_head = 6  # number of heads in ``nn.MultiheadAttention``
dropout = 0.2  # dropout probability
model = hp.TransformerModel(
    n_token, d_model, n_head, d_hid, n_layers, device, dropout
).to(device)



In [20]:
import copy
import time


lr = 0.9  # learning rate
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size =100, gamma=0.5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode="min",
    factor=0.8,
    patience=5,
    threshold=0.001,
    threshold_mode="rel",
    cooldown=0,
    min_lr=0.01,
    eps=1e-08,
    verbose=False,
)


def train(model: nn.Module, epochs=1, model_name="") -> None:
    model.train()  # turn on train mode
    total_loss = 0.0
    log_interval = 1000
    lr_eval_interval = 25
    n_row = 100  # one because it's not time series
    start_time = time.time()
    for epoch in trange(1, epochs + 1, leave=True, desc="Epoch"):
        pbar = trange(0, len(ds) - 1, n_row, desc=f"Batch for {epoch}/{epochs}")
        writer = SummaryWriter("runs/" + model_name + "_run_" + str(epoch))
        for batch, i in enumerate(pbar):
            data, targets = hp.batch_data(ds, i, n_row=n_row)
            class_output, numeric_output = model(data)
            loss, loss_dict = hp.hephaestus_loss(
                class_output, numeric_output, targets, tokens, special_tokens, device
            )
            num_loss = loss_dict["reg_loss"].item()
            class_loss = loss_dict["class_loss"].item()
            writer.add_scalar("Loss/total_loss", loss, batch)
            writer.add_scalar("Loss/numeric_loss", num_loss, batch)
            writer.add_scalar("Loss/class_loss", class_loss, batch)
            writer.add_scalar(
                "Metrics/learning_rate", optimizer.param_groups[0]["lr"], batch
            )
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
            optimizer.step()
            pbar.set_postfix(
                {
                    "tl": f"{loss:.2f}",
                    "cl": f"{class_loss:.2f}",
                    "nl": f"{num_loss:.2f}",
                    "tr": f"{optimizer.param_groups[0]['lr']:.2f}",
                },
                refresh=True,
            )
            total_loss += loss.item()
            if batch % lr_eval_interval == 0:
                # pbar.set_p(
                #     f"tl: {}, nl: {num_loss:.2f}, cl: {class_loss:.2f}, tr: {optimizer.param_groups[0]['lr']:.2f}"
                # )

                scheduler.step(loss)

                start_time = time.time()
                # scheduler.step(loss)
        writer.close()

In [21]:
from datetime import datetime as dt

model_time = dt.now()
model_time = model_time.strftime("%Y-%m-%dT%H:%M:%S")

exp_name = "non-scale"

model_name = model_time + "_" + exp_name
epochs = 3
train(model=model, epochs=epochs, model_name=model_name)

Epoch:   0%|          | 0/3 [00:00<?, ?it/s]

Batch for 1/3:   0%|          | 0/432 [00:00<?, ?it/s]

Batch for 2/3:   0%|          | 0/432 [00:00<?, ?it/s]

Batch for 3/3:   0%|          | 0/432 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [22]:
# %%
torch.save(model.state_dict(), model_name)

In [23]:
data, targets = hp.batch_data(ds, 1, n_row=2)

row1 = ds[1]
data_str = "".join([str(i.value) for i in data])
row1_str = "".join([str(i.value) for i in row1])
print("data: ", data_str)
print("row1: ", row1_str)

data:  <row-start>z:2.31,table:61.0,<mask>:3.89,color:e,y:3.84,clarity:si1,cut:premium,depth:59.8,carat:<numeric_mask>,<row-end><pad><pad><pad><pad><pad><mask><pad><pad><pad><pad><pad><pad><row-start>y:4.07,table:65.0,carat:0.23,<mask>:56.9,clarity:vs1,z:2.31,x:4.05,color:e,cut:good,<row-end><pad><pad><pad><pad><mask><mask><pad><mask><mask><pad><pad><pad>
row1:  <row-start>carat:0.21,cut:premium,x:3.89,color:e,table:61.0,y:3.84,clarity:si1,depth:59.8,z:2.31,<row-end><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>


In [24]:
def evaluate(model: nn.Module, ds, idx) -> None:
    model.eval()  # turn on train mode
    n_row = 1  # one because it's not time series
    with torch.no_grad():
        data, targets = hp.batch_data(ds, idx, n_row=n_row)
        class_output, numeric_output = model(data)
        loss, loss_dict = hp.hephaestus_loss(
            class_output, numeric_output, targets, tokens, special_tokens, device
        )
        return {
            "loss": loss.item(),
            "loss_dict": loss_dict,
            "data": data,
            "targets": targets,
            "class_output": class_output,
            "numeric_output": numeric_output,
        }

In [25]:
ds_test = hp.TabularDataset(
    test,
    tokens,
    special_tokens=special_tokens,
    shuffle_cols=False,
    max_row_length=50,
)

In [26]:
res = evaluate(model, ds_test, 1)

In [27]:
res["loss_dict"]["reg_loss"].item()

98.4974365234375

In [28]:
actuals = [str(i.value) for i in res["targets"]]
actuals_ = " ".join(actuals)
actual_str = actuals_.split("<row-end>")[0]
actual_str

'<row-start> carat : 0.52 , cut : ideal , color : e , clarity : si1 , depth : 62.1 , table : 57.0 , x : 5.1 , y : 5.14 , z : 3.18 , '

In [29]:
masked_str = [str(i.value) for i in res["data"]]
masked_str = " ".join(actuals)
masked_str = masked_str.split("<row-end>")[0]
masked_str

'<row-start> carat : 0.52 , cut : ideal , color : e , clarity : si1 , depth : 62.1 , table : 57.0 , x : 5.1 , y : 5.14 , z : 3.18 , '

In [30]:
def show_results(res):
    actuals = [str(i.value) for i in res["targets"]]
    actuals_ = " ".join(actuals)
    actual_str = actuals_.split("<row-end>")[0]
    masked_str = [str(i.value) for i in res["data"]]
    masked_str = " ".join(masked_str)
    masked_str = masked_str.split("<row-end>")[0]

    lsm = nn.Softmax(dim=0)
    softmax_cats = lsm(res["class_output"])
    softmax_cats = torch.argmax(softmax_cats, dim=1)
    gen_tokens = []
    for idx, pred in enumerate(softmax_cats):
        token = tokens[pred - 1]
        if token == "<numeric>":
            gen_tokens.append(str(res["numeric_output"][idx].item()))
        else:
            gen_tokens.append(token)
    preds = " ".join(gen_tokens)

    s = (
        f"Targets   : {actual_str}\n"
        + f"Masked    : {masked_str}\n"
        + f"Predicted : {preds.split('<row-end>')[0]}"
    )
    return s

In [31]:
for i in range(15):
    res = evaluate(model, ds, i)
    print(f"Row {i}")
    print(show_results(res))
    print("")

Row 0
Targets   : <row-start> depth : 61.5 , clarity : si2 , table : 55.0 , cut : ideal , x : 3.95 , y : 3.98 , z : 2.43 , color : e , carat : 0.23 , 
Masked    : <row-start> depth : 61.5 , clarity : si2 , <mask> : <numeric_mask> , cut : ideal , x : 3.95 , y : <numeric_mask> , z : 2.43 , color : <mask> , <mask> : 0.23 , 
Predicted : color <unk> : 61.21510696411133 , clarity : : , clarity : <unk> , <row-start> table very good , very good table 15.061422348022461 , : cut <unk> , table table 9.649274826049805 , <unk> : <unk> , <unk> : : , clarity <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk>

Row 1
Targets   : <row-start> z : 2.31 , carat : 0.21 , y : 3.84 , depth : 59.8 , clarity : si1 , cut : premium , table : 61.0 , color : e , x : 3.89 , 
Masked    : <row-start> z : 2.31 , carat : 0.21 , y : 3.84 , depth : 59.8 , clarity : si1 , cut : premium , table : <numeric_mask> , <mask> : e , x : 3.89 , 
Predicted : color table : 9.25050163269043 , <row-start> : : , : :

In [None]:
test.with_row_count().filter(pl.col("carat") == max(test["carat"]))

row_nr,carat,cut,color,clarity,depth,table,x,y,z
u32,f64,str,str,str,f64,f64,f64,f64,f64
9270,1.059174,"""fair""","""h""","""i1""",1.850171,-0.204603,1.068715,0.985413,1.277126
9653,1.059174,"""fair""","""e""","""i1""",3.316016,0.242926,0.943911,0.889102,1.362148


In [None]:
res = evaluate(model, ds_test, 9270)
print(show_results(res))

Targets   : <row-start> carat : 1.0591737180449114 , cut : fair , color : h , clarity : i1 , depth : 1.8501714799489926 , table : -0.20460319486368622 , x : 1.0687152244462004 , y : 0.9854127282311091 , z : 1.277125824454861 , 
Masked    : <row-start> carat : 1.0591737180449114 , cut : <mask> , color : h , clarity : i1 , depth : 1.8501714799489926 , table : -0.20460319486368622 , x : 1.0687152244462004 , y : 0.9854127282311091 , <mask> : 1.277125824454861 , 
Predicted : <row-start> carat : 0.9448257684707642 , cut : ideal , color : h , clarity : i1 , depth : 1.5026516914367676 , table : -0.3050450086593628 , x : 0.9185384511947632 , y : 0.8406640291213989 , z : 1.0746914148330688 , <row-end> fair <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk>


In [None]:
test.with_row_count().filter(pl.col("carat") == min(test["carat"]))

row_nr,carat,cut,color,clarity,depth,table,x,y,z
u32,f64,str,str,str,f64,f64,f64,f64,f64
837,-1.24035,"""premium""","""e""","""si2""",0.105119,-0.652132,-1.685883,-1.67627,-1.656137


In [None]:
res = evaluate(model, ds_test, 837)
print(show_results(res))

Targets   : <row-start> carat : -1.2403497908694328 , cut : premium , color : e , clarity : si2 , depth : 0.10511857089741324 , table : -0.6521324960111446 , x : -1.6858828561475554 , y : -1.6762698814250705 , z : -1.6561367264851874 , 
Masked    : <row-start> carat : -1.2403497908694328 , cut : premium , <mask> : e , clarity : si2 , <mask> : 0.10511857089741324 , table : -0.6521324960111446 , x : -1.6858828561475554 , y : -1.6762698814250705 , z : <numeric_mask> , 
Predicted : <row-start> carat : -0.954312801361084 , cut : premium , color : e , clarity : si2 , depth : 0.2146976888179779 , table : -0.5769697427749634 , x : -1.261174201965332 , y : -1.2771449089050293 , z : -0.5968276262283325 , <row-end> <unk> <unk> <unk> <unk> <unk> <unk> ideal <unk> <unk> <unk> <unk> <unk>


In [None]:
# preds = [str(i.value) for i in res["class_output"]]
# preds_ = " ".join(preds)
# preds_.split("<row-end>")[0]

In [None]:
lsm = nn.Softmax(dim=0)
softmax_cats = lsm(res["class_output"])
softmax_cats = torch.argmax(softmax_cats, dim=1)

In [None]:
softmax_cats, softmax_cats.shape

(tensor([ 7,  9,  2,  4,  1, 12,  2, 23,  1, 11,  2, 18,  1, 10,  2, 29,  1, 14,
          2,  4,  1, 31,  2,  4,  1, 37,  2,  4,  1, 38,  2,  4,  1, 39,  2,  4,
          1,  6,  8,  8,  8,  8,  8,  8, 17,  8,  8,  8,  8,  8],
        device='mps:0'),
 torch.Size([50]))

In [None]:
gen_tokens = []
for idx, pred in enumerate(softmax_cats):
    token = tokens[pred - 1]
    if token == "<numeric>":
        gen_tokens.append(str(res["numeric_output"][idx].item()))
    else:
        gen_tokens.append(token)
preds = " ".join(gen_tokens)
print(f"""Predicted Row:\n\n{preds.split("<row-end>")[0]}""")
print(f"""\nActual Row:\n{actual_str}""")

Predicted Row:

<row-start> carat : -0.6318099498748779 , cut : ideal , color : g , clarity : si1 , depth : -0.9217300415039062 , table : -0.2088424265384674 , x : -0.28772419691085815 , y : -0.6106739044189453 , z : -0.47605791687965393 , 

Actual Row:
<row-start> carat : -0.5863568663158119 , cut : ideal , color : e , clarity : si1 , depth : 0.24472280362154117 , table : -0.20460319486368622 , x : -0.5626486873617523 , y : -0.520539274600677 , z : -0.5083383369869077 , 


In [None]:
def gen_random_number():
    num = np.random.rand()
    return num


for i in range(10):
    print(gen_random_number())

0.9683498582258889
0.8052319157998697
0.3877478747690536
0.9181466762759316
0.5964743420946779
0.0875938108238562
0.40845215274416213
0.9158682045747477
0.1687792220992772
0.8808815567770297
