# Define the model


In [1]:
# %%capture installs
# %pip install polars
# %pip install torchviz
# %conda install -y conda install -c  python-graphviz
# %conda install -c fastchan python-graphviz -y

In this tutorial, we train a `nn.TransformerEncoder` model on a language modeling task. The language modeling task is to assign a probability for the likelihood of a given word (or a sequence of words) to follow a sequence of words. A sequence of tokens are passed to the embedding layer first, followed by a positional encoding layer to account for the order of the word (see the next paragraph for more details). The `nn.TransformerEncoder` consists of multiple layers of [nn.TransformerEncoderLayer](https://pytorch.org/docs/stable/generated/torch.nn.TransformerEncoderLayer.html)_. Along with the input sequence, a square attention mask is required because the self-attention layers in `nn.TransformerEncoder` are only allowed to attend the earlier positions in the sequence. For the language modeling task, any tokens on the future positions should be masked. To produce a probability distribution over output words, the output of the `nn.TransformerEncoder` model is passed through a linear layer followed by a log-softmax function.


In [2]:
import math
import os
import time
from tempfile import TemporaryDirectory
from typing import Tuple

import re
from numbers import Number

from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

from tqdm import trange, tqdm

import torch
import polars as pl
import numpy as np


from torch import nn, Tensor

import hephaestus as hp

In [3]:
torch.__version__

'2.1.0.dev20230623'

In [4]:
weather = pl.read_parquet("~/Hephaestus/data/weather_clean.parquet")
weather.head()

x,y,station_name,climate_identifier,province_code,local_year,local_month,local_day,local_hour,temp,temp_flag,dew_point_temp,dew_point_temp_flag,humidex,precip_amount,precip_amount_flag,relative_humidity,relative_humidity_flag,station_pressure,station_pressure_flag,wind_chill,wind_direction,wind_direction_flag,wind_speed,wind_speed_flag
f64,f64,str,str,str,str,str,str,str,f64,str,f64,str,f64,f64,str,f64,str,f64,str,f64,f64,str,f64,str
-114.000297,51.109447,"""CALGARY INT'L …","""3031094""","""AB""","""2010""","""1""","""1""","""0""",-21.6,"""missing""",-23.9,"""missing""",,,"""missing""",82.0,"""missing""",89.38,"""missing""",,,"""M""",,"""M"""
-114.000297,51.109447,"""CALGARY INT'L …","""3031094""","""AB""","""2010""","""1""","""1""","""1""",-21.2,"""missing""",-23.5,"""missing""",,,"""missing""",82.0,"""missing""",89.25,"""missing""",,,"""M""",,"""M"""
-114.000297,51.109447,"""CALGARY INT'L …","""3031094""","""AB""","""2010""","""1""","""1""","""2""",-20.8,"""missing""",-23.0,"""missing""",,,"""missing""",82.0,"""missing""",89.21,"""missing""",,,"""M""",,"""M"""
-114.000297,51.109447,"""CALGARY INT'L …","""3031094""","""AB""","""2010""","""1""","""1""","""3""",-20.4,"""missing""",-22.6,"""missing""",,,"""missing""",83.0,"""missing""",89.12,"""missing""",,,"""M""",,"""M"""
-114.000297,51.109447,"""CALGARY INT'L …","""3031094""","""AB""","""2010""","""1""","""1""","""4""",-20.4,"""missing""",-22.7,"""missing""",,,"""missing""",82.0,"""missing""",89.04,"""missing""",,,"""M""",,"""M"""


In [5]:
weather = hp.scale_numeric(weather)

In [6]:
if torch.backends.mps.is_built():
    device_name = "mps"
elif torch.cuda.is_available():
    device_name = "cuda"
else:
    device_name = "cpu"
device = torch.device(device_name)
print(device)

mps


In [7]:
weather = hp.make_lower_remove_special_chars(weather)
weather_val_tokens = hp.get_unique_utf8_values(weather)
weather_col_tokens = hp.get_col_tokens(weather)

In [8]:
special_tokens = np.array(
    [
        "missing",
        "<mask>",
        "<pad>",
        "<unk>",
        "<numeric>",
        ":",
        ",",
        "<row-start>",
        "<row-end>",
    ]
)
tokens = np.unique(
    np.concatenate(
        (
            weather_val_tokens,
            weather_col_tokens,
            special_tokens,
        )
    )
)
tokens

array([',', '0', '1', '10', '11', '12', '13', '14', '15', '16', '17',
       '18', '19', '2', '20', '2010', '2011', '2012', '2013', '2014',
       '2015', '2016', '2017', '2018', '2019', '2020', '2021', '2022',
       '21', '22', '23', '24', '25', '26', '27', '28', '29', '3', '30',
       '3012206', '3026knq', '3031094', '3033890', '3035208', '3062696',
       '31', '4', '5', '6', '7', '8', '9', ':', '<mask>', '<numeric>',
       '<pad>', '<row-end>', '<row-start>', '<unk>', 'ab', 'amount',
       'calgary int l cs', 'chill', 'climate', 'code', 'day', 'dew',
       'direction', 'edmonton international cs', 'flag',
       'fort mcmurray cs', 'hour', 'humidex', 'humidity', 'identifier',
       'lethbridge cda', 'local', 'm', 'missing', 'month', 'name',
       'pincher creek climate', 'point', 'precip', 'pressure', 'province',
       'relative', 'speed', 'station', 'sundre a', 'temp', 'wind', 'x',
       'y', 'year'], dtype=object)

In [9]:
weather_ds = hp.TabularDataset(
    weather,
    tokens,
    special_tokens=special_tokens,
    shuffle_cols=False,
    max_row_length=140,
)

print(len(weather_ds[0]))

140


In [10]:
weather_ds.vocab

array([',', '0', '1', '10', '11', '12', '13', '14', '15', '16', '17',
       '18', '19', '2', '20', '2010', '2011', '2012', '2013', '2014',
       '2015', '2016', '2017', '2018', '2019', '2020', '2021', '2022',
       '21', '22', '23', '24', '25', '26', '27', '28', '29', '3', '30',
       '3012206', '3026knq', '3031094', '3033890', '3035208', '3062696',
       '31', '4', '5', '6', '7', '8', '9', ':', '<mask>', '<numeric>',
       '<pad>', '<row-end>', '<row-start>', '<unk>', 'ab', 'amount',
       'calgary int l cs', 'chill', 'climate', 'code', 'day', 'dew',
       'direction', 'edmonton international cs', 'flag',
       'fort mcmurray cs', 'hour', 'humidex', 'humidity', 'identifier',
       'lethbridge cda', 'local', 'm', 'missing', 'month', 'name',
       'pincher creek climate', 'point', 'precip', 'pressure', 'province',
       'relative', 'speed', 'station', 'sundre a', 'temp', 'wind', 'x',
       'y', 'year'], dtype=object)

# Load and batch data


The model hyperparameters are defined below. The `vocab` size is equal to the length of the vocab object.


In [11]:
n_token = len(weather_ds.vocab)  # size of vocabulary
d_model = 32  # embedding dimension
d_hid = 200  # dimension of the feedforward network model in ``nn.TransformerEncoder``
n_layers = 2  # number of ``nn.TransformerEncoderLayer`` in ``nn.TransformerEncoder``
n_head = 2  # number of heads in ``nn.MultiheadAttention``
dropout = 0.2  # dropout probability
model = hp.TransformerModel(
    n_token, d_model, n_head, d_hid, n_layers, device, dropout
).to(device)
# initial_memory_allocated = torch.cuda.memory_allocated()
# initial_memory_reserved = torch.cuda.memory_reserved()
# c,n = model(data)
# final_memory_allocated = torch.cuda.memory_allocated()
# final_memory_reserved = torch.cuda.memory_reserved()
# print(f'Increase in memory allocated: {final_memory_allocated - initial_memory_allocated}')
# print(f'Increase in memory reserved: {final_memory_reserved - initial_memory_reserved}')



# Run the model


We use [CrossEntropyLoss](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html) _with the [SGD](https://pytorch.org/docs/stable/generated/torch.optim.SGD.html)_ (stochastic gradient descent) optimizer. The learning rate is initially set to 5.0 and follows a [StepLR](https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.StepLR.html) _schedule. During training, we use [nn.utils.clip_grad_norm\_](<https://pytorch.org/docs/stable/generated/torch.nn.utils.clip_grad_norm_.html)_> to prevent gradients from exploding.


In [12]:
import copy
import time


lr = 0.99  # learning rate
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size =100, gamma=0.5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode="min",
    factor=0.9,
    patience=5,
    threshold=0.001,
    threshold_mode="rel",
    cooldown=0,
    min_lr=0.01,
    eps=1e-08,
    verbose=False,
)


def train(model: nn.Module) -> None:
    model.train()  # turn on train mode
    total_loss = 0.0
    log_interval = 50
    n_row = 50
    start_time = time.time()
    for batch, i in enumerate(trange(0, len(weather_ds) - 1, n_row)):
        data, targets = hp.batch_data(weather_ds, i, n_row=n_row)
        class_output, numeric_output = model(data)
        loss, loss_dict = hp.hephaestus_loss(
            class_output, numeric_output, targets, tokens, special_tokens, device
        )

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()

        total_loss += loss.item()
        if batch % log_interval == 0 and batch > 0:
            # lr = scheduler.get_last_lr()[0]
            lr = optimizer.param_groups[0]["lr"]

            ms_per_batch = (time.time() - start_time) * 1000 / log_interval
            cur_loss = total_loss / log_interval
            ppl = math.exp(cur_loss)
            print(  # f'| epoch {epoch:3d} | {batch:5d}/{num_batches:5d} batches | '
                f"lr {lr:02.2f} | ms/batch {ms_per_batch:5.2f} | ",
                f"loss {cur_loss:5.2f} | ppl {ppl:8.2f}",
                loss_dict,
            )
            total_loss = 0
            start_time = time.time()
            scheduler.step(loss)

Loop over epochs. Save the model if the validation loss is the best we've seen so far. Adjust the learning rate after each epoch.


In [13]:
train(model)

  0%|          | 50/12714 [00:31<2:00:36,  1.75it/s]

lr 0.99 | ms/batch 655.22 |  loss  3.31 | ppl    27.32 

  0%|          | 51/12714 [00:33<3:40:45,  1.05s/it]

{'reg_loss': tensor(0.4131, device='mps:0', grad_fn=<MseLossBackward0>), 'class_loss': tensor(1.6221, device='mps:0', grad_fn=<NllLossBackward0>)}


  1%|          | 101/12714 [00:59<1:35:18,  2.21it/s]

lr 0.99 | ms/batch 521.63 |  loss  1.77 | ppl     5.89 {'reg_loss': tensor(0.4007, device='mps:0', grad_fn=<MseLossBackward0>), 'class_loss': tensor(0.9961, device='mps:0', grad_fn=<NllLossBackward0>)}


  1%|          | 151/12714 [01:20<1:55:24,  1.81it/s]

lr 0.99 | ms/batch 402.97 |  loss  1.28 | ppl     3.58 {'reg_loss': tensor(0.7777, device='mps:0', grad_fn=<MseLossBackward0>), 'class_loss': tensor(0.6775, device='mps:0', grad_fn=<NllLossBackward0>)}


  2%|▏         | 201/12714 [01:42<1:27:02,  2.40it/s]

lr 0.99 | ms/batch 438.38 |  loss  1.13 | ppl     3.10 {'reg_loss': tensor(0.1815, device='mps:0', grad_fn=<MseLossBackward0>), 'class_loss': tensor(0.5331, device='mps:0', grad_fn=<NllLossBackward0>)}


  2%|▏         | 251/12714 [02:02<1:24:11,  2.47it/s]

lr 0.99 | ms/batch 402.92 |  loss  0.92 | ppl     2.52 {'reg_loss': tensor(0.5036, device='mps:0', grad_fn=<MseLossBackward0>), 'class_loss': tensor(0.4491, device='mps:0', grad_fn=<NllLossBackward0>)}


  2%|▏         | 300/12714 [02:23<1:38:49,  2.09it/s]


Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "/opt/homebrew/Caskroom/miniconda/base/envs/pytorch/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3508, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/var/folders/rs/qflxwtyx6kvfj8jcqx5zm5hr0000gn/T/ipykernel_11667/2658576871.py", line 1, in <module>
    train(model)
  File "/var/folders/rs/qflxwtyx6kvfj8jcqx5zm5hr0000gn/T/ipykernel_11667/33392831.py", line 30, in train
    class_output, numeric_output = model(data)
                                   ^^^^^^^^^^^
  File "/opt/homebrew/Caskroom/miniconda/base/envs/pytorch/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Caskroom/miniconda/base/envs/pytorch/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^

In [None]:
torch.save(model.state_dict(), "model_path.pt")

In [None]:
data, targets = hp.batch_data(weather_ds, idx=0, n_row=3)
print(len(data), len(targets))
# data[0]

In [None]:
cat, num = model(data)

In [None]:
def custom_loss_manual(class_preds, numeric_preds, raw_data):
    cross_entropy = nn.CrossEntropyLoss()
    mse_loss = nn.MSELoss()
    raw_data_numeric_class = raw_data[:]

    for idx, val in enumerate(raw_data_numeric_class):
        if val.is_numeric:
            val = hp.StringNumeric(value="<numeric>")
            val.gen_embed_idx(tokens, special_tokens)
            raw_data_numeric_class[idx] = val

    class_target = torch.tensor([i.embedding_idx for i in raw_data_numeric_class]).to(
        device
    )
    class_loss = cross_entropy(class_preds, class_target)

    actual_num_idx = torch.tensor(
        [idx for idx, j in enumerate(raw_data) if j.is_numeric]
    ).to(device)
    pred_nums = numeric_preds[actual_num_idx]
    # print(actual_num_idx.shape)
    actual_nums = torch.tensor([i.value for i in raw_data if i.is_numeric]).to(device)
    # print(actual_nums.shape)
    # print(pred_nums.shape)
    reg_loss = mse_loss(pred_nums, actual_nums)
    reg_loss_adjuster = 1  # class_loss/reg_loss

    return reg_loss * reg_loss_adjuster + class_loss, {
        "reg_loss": reg_loss,
        "class_loss": class_loss,
    }


custom_loss_manual(cat, num, targets)

In [None]:
lsm = nn.Softmax(dim=0)
l_cat = lsm(cat.T)
l_cat = torch.argmax(l_cat, dim=0)

In [None]:
gen_tokens = []
for idx, pred in enumerate(l_cat):
    token = tokens[pred - 1]
    if token == "<numeric>":
        gen_tokens.append("num_" + str(num[idx].item()))
    else:
        gen_tokens.append(token)
preds = " ".join(gen_tokens)
preds.split("<row-end>")[0]

In [None]:
actuals = [str(i.value) for i in targets]
actuals_ = " ".join(actuals)
actuals_.split("<row-end>")[0]

In [None]:
def evaluate(model, data, i):
    model.eval
    with torch.no_grad():
        data, targets = hp.batch_data(weather_ds, i, n_row=4)
        class_output, numeric_output = model(data)
        loss, loss_dict = hp.hephaestus_loss(class_output, numeric_output, targets)

        return data, targets, class_output, numeric_output, loss, loss_dict


data, targets, class_output, numeric_output, loss, loss_dict = evaluate(
    model, weather_ds, 1
)