# Diamond Transformer

This notebook investigates the use of a transformer to predict the price of diamonds. The dataset is from Kaggle and can be found [here](https://www.kaggle.com/shivam2503/diamonds).

While many traditional ML and DL techniques work on the dataset, our approach uses far less labeled data while achieving similar results. This is done by using a transformer.

--------------------------------------------------------------------------------

## Libraries


In [1]:
import hashlib
import math
from datetime import datetime as dt
import polars as pl


import polars as pl
import numpy as np


from torch import nn, Tensor
from tqdm.notebook import trange, tqdm

import hephaestus as hp
import torch
from torch.utils.tensorboard import SummaryWriter

The diamonds dataset: df = pl.read_csv("../data/diamonds.csv") df.head()


In [2]:
import random
import numpy as np


# Function to generate a random number, either integer or float
def generate_random_number():
    return random.randint(1, 10) if random.random() < 0.5 else random.uniform(1, 10)


# Number of rows in the DataFrame
n_rows = 100_000

# Generate random values for num1 and num2 (both integers and floats)
num1 = [generate_random_number() for _ in range(n_rows)]
num2 = [generate_random_number() for _ in range(n_rows)]

# Randomly choose an operation for each row
operations = [
    random.choice(["multiply", "divide", "add", "subtract"]) for _ in range(n_rows)
]

# Create a DataFrame with the columns num1, num2, and operation
df = pl.DataFrame({"num1": num1, "num2": num2, "operation": operations})

# Apply the operation to each row to create the result column
df = df.with_columns(
    pl.when(pl.col("operation") == "multiply")
    .then(pl.col("num1") * pl.col("num2"))
    .when(pl.col("operation") == "divide")
    .then(pl.col("num1") / pl.col("num2"))
    .when(pl.col("operation") == "add")
    .then(pl.col("num1") + pl.col("num2"))
    .when(pl.col("operation") == "subtract")
    .then(pl.col("num1") - pl.col("num2"))
    .otherwise(None)
    .alias("result")
)


# Print the DataFrame
df.head()

num1,num2,operation,result
f64,f64,str,f64
684.0,93.428636,"""multiply""",63905.187075
440.382552,30.0,"""multiply""",13211.476564
645.0,169.406437,"""add""",814.406437
589.0,595.671255,"""add""",1184.671255
945.0,542.315009,"""subtract""",402.684991


In [3]:
df.describe()

describe,num1,num2,operation,result
str,f64,f64,str,f64
"""count""",10000.0,10000.0,"""10000""",10000.0
"""null_count""",0.0,0.0,"""0""",0.0
"""mean""",503.037639,501.664594,,62970.758504
"""std""",288.368825,286.497977,,153451.869404
"""min""",1.0,1.0,"""add""",-980.694696
"""max""",1000.0,1000.0,"""subtract""",984850.571308
"""median""",501.379767,503.984314,,504.571328
"""25%""",256.793421,257.780895,,0.991464
"""75%""",754.058153,749.0,,1871.0


In [4]:
df = hp.scale_numeric(df)
df.head()

num1,num2,operation,result
f64,f64,str,f64
0.627538,-1.424917,"""multiply""",0.006089
-0.217274,-1.64631,"""multiply""",-0.324266
0.492294,-1.159723,"""add""",-0.405054
0.298099,0.328123,"""add""",-0.402641
1.532629,0.141887,"""subtract""",-0.407737


In [5]:
df = hp.make_lower_remove_special_chars(df)
val_tokens = hp.get_unique_utf8_values(df)
col_tokens = hp.get_col_tokens(df)

In [6]:
df.head()

num1,num2,operation,result
f64,f64,str,f64
0.627538,-1.424917,"""multiply""",0.006089
-0.217274,-1.64631,"""multiply""",-0.324266
0.492294,-1.159723,"""add""",-0.405054
0.298099,0.328123,"""add""",-0.402641
1.532629,0.141887,"""subtract""",-0.407737


In [7]:
special_tokens = np.array(
    [
        "missing",
        "<mask>",
        "<numeric_mask>",
        "<pad>",
        "<unk>",
        ":",
        ",",
        "<row-start>",
        "<row-end>",
    ]
)

In [8]:
tokens = np.unique(
    np.concatenate(
        (
            val_tokens,
            col_tokens,
            special_tokens,
            [
                "<numeric>",
            ],
        )
    )
)
tokens

array([',', ':', '<mask>', '<numeric>', '<numeric_mask>', '<pad>',
       '<row-end>', '<row-start>', '<unk>', 'add', 'divide', 'missing',
       'multiply', 'num1', 'num2', 'operation', 'result', 'subtract'],
      dtype=object)

# Train Test Split

To show the actual model performance out of sample we split the data into a training and test set. The training set will be used to train the model and the test set will be used to evaluate the model performance. We will use 80% of the data for training and 20% for testing.

We also remove the price column from the training and test sets and will only use a tiny subset of the data to simulate an industrial process with lots of input data but expensive and limited labeled data.


In [9]:
# Shuffle for randomness
df = df.sample(fraction=1.0, seed=42)
df.head()

num1,num2,operation,result
f64,f64,str,f64
0.627538,-1.424917,"""multiply""",0.006089
-0.217274,-1.64631,"""multiply""",-0.324266
0.492294,-1.159723,"""add""",-0.405054
0.298099,0.328123,"""add""",-0.402641
1.532629,0.141887,"""subtract""",-0.407737


In [10]:
train_fraction = 0.8
n_train = int(train_fraction * len(df))
train_test_df = df.select(pl.all())

train, test = train_test_df.head(n_train), train_test_df.tail(
    len(train_test_df) - n_train
)

labeled_train, labeled_test = df.head(n_train), df.tail(len(df) - n_train)

In [11]:
train.head(), train.shape

(shape: (5, 4)
 ┌───────────┬───────────┬───────────┬───────────┐
 │ num1      ┆ num2      ┆ operation ┆ result    │
 │ ---       ┆ ---       ┆ ---       ┆ ---       │
 │ f64       ┆ f64       ┆ str       ┆ f64       │
 ╞═══════════╪═══════════╪═══════════╪═══════════╡
 │ 0.627538  ┆ -1.424917 ┆ multiply  ┆ 0.006089  │
 │ -0.217274 ┆ -1.64631  ┆ multiply  ┆ -0.324266 │
 │ 0.492294  ┆ -1.159723 ┆ add       ┆ -0.405054 │
 │ 0.298099  ┆ 0.328123  ┆ add       ┆ -0.402641 │
 │ 1.532629  ┆ 0.141887  ┆ subtract  ┆ -0.407737 │
 └───────────┴───────────┴───────────┴───────────┘,
 (8000, 4))

In [12]:
train_test_df.head()

num1,num2,operation,result
f64,f64,str,f64
0.627538,-1.424917,"""multiply""",0.006089
-0.217274,-1.64631,"""multiply""",-0.324266
0.492294,-1.159723,"""add""",-0.405054
0.298099,0.328123,"""add""",-0.402641
1.532629,0.141887,"""subtract""",-0.407737


In [13]:
ds = hp.TabularDataset(
    train,
    tokens,
    special_tokens=special_tokens,
    shuffle_cols=False,
    max_row_length=19,
)

print(len(ds[0]))

19


In [14]:
print([i.value for i in ds[0]])

['<row-start>', 'num1', ':', 0.627537880086947, ',', 'num2', ':', -1.4249174175476262, ',', 'operation', ':', 'multiply', ',', 'result', ':', 0.006089391901095394, ',', '<row-end>', '<pad>']


In [15]:
if torch.backends.mps.is_built():
    device_name = "mps"
elif torch.cuda.is_available():
    device_name = "cuda"
else:
    device_name = "cpu"
device = torch.device(device_name)
print(device)

mps


In [16]:
n_token = len(ds.vocab)  # size of vocabulary
d_model = 96  # embedding dimension
d_hid = 1_000  # dimension of the feedforward network model in ``nn.TransformerEncoder``
n_layers = 12  # number of ``nn.TransformerEncoderLayer`` in ``nn.TransformerEncoder``
n_head = 12  # number of heads in ``nn.MultiheadAttention``
dropout = 0.2  # dropout probability
model = hp.TransformerModel(
    n_token, d_model, n_head, d_hid, n_layers, device, dropout
).to(device)



In [17]:
# Test the model out:
data, targets = hp.batch_data(ds, 1, n_row=1)
class_out, numeric_out = model(data)

In [18]:
# diff = d["pre_num_scal_embed"] / d["post_num_scal_embed"]

In [19]:
# d["pre_num_scal_embed"] * torch.eye(d["pre_num_scal_embed"].shape[1]).to(device)

In [20]:
for i in range(0):
    print(i)

In [66]:
import copy
import time


lr = 0.001  # learning rate
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size =100, gamma=0.5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode="min",
    factor=0.8,
    patience=5,
    threshold=0.001,
    threshold_mode="rel",
    cooldown=0,
    min_lr=0.01,
    eps=1e-08,
    verbose=False,
)


def train(model: nn.Module, epochs=1, model_name="") -> None:
    model.train()  # turn on train mode
    total_loss = 0.0
    log_interval = 1000
    lr_eval_interval = 25
    n_row = 100  # one because it's not time series
    start_time = time.time()
    for epoch in trange(1, epochs + 1, leave=True, desc="Epoch"):
        pbar = trange(0, len(ds) - 1, n_row, desc=f"Batch for {epoch}/{epochs}")
        writer = SummaryWriter("runs/" + model_name + "_run_" + str(epoch))
        for batch, i in enumerate(pbar):
            data, targets = hp.batch_data(ds, i, n_row=n_row)
            class_output, numeric_output = model(data)
            loss, loss_dict = hp.hephaestus_loss(
                class_output, numeric_output, targets, tokens, special_tokens, device
            )
            num_loss = loss_dict["reg_loss"].item()
            class_loss = loss_dict["class_loss"].item()
            writer.add_scalar("Loss/total_loss", loss, batch)
            writer.add_scalar("Loss/numeric_loss", num_loss, batch)
            writer.add_scalar("Loss/class_loss", class_loss, batch)
            writer.add_scalar(
                "Metrics/learning_rate", optimizer.param_groups[0]["lr"], batch
            )
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
            optimizer.step()
            pbar.set_postfix(
                {
                    "tl": f"{loss:.2f}",
                    "cl": f"{class_loss:.2f}",
                    "nl": f"{num_loss:.2f}",
                    "tr": f"{optimizer.param_groups[0]['lr']:.2f}",
                },
                refresh=True,
            )
            total_loss += loss.item()
            if batch % lr_eval_interval == 0:
                # pbar.set_p(
                #     f"tl: {}, nl: {num_loss:.2f}, cl: {class_loss:.2f}, tr: {optimizer.param_groups[0]['lr']:.2f}"
                # )

                scheduler.step(loss)

                # if batch % log_interval == 0:
                #     # lr = scheduler.get_last_lr()[0]
                #     lr = optimizer.param_groups[0]["lr"]

                #     ms_per_batch = (time.time() - start_time) * 1000 / log_interval
                #     cur_loss = total_loss / log_interval
                #     ppl = math.exp(cur_loss)
                #     print(  # f'| epoch {epoch:3d} | {batch:5d}/{num_batches:5d} batches | '
                #         f"lr {lr:02.2f} | ms/batch {ms_per_batch:5.2f} | ",
                #         f"loss {cur_loss:5.2f} | ppl {ppl:8.2f}",
                #         f"num_loss {num_loss:5.2f} | class_loss {class_loss:5.2f}",
                #     )
                #     total_loss = 0
                start_time = time.time()
                # scheduler.step(loss)
        writer.close()

In [67]:
model_time = dt.now()
model_time = model_time.strftime("%Y-%m-%dT%H:%M:%S")

exp_name = "numeric_scaling6"

model_name = model_time + "_" + exp_name
epochs = 3
train(model=model, epochs=epochs, model_name=model_name)

Epoch:   0%|          | 0/3 [00:00<?, ?it/s]

Batch for 1/3:   0%|          | 0/80 [00:00<?, ?it/s]

Batch for 2/3:   0%|          | 0/80 [00:00<?, ?it/s]

Batch for 3/3:   0%|          | 0/80 [00:00<?, ?it/s]

In [68]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


# model = CustomNumericAttention(d_model=512, n_head=8) # Replace with your model
param_count = count_parameters(model)
print(f"Total trainable parameters: {param_count:,}")

Total trainable parameters: 2,776,946


In [41]:
# %%
save_model = True
if save_model:
    MODEL_PATH = "models/" + model_name + ".pth"
    torch.save(model.state_dict(), MODEL_PATH)

In [69]:
ds_test = hp.TabularDataset(
    test,
    tokens,
    special_tokens=special_tokens,
    shuffle_cols=False,
    max_row_length=50,
)

In [70]:
def evaluate(model: nn.Module, ds, idx) -> None:
    model.eval()  # turn on train mode
    n_row = 1  # one because it's not time series
    with torch.no_grad():
        data, targets = hp.batch_data(ds, idx, n_row=n_row)
        class_output, numeric_output = model(data)
        loss, loss_dict = hp.hephaestus_loss(
            class_output, numeric_output, targets, tokens, special_tokens, device
        )
        return {
            "loss": loss.item(),
            "loss_dict": loss_dict,
            "data": data,
            "targets": targets,
            "class_output": class_output,
            "numeric_output": numeric_output,
        }

In [71]:
test

num1,num2,operation,result
f64,f64,str,f64
-0.222069,-0.093071,"""divide""",-0.410356
-0.760505,1.052537,"""multiply""",1.074778
-0.596589,0.714712,"""add""",-0.403601
1.45287,0.309759,"""divide""",-0.410351
0.846681,-0.99011,"""subtract""",-0.406913
0.336244,0.043056,"""divide""",-0.410354
1.622331,-0.777194,"""multiply""",1.35483
-1.0474,-0.145427,"""multiply""",0.192173
0.55095,-0.853984,"""subtract""",-0.407723
1.709486,0.4846,"""multiply""",3.7469


In [72]:
ds_test = hp.TabularDataset(
    test,
    tokens,
    special_tokens=special_tokens,
    shuffle_cols=False,
    max_row_length=18,
)

In [73]:
def show_results(res):

    def replacer(x):
        x = x.replace("  ", ":")
        x = x.replace(" ,", ",")
        return x

    actuals = [str(i.value) for i in res["targets"]]
    actuals_ = " ".join(actuals)
    actual_str = actuals_.split("<row-end>")[0]
    actual_str = replacer(actual_str)
    masked_str = [str(i.value) for i in res["data"]]
    masked_str = " ".join(masked_str)
    masked_str = masked_str.split("<row-end>")[0]
    masked_str = replacer(masked_str)
    lsm = nn.Softmax(dim=0)
    softmax_cats = lsm(res["class_output"])
    softmax_cats = torch.argmax(softmax_cats, dim=1)
    gen_tokens = []
    for idx, pred in enumerate(softmax_cats):
        token = tokens[pred - 1]
        if token == "<numeric>":
            gen_tokens.append(str(res["numeric_output"][idx].item()))
        else:
            gen_tokens.append(token)
    preds = " ".join(gen_tokens)
    preds = replacer(preds)
    
    s = (
        f"Targets   : {actual_str}\n"
        + f"Masked    : {masked_str}\n"
        + f"Predicted : {preds.split('<row-end>')[0]}"
    )
    return s

In [65]:
for i in range(10):
    res = evaluate(model, ds, i)
    print(f"Row {i}")
    print(show_results(res))
    print("")

Row 0
Targets   : <row-start> num1 : 0.627537880086947, num2 : -1.4249174175476262, operation : multiply, result : 0.006089391901095394, 
Masked    : <row-start> num1 : 0.627537880086947, num2 : -1.4249174175476262, operation : multiply, result : 0.006089391901095394, 
Predicted : <row-start> num1 : 0.746845006942749, num2 : subtract, operation : multiply, result : -0.0796755775809288, 

Row 1
Targets   : <row-start> num1 : -0.21727413399197865, num2 : -1.6463103806205621, operation : multiply, result : -0.3242663783358256, 
Masked    : <row-start> num1 : -0.21727413399197865, num2 : -1.6463103806205621, operation : multiply, result : -0.3242663783358256, 
Predicted : <row-start> num1 : -0.2761440575122833, num2 : subtract, operation : multiply, result : -0.31431835889816284, 

Row 2
Targets   : <row-start> num1 : 0.4922944119488227, num2 : -1.1597225240082345, operation : add, result : -0.4050543816063307, 
Masked    : <row-start> num1 : 0.4922944119488227, num2 : -1.1597225240082345,

In [32]:
def evaluate(model: nn.Module, ds, idx) -> None:
    model.eval()  # turn on train mode
    n_row = 1  # one because it's not time series
    with torch.no_grad():
        data, targets = hp.batch_data(ds, idx, n_row=n_row)
        class_output, numeric_output = model(data)
        loss, loss_dict = hp.hephaestus_loss(
            class_output, numeric_output, targets, tokens, special_tokens, device
        )
        return {
            "loss": loss.item(),
            "loss_dict": loss_dict,
            "data": data,
            "targets": targets,
            "class_output": class_output,
            "numeric_output": numeric_output,
        }

In [33]:
res = evaluate(model, ds_test, 5)

In [34]:
def make_str(l: list):
    result_list = []
    for i in l:
        if i.is_numeric:
            result_list.append(i.numeric_value)
        else:
            result_list.append(i.value)
    result_list = [str(i) for i in result_list]
    return " ".join(result_list).split("<row-end>")[0]


def prediction_str(class_output, numeric_output, tokens):
    lsm = nn.Softmax(dim=0)
    softmax_cats = lsm(class_output)
    softmax_cats = torch.argmax(softmax_cats, dim=1)
    gen_tokens = []
    for idx, pred in enumerate(softmax_cats):
        token = tokens[pred]
        if token == "<numeric>":
            gen_tokens.append(str(numeric_output[idx].item()))
        else:
            gen_tokens.append(token)
    preds = " ".join(gen_tokens).split("<row-end>")[0]
    return preds


def print_results(d: dict):
    data_str = make_str(d["data"])
    target_str = make_str(d["targets"])
    pred_str = prediction_str(d["class_output"], d["numeric_output"], tokens)
    print(f"Data  : {data_str}\nTarget: {target_str}\nPredict: {pred_str}")


print_results(res)

AttributeError: 'StringNumeric' object has no attribute 'numeric_value'

In [None]:
res["class_output"].shape

torch.Size([50, 41])

In [None]:
res["numeric_output"][:, None].shape

torch.Size([50, 1])

In [None]:
torch.cat((res["numeric_output"][None, :], res["class_output"].T)).shape

torch.Size([42, 50])

In [None]:
def prediction_str(class_output, numeric_output, tokens):
    lsm = nn.Softmax(dim=0)
    softmax_cats = lsm(class_output)
    softmax_cats = torch.argmax(softmax_cats, dim=1)
    gen_tokens = []
    for idx, pred in enumerate(softmax_cats):
        token = tokens[pred]
        if token == "<numeric>":
            gen_tokens.append(str(numeric_output[idx].item()))
        else:
            gen_tokens.append(token)
    preds = " ".join(gen_tokens).split("<row-end>")[0]
    return preds

In [None]:
import torch
import numpy as np


def generate_data(N):
    x = torch.randint(1, 100, (N, 2), dtype=torch.float32)
    y_mul = x[:, 0] * x[:, 1]
    y_div = x[:, 0] / x[:, 1]
    y = torch.stack((y_mul, y_div), dim=1)
    return x, y

In [None]:
import torch.nn as nn


class MathNet(nn.Module):
    def __init__(self):
        super(MathNet, self).__init__()
        self.fc1 = nn.Linear(2, 10)
        self.fc2 = nn.Linear(10, 2)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [None]:
N = 1000000
data_x, data_y = generate_data(N)

# Split into training and testing
train_x, test_x = data_x[:8000], data_x[8000:]
train_y, test_y = data_y[:8000], data_y[8000:]

model = MathNet()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

for epoch in range(1000):
    # Forward pass
    outputs = model(train_x)
    loss = criterion(outputs, train_y)

    # Backward pass
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if epoch % 100 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item()}")

Epoch 0, Loss: 5478397.5
Epoch 100, Loss: 1675928.875
Epoch 200, Loss: 784210.25
Epoch 300, Loss: 779894.25
Epoch 400, Loss: 775376.625
Epoch 500, Loss: 770643.8125
Epoch 600, Loss: 765713.625
Epoch 700, Loss: 760585.8125
Epoch 800, Loss: 755247.5
Epoch 900, Loss: 749672.125


In [None]:
with torch.no_grad():
    test_outputs = model(test_x)
    test_loss = criterion(test_outputs, test_y)
    print(f"Test Loss: {test_loss.item()}")

Test Loss: 737482.3125


In [None]:
train_x.shape

torch.Size([8000, 2])

In [None]:
test_x.shape

torch.Size([992000, 2])

In [None]:
test_y

tensor([[1.3350e+03, 5.9333e+00],
        [7.4000e+01, 1.8500e+01],
        [2.7600e+02, 1.7250e+01],
        ...,
        [9.8600e+02, 8.5294e-01],
        [9.0000e+02, 1.1111e-01],
        [2.5550e+03, 2.0857e+00]])

In [None]:
test_outputs

tensor([[ 5.5403e+03, -6.8926e-01],
        [ 5.4591e+03, -7.7647e-01],
        [ 1.7531e+03,  9.1369e+00],
        ...,
        [ 8.5710e+02,  8.9137e-01],
        [ 6.9791e+03, -1.7934e+00],
        [ 5.2343e+01,  1.2579e+00]])

In [None]:
def generate_data(N):
    x = torch.randint(1, 1000, (N, 2), dtype=torch.float32)
    y_mul = x[:, 0] * x[:, 1]
    y_div = x[:, 0] / x[:, 1]
    y = torch.stack((y_mul, y_div), dim=1)
    return x, y

In [None]:
from torch.nn import functional as F


class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

    def forward(self, values, keys, queries, mask):
        N = queries.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], queries.shape[1]

        # Split the embedding into self.heads different pieces
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = queries.reshape(N, query_len, self.heads, self.head_dim)

        # Scaled dot-product attention
        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        attention = torch.nn.Softmax(dim=3)(energy)

        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )

        out = self.fc_out(out)
        return out


class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads):
        super(TransformerBlock, self).__init__()
        self.attention = SelfAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)

        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, 256), nn.ReLU(), nn.Linear(256, embed_size)
        )

    def forward(self, value, key, query, mask):
        attention = self.attention(value, key, query, mask)

        # Add skip connection, run through normalization and finally feed forward
        x = self.norm1(attention + query)
        forward = self.feed_forward(x)
        out = self.norm2(forward + x)
        return out


class MathNet(nn.Module):
    def __init__(self, embed_size, heads):
        super(MathNet, self).__init__()
        self.embed_size = embed_size

        self.fc1 = nn.Linear(2, embed_size)
        self.transformer_block = TransformerBlock(embed_size, heads)
        self.fc2 = nn.Linear(embed_size, 2)

    def forward(self, x):
        x = self.fc1(x)
        x = x.repeat(1, 3, 1)  # Repeat x three times to form keys, values, queries
        x = self.transformer_block(x, x, x, mask=None)
        x = self.fc2(x[:, 0, :])  # We only need the first result from the sequence
        return x

In [None]:
model = MathNet(embed_size=64, heads=4)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

for epoch in range(1000):
    # Forward pass
    outputs = model(train_x)
    loss = criterion(outputs, train_y)

    # Backward pass
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if epoch % 100 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item()}")

  return F.mse_loss(input, target, reduction=self.reduction)


Epoch 0, Loss: 5377728.5
Epoch 100, Loss: 4712922.5


KeyboardInterrupt: 