# Diamond Transformer

This notebook investigates the use of a transformer to predict the price of diamonds. The dataset is from Kaggle and can be found [here](https://www.kaggle.com/shivam2503/diamonds).

While many traditional ML and DL techniques work on the dataset, our approach uses far less labeled data while achieving similar results. This is done by using a transformer.

--------------------------------------------------------------------------------

## Libraries


In [1]:
import hashlib
import math

import polars as pl


import polars as pl
import numpy as np


from torch import nn, Tensor
from tqdm import trange, tqdm

import hephaestus as hp
import torch
from torch.utils.tensorboard import SummaryWriter


The diamonds dataset: df = pl.read_csv("../data/diamonds.csv") df.head()


In [2]:
df = pl.read_csv("../data/diamonds.csv")
df.head()

carat,cut,color,clarity,depth,table,price,x,y,z
f64,str,str,str,f64,f64,i64,f64,f64,f64
0.23,"""Ideal""","""E""","""SI2""",61.5,55.0,326,3.95,3.98,2.43
0.21,"""Premium""","""E""","""SI1""",59.8,61.0,326,3.89,3.84,2.31
0.23,"""Good""","""E""","""VS1""",56.9,65.0,327,4.05,4.07,2.31
0.29,"""Premium""","""I""","""VS2""",62.4,58.0,334,4.2,4.23,2.63
0.31,"""Good""","""J""","""SI2""",63.3,58.0,335,4.34,4.35,2.75


In [3]:
df.describe()

describe,carat,cut,color,clarity,depth,table,price,x,y,z
str,f64,str,str,str,f64,f64,f64,f64,f64,f64
"""count""",53940.0,"""53940""","""53940""","""53940""",53940.0,53940.0,53940.0,53940.0,53940.0,53940.0
"""null_count""",0.0,"""0""","""0""","""0""",0.0,0.0,0.0,0.0,0.0,0.0
"""mean""",0.79794,,,,61.749405,57.457184,3932.799722,5.731157,5.734526,3.538734
"""std""",0.474011,,,,1.432621,2.234491,3989.439738,1.121761,1.142135,0.705699
"""min""",0.2,"""Fair""","""D""","""I1""",43.0,43.0,326.0,0.0,0.0,0.0
"""max""",5.01,"""Very Good""","""J""","""VVS2""",79.0,95.0,18823.0,10.74,58.9,31.8
"""median""",0.7,,,,61.8,57.0,2401.0,5.7,5.71,3.53
"""25%""",0.4,,,,61.0,56.0,950.0,4.71,4.72,2.91
"""75%""",1.04,,,,62.5,59.0,5325.0,6.54,6.54,4.04


In [4]:
df = hp.scale_numeric(df)
df.head()

carat,cut,color,clarity,depth,table,price,x,y,z
f64,str,str,str,f64,f64,f64,f64,f64,f64
-1.198157,"""Ideal""","""E""","""SI2""",-0.17409,-1.099662,-0.904087,-1.587823,-1.536181,-1.571115
-1.24035,"""Premium""","""E""","""SI1""",-1.360726,1.585514,-0.904087,-1.64131,-1.658759,-1.741159
-1.198157,"""Good""","""E""","""VS1""",-3.384987,3.375631,-0.903836,-1.498677,-1.457382,-1.741159
-1.071577,"""Premium""","""I""","""VS2""",0.454129,0.242926,-0.902081,-1.364959,-1.317293,-1.287708
-1.029384,"""Good""","""J""","""SI2""",1.082348,0.242926,-0.901831,-1.240155,-1.212227,-1.117663


In [5]:
df = hp.make_lower_remove_special_chars(df)
val_tokens = hp.get_unique_utf8_values(df)
col_tokens = hp.get_col_tokens(df)

In [6]:
df.head()

carat,cut,color,clarity,depth,table,price,x,y,z
f64,str,str,str,f64,f64,f64,f64,f64,f64
-1.198157,"""ideal""","""e""","""si2""",-0.17409,-1.099662,-0.904087,-1.587823,-1.536181,-1.571115
-1.24035,"""premium""","""e""","""si1""",-1.360726,1.585514,-0.904087,-1.64131,-1.658759,-1.741159
-1.198157,"""good""","""e""","""vs1""",-3.384987,3.375631,-0.903836,-1.498677,-1.457382,-1.741159
-1.071577,"""premium""","""i""","""vs2""",0.454129,0.242926,-0.902081,-1.364959,-1.317293,-1.287708
-1.029384,"""good""","""j""","""si2""",1.082348,0.242926,-0.901831,-1.240155,-1.212227,-1.117663


In [7]:
special_tokens = np.array(
    [
        "missing",
        "<mask>",
        "<numeric_mask>"
        "<pad>",
        "<unk>",
        ":",
        ",",
        "<row-start>",
        "<row-end>",
    ]
)

In [8]:
tokens = np.unique(
    np.concatenate(
        (
            val_tokens,
            col_tokens,
            special_tokens,
            np.array(["<numeric>",])
        )
    )
)
tokens

array([',', ':', '<mask>', '<numeric>', '<numeric_mask><pad>',
       '<row-end>', '<row-start>', '<unk>', 'carat', 'clarity', 'color',
       'cut', 'd', 'depth', 'e', 'f', 'fair', 'g', 'good', 'h', 'i', 'i1',
       'ideal', 'if', 'j', 'missing', 'premium', 'price', 'si1', 'si2',
       'table', 'very good', 'vs1', 'vs2', 'vvs1', 'vvs2', 'x', 'y', 'z'],
      dtype=object)

# Train Test Split

To show the actual model performance out of sample we split the data into a training and test set. The training set will be used to train the model and the test set will be used to evaluate the model performance. We will use 80% of the data for training and 20% for testing.

We also remove the price column from the training and test sets and will only use a tiny subset of the data to simulate an industrial process with lots of input data but expensive and limited labeled data.


In [9]:
df = (
    df.with_columns(
        pl.concat_str(pl.all().exclude("price").cast(pl.Utf8)).alias("all_cols")
    )
    .with_columns(
        pl.col("all_cols")
        .apply(lambda x: hashlib.md5(x.encode()).hexdigest())
        .alias("hash")
    )
    .drop("all_cols")
)
df.select(pl.col("hash").is_duplicated().sum())

hash
u32
685


In [10]:
# Shuffle for randomness
df = df.sample(fraction=1.0, seed=42)
df.head()

carat,cut,color,clarity,depth,table,price,x,y,z,hash
f64,str,str,str,f64,f64,f64,f64,f64,f64,str
-1.198157,"""ideal""","""e""","""si2""",-0.17409,-1.099662,-0.904087,-1.587823,-1.536181,-1.571115,"""501249736f26c3…"
-1.24035,"""premium""","""e""","""si1""",-1.360726,1.585514,-0.904087,-1.64131,-1.658759,-1.741159,"""5427305ea67a9e…"
-1.198157,"""good""","""e""","""vs1""",-3.384987,3.375631,-0.903836,-1.498677,-1.457382,-1.741159,"""155663ac256a0e…"
-1.071577,"""premium""","""i""","""vs2""",0.454129,0.242926,-0.902081,-1.364959,-1.317293,-1.287708,"""45ee1317264026…"
-1.029384,"""good""","""j""","""si2""",1.082348,0.242926,-0.901831,-1.240155,-1.212227,-1.117663,"""bce254e518e158…"


In [11]:
train_fraction = 0.8
n_train = int(train_fraction * len(df))
train_test_df = df.select(pl.all().exclude(["price", "hash"]))

train, test = train_test_df.head(n_train), train_test_df.tail(
    len(train_test_df) - n_train
)

In [12]:
train.head(), train.shape

(shape: (5, 9)
 ┌───────────┬─────────┬───────┬─────────┬───┬───────────┬───────────┬───────────┬───────────┐
 │ carat     ┆ cut     ┆ color ┆ clarity ┆ … ┆ table     ┆ x         ┆ y         ┆ z         │
 │ ---       ┆ ---     ┆ ---   ┆ ---     ┆   ┆ ---       ┆ ---       ┆ ---       ┆ ---       │
 │ f64       ┆ str     ┆ str   ┆ str     ┆   ┆ f64       ┆ f64       ┆ f64       ┆ f64       │
 ╞═══════════╪═════════╪═══════╪═════════╪═══╪═══════════╪═══════════╪═══════════╪═══════════╡
 │ -1.198157 ┆ ideal   ┆ e     ┆ si2     ┆ … ┆ -1.099662 ┆ -1.587823 ┆ -1.536181 ┆ -1.571115 │
 │ -1.24035  ┆ premium ┆ e     ┆ si1     ┆ … ┆ 1.585514  ┆ -1.64131  ┆ -1.658759 ┆ -1.741159 │
 │ -1.198157 ┆ good    ┆ e     ┆ vs1     ┆ … ┆ 3.375631  ┆ -1.498677 ┆ -1.457382 ┆ -1.741159 │
 │ -1.071577 ┆ premium ┆ i     ┆ vs2     ┆ … ┆ 0.242926  ┆ -1.364959 ┆ -1.317293 ┆ -1.287708 │
 │ -1.029384 ┆ good    ┆ j     ┆ si2     ┆ … ┆ 0.242926  ┆ -1.240155 ┆ -1.212227 ┆ -1.117663 │
 └───────────┴─────────┴───────┴───

In [13]:
train_test_df.head()

carat,cut,color,clarity,depth,table,x,y,z
f64,str,str,str,f64,f64,f64,f64,f64
-1.198157,"""ideal""","""e""","""si2""",-0.17409,-1.099662,-1.587823,-1.536181,-1.571115
-1.24035,"""premium""","""e""","""si1""",-1.360726,1.585514,-1.64131,-1.658759,-1.741159
-1.198157,"""good""","""e""","""vs1""",-3.384987,3.375631,-1.498677,-1.457382,-1.741159
-1.071577,"""premium""","""i""","""vs2""",0.454129,0.242926,-1.364959,-1.317293,-1.287708
-1.029384,"""good""","""j""","""si2""",1.082348,0.242926,-1.240155,-1.212227,-1.117663


In [14]:
ds = hp.TabularDataset(
    train,
    tokens,
    special_tokens=special_tokens,
    shuffle_cols=True,
    max_row_length=50,
)

print(len(ds[0]))

50


In [19]:
row = [str(i.value) for i in ds[1]]
print("".join(row))

<row-start>y:-1.6587588116247007,carat:-1.2403497908694328,depth:-1.3607258727059102,z:-1.7411588294109859,x:-1.6413100716719282,cut:premium,color:e,table:1.5855140097261475,clarity:si1,<row-end><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>


In [20]:
print([i.value for i in ds[0]])

['<row-start>', 'carat', ':', -1.1981566989627475, ',', 'clarity', ':', 'si2', ',', 'table', ':', -1.0996617971586031, ',', 'z', ':', -1.5711146235593887, ',', 'color', ':', 'e', ',', 'depth', ':', -0.17408989455083768, ',', 'cut', ':', 'ideal', ',', 'x', ':', -1.5878227303011756, ',', 'y', ':', -1.5361813230221135, ',', '<row-end>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>']


In [21]:
if torch.backends.mps.is_built():
    device_name = "mps"
elif torch.cuda.is_available():
    device_name = "cuda"
else:
    device_name = "cpu"
device = torch.device(device_name)
print(device)

mps


In [22]:
data, targets = hp.batch_data(ds, 1, n_row=1)

In [23]:
n_token = len(ds.vocab)  # size of vocabulary
d_model = 64  # embedding dimension
d_hid = 1_000  # dimension of the feedforward network model in ``nn.TransformerEncoder``
n_layers = 4  # number of ``nn.TransformerEncoderLayer`` in ``nn.TransformerEncoder``
n_head = 4  # number of heads in ``nn.MultiheadAttention``
dropout = 0.2  # dropout probability
model = hp.TransformerModel(
    n_token, d_model, n_head, d_hid, n_layers, device, dropout
).to(device)



In [24]:
import copy
import time


lr = 0.24  # learning rate
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size =100, gamma=0.5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode="min",
    factor=0.9,
    patience=5,
    threshold=0.001,
    threshold_mode="rel",
    cooldown=0,
    min_lr=0.01,
    eps=1e-08,
    verbose=False,
)


def train(model: nn.Module) -> None:
    writer = SummaryWriter()
    model.train()  # turn on train mode
    total_loss = 0.0
    log_interval = 1000
    lr_interval = 100
    n_row = 10 # one because it's not time series
    start_time = time.time()
    for batch, i in enumerate(trange(0, len(ds) - 1, n_row)):
        data, targets = hp.batch_data(ds, i, n_row=n_row)
        class_output, numeric_output = model(data)
        loss, loss_dict = hp.hephaestus_loss(
            class_output, numeric_output, targets, tokens, special_tokens, device
        )
        num_loss = loss_dict["reg_loss"].item()
        class_loss = loss_dict["class_loss"].item()
        writer.add_scalar("Loss/total_loss", loss, batch)
        writer.add_scalar("Loss/numeric_loss", num_loss, batch)
        writer.add_scalar("Loss/class_loss", class_loss, batch)
        writer.add_scalar("Metrics/learning_rate", optimizer.param_groups[0]["lr"], batch)
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()

        total_loss += loss.item()
        if batch % lr_interval == 0 and batch > 0:
            scheduler.step(total_loss)
        if batch % log_interval == 0 and batch > 0:
            # lr = scheduler.get_last_lr()[0]
            lr = optimizer.param_groups[0]["lr"]

            ms_per_batch = (time.time() - start_time) * 1000 / log_interval
            cur_loss = total_loss / log_interval
            ppl = math.exp(cur_loss)
            print(  # f'| epoch {epoch:3d} | {batch:5d}/{num_batches:5d} batches | '
                f"lr {lr:02.2f} | ms/batch {ms_per_batch:5.2f} | ",
                f"loss {cur_loss:5.2f} | ppl {ppl:8.2f}",
                loss_dict,
            )
            total_loss = 0
            start_time = time.time()
            # scheduler.step(loss)
    writer.close()

In [25]:
epochs = 3
for i in range(epochs):
    print(f"Epoch {i+1}/{epochs}")
    train(model)

Epoch 1/3


  0%|          | 0/4316 [00:00<?, ?it/s]

 23%|██▎       | 1002/4316 [00:40<02:41, 20.58it/s]

lr 0.22 | ms/batch 40.67 |  loss  0.74 | ppl     2.09 {'reg_loss': tensor(0.1524, device='mps:0', grad_fn=<MseLossBackward0>), 'class_loss': tensor(0.2297, device='mps:0', grad_fn=<NllLossBackward0>)}


 46%|████▋     | 2006/4316 [01:20<01:27, 26.37it/s]

lr 0.19 | ms/batch 39.44 |  loss  0.36 | ppl     1.43 {'reg_loss': tensor(0.1175, device='mps:0', grad_fn=<MseLossBackward0>), 'class_loss': tensor(0.1547, device='mps:0', grad_fn=<NllLossBackward0>)}


 70%|██████▉   | 3005/4316 [01:58<00:47, 27.63it/s]

lr 0.17 | ms/batch 38.02 |  loss  0.46 | ppl     1.58 {'reg_loss': tensor(0.2088, device='mps:0', grad_fn=<MseLossBackward0>), 'class_loss': tensor(0.1550, device='mps:0', grad_fn=<NllLossBackward0>)}


 93%|█████████▎| 4005/4316 [02:34<00:11, 27.15it/s]

lr 0.16 | ms/batch 36.27 |  loss  0.27 | ppl     1.32 {'reg_loss': tensor(0.0864, device='mps:0', grad_fn=<MseLossBackward0>), 'class_loss': tensor(0.1417, device='mps:0', grad_fn=<NllLossBackward0>)}


100%|██████████| 4316/4316 [02:46<00:00, 25.97it/s]


Epoch 2/3


 23%|██▎       | 1003/4316 [00:32<01:43, 32.06it/s]

lr 0.13 | ms/batch 32.40 |  loss  0.27 | ppl     1.31 {'reg_loss': tensor(0.1632, device='mps:0', grad_fn=<MseLossBackward0>), 'class_loss': tensor(0.1680, device='mps:0', grad_fn=<NllLossBackward0>)}


 46%|████▋     | 2005/4316 [01:04<01:17, 30.00it/s]

lr 0.11 | ms/batch 31.97 |  loss  0.26 | ppl     1.30 {'reg_loss': tensor(0.0997, device='mps:0', grad_fn=<MseLossBackward0>), 'class_loss': tensor(0.1778, device='mps:0', grad_fn=<NllLossBackward0>)}


 70%|██████▉   | 3003/4316 [01:37<00:42, 31.22it/s]

lr 0.10 | ms/batch 32.62 |  loss  0.36 | ppl     1.43 {'reg_loss': tensor(0.0714, device='mps:0', grad_fn=<MseLossBackward0>), 'class_loss': tensor(0.1544, device='mps:0', grad_fn=<NllLossBackward0>)}


 93%|█████████▎| 4002/4316 [02:08<00:10, 30.85it/s]

lr 0.09 | ms/batch 31.91 |  loss  0.23 | ppl     1.25 {'reg_loss': tensor(0.1603, device='mps:0', grad_fn=<MseLossBackward0>), 'class_loss': tensor(0.1473, device='mps:0', grad_fn=<NllLossBackward0>)}


100%|██████████| 4316/4316 [02:18<00:00, 31.06it/s]


Epoch 3/3


 23%|██▎       | 1005/4316 [00:34<01:51, 29.69it/s]

lr 0.08 | ms/batch 33.86 |  loss  0.23 | ppl     1.26 {'reg_loss': tensor(0.0697, device='mps:0', grad_fn=<MseLossBackward0>), 'class_loss': tensor(0.1282, device='mps:0', grad_fn=<NllLossBackward0>)}


 46%|████▋     | 2003/4316 [01:08<01:20, 28.76it/s]

lr 0.07 | ms/batch 34.47 |  loss  0.23 | ppl     1.26 {'reg_loss': tensor(0.0466, device='mps:0', grad_fn=<MseLossBackward0>), 'class_loss': tensor(0.1233, device='mps:0', grad_fn=<NllLossBackward0>)}


 70%|██████▉   | 3006/4316 [01:43<00:46, 28.38it/s]

lr 0.05 | ms/batch 35.05 |  loss  0.33 | ppl     1.40 {'reg_loss': tensor(0.0627, device='mps:0', grad_fn=<MseLossBackward0>), 'class_loss': tensor(0.1179, device='mps:0', grad_fn=<NllLossBackward0>)}


 93%|█████████▎| 4005/4316 [02:18<00:09, 31.32it/s]

lr 0.05 | ms/batch 34.73 |  loss  0.21 | ppl     1.23 {'reg_loss': tensor(0.0566, device='mps:0', grad_fn=<MseLossBackward0>), 'class_loss': tensor(0.1276, device='mps:0', grad_fn=<NllLossBackward0>)}


100%|██████████| 4316/4316 [02:28<00:00, 29.14it/s]


In [26]:
# %%
torch.save(model.state_dict(), "models/diamonds2.pth")

In [35]:
data, targets = hp.batch_data(ds, 1, n_row=2)

row1 = ds[1]
data_str = "".join([str(i.value) for i in data])
row1_str = "".join([str(i.value) for i in row1])
print("data: ", data_str)
print("row1: ", row1_str)

data:  <row-start><mask>:-1.6587588116247007,x:-1.6413100716719282,table:1.5855140097261475,depth:-1.3607258727059102,clarity:si1,cut:premium,<mask>:-1.2403497908694328,<mask>:e,z:-1.7411588294109859,<row-end><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><row-start>y:<numeric_mask>,carat:-1.1981566989627475,z:-1.7411588294109859,color:e,x:-1.4986771613499217,<mask>:3.3756312143159812,depth:<numeric_mask>,clarity:vs1,cut:good,<row-end><mask><pad><pad><pad><pad><pad><pad><pad><pad><pad><mask><pad>
row1:  <row-start>carat:-1.2403497908694328,z:-1.7411588294109859,y:-1.6587588116247007,color:e,clarity:si1,table:1.5855140097261475,cut:premium,depth:-1.3607258727059102,x:-1.6413100716719282,<row-end><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>


In [27]:
def evaluate(model: nn.Module, ds, idx) -> None:
    model.eval()  # turn on train mode
    n_row = 1  # one because it's not time series
    with torch.no_grad():
        data, targets = hp.batch_data(ds, idx, n_row=n_row)
        class_output, numeric_output = model(data)
        loss, loss_dict = hp.hephaestus_loss(
            class_output, numeric_output, targets, tokens, special_tokens, device
        )
        return {
            "loss": loss.item(),
            "loss_dict": loss_dict,
            "data": data,
            "targets": targets,
            "class_output": class_output,
            "numeric_output": numeric_output,
        }

In [28]:
ds_test = hp.TabularDataset(
    test,
    tokens,
    special_tokens=special_tokens,
    shuffle_cols=False,
    max_row_length=50,
)

In [29]:
res = evaluate(model, ds_test, 1)

In [30]:
res["loss_dict"]["reg_loss"].item()

0.0019309771014377475

In [31]:
actuals = [str(i.value) for i in res["targets"]]
actuals_ = " ".join(actuals)
actual_str = actuals_.split("<row-end>")[0]
actual_str

'<row-start> carat : -0.5863568663158119 , cut : ideal , color : e , clarity : si1 , depth : 0.24472280362154117 , table : -0.20460319486368622 , x : -0.5626486873617523 , y : -0.520539274600677 , z : -0.5083383369869077 , '

In [32]:
masked_str = [str(i.value) for i in res["data"]]
masked_str = " ".join(actuals)
masked_str = masked_str.split("<row-end>")[0]
masked_str

'<row-start> carat : -0.5863568663158119 , cut : ideal , color : e , clarity : si1 , depth : 0.24472280362154117 , table : -0.20460319486368622 , x : -0.5626486873617523 , y : -0.520539274600677 , z : -0.5083383369869077 , '

In [36]:
def show_results(res):
    actuals = [str(i.value) for i in res["targets"]]
    actuals_ = " ".join(actuals)
    actual_str = actuals_.split("<row-end>")[0]
    masked_str = [str(i.value) for i in res["data"]]
    masked_str = " ".join(masked_str)
    masked_str = masked_str.split("<row-end>")[0]

    lsm = nn.Softmax(dim=0)
    softmax_cats = lsm(res["class_output"])
    softmax_cats = torch.argmax(softmax_cats, dim=1)
    gen_tokens = []
    for idx, pred in enumerate(softmax_cats):
        token = tokens[pred-1]
        if token == "<numeric>":
            gen_tokens.append(str(res["numeric_output"][idx].item()))
        else:
            gen_tokens.append(token)
    preds = " ".join(gen_tokens)

    s = f"Targets   : {actual_str}\n" + \
        f"Masked    : {masked_str}\n" + \
        f"Predicted : {preds.split('<row-end>')[0]}"
    return s

In [37]:
for i in range(5):
    res = evaluate(model, ds, i)
    print(f"Row {i}")
    print(show_results(res))
    print("")

Row 0
Targets   : <row-start> carat : -1.1981566989627475 , z : -1.5711146235593887 , clarity : si2 , y : -1.5361813230221135 , depth : -0.17408989455083768 , table : -1.0996617971586031 , cut : ideal , color : e , x : -1.5878227303011756 , 
Masked    : <row-start> carat : -1.1981566989627475 , z : -1.5711146235593887 , clarity : si2 , y : <numeric_mask> , depth : <numeric_mask> , table : -1.0996617971586031 , cut : ideal , color : <mask> , x : -1.5878227303011756 , 
Predicted : <row-start> carat : -0.9883768558502197 , z : -1.2116670608520508 , clarity : si2 , y : e , depth : e , table : -0.9551116228103638 , cut : ideal , color : premium , x : -1.2383604049682617 , 

Row 1
Targets   : <row-start> x : -1.6413100716719282 , depth : -1.3607258727059102 , y : -1.6587588116247007 , color : e , cut : premium , carat : -1.2403497908694328 , clarity : si1 , table : 1.5855140097261475 , z : -1.7411588294109859 , 
Masked    : <row-start> x : -1.6413100716719282 , depth : -1.3607258727059102 , 

In [None]:
test.with_row_count().filter(pl.col("carat") == max(test["carat"]))

row_nr,carat,cut,color,clarity,depth,table,x,y,z
u32,f64,str,str,str,f64,f64,f64,f64,f64
9270,1.059174,"""fair""","""h""","""i1""",1.850171,-0.204603,1.068715,0.985413,1.277126
9653,1.059174,"""fair""","""e""","""i1""",3.316016,0.242926,0.943911,0.889102,1.362148


In [None]:
res = evaluate(model, ds_test, 9270)
print(show_results(res))

Targets   : <row-start> carat : 1.0591737180449114 , cut : fair , color : h , clarity : i1 , depth : 1.8501714799489926 , table : -0.20460319486368622 , x : 1.0687152244462004 , y : 0.9854127282311091 , z : 1.277125824454861 , 
Masked    : <row-start> carat : 1.0591737180449114 , cut : <mask> , color : h , clarity : i1 , depth : 1.8501714799489926 , table : -0.20460319486368622 , x : 1.0687152244462004 , y : 0.9854127282311091 , <mask> : 1.277125824454861 , 
Predicted : <row-start> carat : 0.9448257684707642 , cut : ideal , color : h , clarity : i1 , depth : 1.5026516914367676 , table : -0.3050450086593628 , x : 0.9185384511947632 , y : 0.8406640291213989 , z : 1.0746914148330688 , <row-end> fair <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk>


In [None]:
test.with_row_count().filter(pl.col("carat") == min(test["carat"]))

row_nr,carat,cut,color,clarity,depth,table,x,y,z
u32,f64,str,str,str,f64,f64,f64,f64,f64
837,-1.24035,"""premium""","""e""","""si2""",0.105119,-0.652132,-1.685883,-1.67627,-1.656137


In [None]:
res = evaluate(model, ds_test, 837)
print(show_results(res))

Targets   : <row-start> carat : -1.2403497908694328 , cut : premium , color : e , clarity : si2 , depth : 0.10511857089741324 , table : -0.6521324960111446 , x : -1.6858828561475554 , y : -1.6762698814250705 , z : -1.6561367264851874 , 
Masked    : <row-start> carat : -1.2403497908694328 , cut : premium , <mask> : e , clarity : si2 , <mask> : 0.10511857089741324 , table : -0.6521324960111446 , x : -1.6858828561475554 , y : -1.6762698814250705 , z : <numeric_mask> , 
Predicted : <row-start> carat : -0.954312801361084 , cut : premium , color : e , clarity : si2 , depth : 0.2146976888179779 , table : -0.5769697427749634 , x : -1.261174201965332 , y : -1.2771449089050293 , z : -0.5968276262283325 , <row-end> <unk> <unk> <unk> <unk> <unk> <unk> ideal <unk> <unk> <unk> <unk> <unk>


In [None]:
# preds = [str(i.value) for i in res["class_output"]]
# preds_ = " ".join(preds)
# preds_.split("<row-end>")[0]

In [None]:
lsm = nn.Softmax(dim=0)
softmax_cats = lsm(res["class_output"])
softmax_cats = torch.argmax(softmax_cats, dim=1)

In [None]:
softmax_cats, softmax_cats.shape

(tensor([ 7,  9,  2,  4,  1, 12,  2, 23,  1, 11,  2, 18,  1, 10,  2, 29,  1, 14,
          2,  4,  1, 31,  2,  4,  1, 37,  2,  4,  1, 38,  2,  4,  1, 39,  2,  4,
          1,  6,  8,  8,  8,  8,  8,  8, 17,  8,  8,  8,  8,  8],
        device='mps:0'),
 torch.Size([50]))

In [None]:
gen_tokens = []
for idx, pred in enumerate(softmax_cats):
    token = tokens[pred-1]
    if token == "<numeric>":
        gen_tokens.append(str(res["numeric_output"][idx].item()))
    else:
        gen_tokens.append(token)
preds = " ".join(gen_tokens)
print(f"""Predicted Row:\n\n{preds.split("<row-end>")[0]}""")
print(f"""\nActual Row:\n{actual_str}""")

Predicted Row:

<row-start> carat : -0.6318099498748779 , cut : ideal , color : g , clarity : si1 , depth : -0.9217300415039062 , table : -0.2088424265384674 , x : -0.28772419691085815 , y : -0.6106739044189453 , z : -0.47605791687965393 , 

Actual Row:
<row-start> carat : -0.5863568663158119 , cut : ideal , color : e , clarity : si1 , depth : 0.24472280362154117 , table : -0.20460319486368622 , x : -0.5626486873617523 , y : -0.520539274600677 , z : -0.5083383369869077 , 


In [None]:
def gen_random_number():
    num = np.random.rand()
    return num

for i in range(10):
    print(gen_random_number())


0.9683498582258889
0.8052319157998697
0.3877478747690536
0.9181466762759316
0.5964743420946779
0.0875938108238562
0.40845215274416213
0.9158682045747477
0.1687792220992772
0.8808815567770297
