In [143]:
import glob
import os
from pathlib import Path
import sys

import numpy as np
import pandas as pd

import wandb

from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from otc.models.activation import ReGLU
from otc.models.fttransformer import (
    CLSToken,
    FeatureTokenizer,
    FTTransformer,
    Transformer,
)

from otc.data.dataset import TabDataset
from otc.data.dataloader import TabDataLoader
from otc.features.build_features import features_classical, features_classical_size

In [None]:
run = wandb.init(project="thesis", entity="fbv")

dataset = "fbv/thesis/ise_unsupervised_log_standardized_clipped:latest"
artifact = run.use_artifact(dataset)
data_dir = artifact.download()

In [None]:
# preserve relative ordering, sample for testing ache
frac = 0.05

X_train = pd.read_parquet(Path(data_dir, "train_set.parquet"), engine="fastparquet")

X_train = X_train.sample(frac=frac, random_state=42)
y_train = X_train["buy_sell"]
X_train = X_train[features_classical_size]

In [None]:
# y = 0
training_data = TabDataset(X_train, y_train)



# TODOs:💡
- pass pre-training command to objective ✅
- filter unlabelled trades in fit ✅
- add pretraining method
- save pre-trained model to checkpoint
- make sure it works with finetuning
- refactor Transformer head to TargetHead

In [144]:
class CLSHead(nn.Module):
    """
    2 Layer MLP projection head
    """
    # d_in is last dim of transformer output torch.Size([4, 17, 96]) -> 96 (ok)
    # d_out -> last dim of output here [4,17,1] -> [4, 16] (ok)
    def __init__(self, *, d_in: int, d_hidden: int):
        super().__init__()
        self.first = nn.Linear(d_in, d_hidden)
        self.out = nn.Linear(d_hidden, 1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x[:, 1:]

        x = self.out(F.relu(self.first(x))).squeeze(2)
        return x

In [145]:
head = CLSHead(d_in=768, d_hidden=768)

print(head)

CLSHead(
  (first): Linear(in_features=768, out_features=768, bias=True)
  (out): Linear(in_features=768, out_features=1, bias=True)
)


In [146]:
class ShufflePermutations(object):
    """
    Generate permutations by shuffeling.
    """
    def __init__(self, X_num, X_cat):
        self.X_num = X_num
        self.X_cat = X_cat

    def permute(self, X):
        """
        generate random index
        """
        if X is None:
            return None

        # generate random index array
        return torch.randint_like(X, X.shape[0], dtype=torch.long)

    def gen_permutations(self):
        # permute numerical and categorical by random index
        X_num = self.X_num
        X_cat = self.X_cat if self.X_cat is not None else None
        return self.permute(X_num), self.permute(X_cat)

In [147]:
d_num, d_cat = 8,8
batch_size = 4

X_num = torch.randn(batch_size, d_num)
X_cat = torch.randint(0, 10, (batch_size, d_cat))

perm_class = ShufflePermutations(X_num, X_cat)
x_num_perm, x_cat_perm = perm_class.gen_permutations()

tensor([[3, 0, 3, 2, 1, 2, 2, 3],
        [0, 2, 0, 3, 1, 0, 2, 2],
        [2, 2, 1, 1, 0, 2, 1, 2],
        [2, 0, 1, 3, 0, 2, 0, 2]])
tensor([[2, 2, 3, 2, 2, 1, 0, 2],
        [2, 2, 3, 0, 1, 2, 3, 0],
        [3, 2, 1, 2, 0, 1, 2, 3],
        [0, 0, 2, 2, 0, 2, 3, 2]])


In [148]:
x_cat_perm

tensor([[2, 2, 3, 2, 2, 1, 0, 2],
        [2, 2, 3, 0, 1, 2, 3, 0],
        [3, 2, 1, 2, 0, 1, 2, 3],
        [0, 0, 2, 2, 0, 2, 3, 2]])

In [149]:
corrupt_probability = 0.15

def gen_masks(X, perm):
    # generate binary masks
    masks = torch.empty_like(X).bernoulli(p=corrupt_probability).bool()
    new_masks = masks & (X != X[perm, torch.arange(X.shape[1], device=X.device)])
    return new_masks

# FIXME: probably generate for train and val set
x_num_mask = gen_masks(X_num, x_num_perm)
x_cat_mask = gen_masks(X_cat, x_cat_perm)

In [150]:
x_num_mask

tensor([[False, False, False, False, False, False, False,  True],
        [False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False,  True]])

In [151]:
X_num

tensor([[ 1.0082, -2.0447, -1.8514, -0.3190,  0.7799,  2.2147, -0.2914, -0.4975],
        [ 0.6028,  1.1816,  1.0263,  1.2949,  0.4654,  0.1010,  0.6004, -1.8047],
        [-0.1933,  2.0371,  0.3731,  0.9525,  0.5256, -1.1464, -0.4436,  1.2238],
        [-0.3117, -0.8931,  2.6008, -0.0808,  0.2362, -0.6265, -0.5757,  0.2028]])

In [152]:
# get values at permuted places
x_num_permuted = torch.gather(X_num, 0, x_num_perm)

# replace at mask = True
X_num[x_num_mask] = x_num_permuted[x_num_mask]

if X_cat is not None:

    # along the 0 axis get elements based on perm_cat
    x_cat_permuted = torch.gather(X_cat, 0, x_cat_perm)
    
    # replace at mask
    X_cat[x_cat_mask] = x_cat_permuted[x_cat_mask]

In [153]:
x_num_mask

tensor([[False, False, False, False, False, False, False,  True],
        [False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False,  True]])

In [154]:
X_num

tensor([[ 1.0082, -2.0447, -1.8514, -0.3190,  0.7799,  2.2147, -0.2914,  0.2028],
        [ 0.6028,  1.1816,  1.0263,  1.2949,  0.4654,  0.1010,  0.6004, -1.8047],
        [-0.1933,  2.0371,  0.3731,  0.9525,  0.5256, -1.1464, -0.4436,  1.2238],
        [-0.3117, -0.8931,  2.6008, -0.0808,  0.2362, -0.6265, -0.5757,  1.2238]])

In [172]:
d_token = 96

params_feature_tokenizer = {
            "num_continous": d_num,
            "cat_cardinalities": [11] * d_cat,
            "d_token": d_token,
        }

feature_tokenizer = FeatureTokenizer(**params_feature_tokenizer)

params_transformer = {
            "d_token": d_token,
            "n_blocks": 3,
            "attention_n_heads": 8,
            "attention_initialization": "kaiming",
            "ffn_activation": ReGLU,
            "attention_normalization": nn.LayerNorm,
            "ffn_normalization": nn.LayerNorm,
            "ffn_dropout": 0.1,
            "ffn_d_hidden": 96 * 2,
            "attention_dropout": 0.1,
            "residual_dropout": 0.1,
            "prenormalization": True,
            "first_prenormalization": False,
            "last_layer_query_idx": None,
            "n_tokens": None,
            "kv_compression_ratio": None,
            "kv_compression_sharing": None,
            "head_activation": nn.ReLU,
            "head_normalization": nn.LayerNorm,
            "d_out": 1,
        }

transformer = Transformer(**params_transformer)

target_head = transformer.head

In [170]:
d_hidden = 32

class PretrainModel(nn.Module):
    def __init__(self):
        super().__init__()

        self.cls_token = CLSToken(d_token, "uniform")

        self.feature_tokenizer = feature_tokenizer
        self.transformer = transformer

        # disable BERT-like classification head and replace with idenity mapping
        self.transformer.head = nn.Identity()

        # enable RTD-like head with one class per feature
        self.head = CLSHead(
                d_in=d_token,
                d_hidden=d_hidden,
            )

    def forward(self, x_num, x_cat):

        # tokenize
        x = self.feature_tokenizer(x_num, x_cat)
        
        # add cls token to input
        x = self.cls_token(x)

        # add backbone
        h = self.transformer(x)

        # add classification head
        return self.head(h)

In [171]:
model = PretrainModel()
predictions = model(X_num, X_cat)
print(predictions.shape)

torch.Size([4, 16])


In [160]:
# move to device
device = "cpu"

model = PretrainModel().to(device)

optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.BCEWithLogitsLoss()


In [162]:
# cat masks if cat variables are present
if X_cat != None:
    masks = torch.cat([x_num_mask, x_cat_mask], dim=1)
else:
    masks = x_num_mask

# logits to binary mask
hard_predictions = torch.zeros_like(predictions, dtype=torch.long)
hard_predictions[predictions > 0] = 1

# calculate column-wise accuracy
features_accuracy = (hard_predictions.bool() == masks).sum(0) / hard_predictions.shape[0]

mean_acc = features_accuracy.mean()

print(masks)
print(hard_predictions)
print(mean_acc)

tensor([[False, False, False, False, False, False, False,  True, False, False,
         False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False,  True,
         False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False],
        [False, False, False, False, False, False, False,  True, False,  True,
          True, False, False, False, False, False]])
tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
tensor(0.0781)


In [173]:
transformer.head = target_head

In [177]:
# swap pretraining head with classification head
transformer.head = target_head
clf = FTTransformer(feature_tokenizer, transformer)

# TODO: save to checkpoint

# classify
clf(X_cat, X_num)

tensor([[-0.1340],
        [-0.3609],
        [-0.6000],
        [-0.3740]], grad_fn=<AddmmBackward0>)

## Runfield 🛫

In [None]:
class CLSHead(nn.Module):
    """
    2 Layer MLP projection head
    """
    # d_in is last dim of transformer output torch.Size([4, 17, 96]) -> 96 (ok)
    # d_out -> last dim of output here [4,17,1] -> [4, 16] (ok)
    def __init__(self, *, d_in: int, d_hidden: int):
        super().__init__()
        self.first = nn.Linear(d_in, d_hidden)
        self.out = nn.Linear(d_hidden, 1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x[:, 1:]

        x = self.out(F.relu(self.first(x))).squeeze(2)
        return x

In [None]:
device = "cuda"
batch_size = 16192
epochs = 100

d_token = 192
n_blocks = 3
attention_dropout = 0.2
ffn_dropout = 0.1
residual_dropout = 0.0
attention_heads = 8


reduction = "mean"

feature_tokenizer_kwargs = {
    "num_continous": len(X_train.columns.tolist()),
    "cat_cardinalities": (),
    "d_token": d_token,
}

dl_params = {
    "batch_size": batch_size,  # dataprallel splits batches across devices
    "shuffle": False,
    "device": device,
}

transformer_kwargs = {
    "d_token": d_token,
    "n_blocks": n_blocks,
    "attention_n_heads": attention_heads,
    "attention_initialization": "kaiming",
    "ffn_activation": ReGLU,
    "attention_normalization": nn.LayerNorm,
    "ffn_normalization": nn.LayerNorm,
    "ffn_dropout": ffn_dropout,
    # fix at 4/3, as activation (see search space B in
    # https://arxiv.org/pdf/2106.11959v2.pdf)
    # is static with ReGLU / GeGLU
    "ffn_d_hidden": int(d_token * (4 / 3)),
    "attention_dropout": attention_dropout,
    "residual_dropout": residual_dropout,  # see search space (B)
    "prenormalization": True,
    "first_prenormalization": False,
    "last_layer_query_idx": None,
    "n_tokens": None,
    "kv_compression_ratio": None,
    "kv_compression_sharing": None,
    "head_activation": nn.GELU, # nn.ReLU
    "head_normalization": nn.LayerNorm,
    "d_out": 1,  # fix at 1, due to binary classification
}

head_kwargs = {"d_in": d_token, "d_hidden": 32}

optim_params = {"lr": 1e-4, "weight_decay": 0.00001}

module_params = {
    "transformer": Transformer(**transformer_kwargs),  # type: ignore
    "feature_tokenizer": FeatureTokenizer(**feature_tokenizer_kwargs),  # type: ignore # noqa: E501
    "cat_features": None,
    "cat_cardinalities": [],
}

clf = FTTransformer(**module_params)
criterion = nn.BCEWithLogitsLoss()

# swap target head with classification head, and restore later
target_head = clf.transformer.head
clf_head = CLSHead(**head_kwargs)
clf.transformer.head = clf_head

In [None]:
training_data = TabDataset(X_train, y_train)
val_data = TabDataset(X_val, y_val)

train_loader = TabDataLoader(
    training_data.x_cat,
    training_data.x_cont,
    training_data.weight,
    training_data.y,
    **dl_params
)

val_loader = TabDataLoader(
    val_data.x_cat, val_data.x_cont, val_data.weight, val_data.y, **dl_params
)

In [None]:
class CosineWarmupScheduler(optim.lr_scheduler._LRScheduler):
    def __init__(self, optimizer, warmup, max_iters):
        self.warmup = warmup
        self.max_num_iters = max_iters
        super().__init__(optimizer)

    def get_lr(self):
        lr_factor = self.get_lr_factor(epoch=self.last_epoch)
        return [base_lr * lr_factor for base_lr in self.base_lrs]

    def get_lr_factor(self, epoch):
        lr_factor = 0.5 * (1 + np.cos(np.pi * epoch / self.max_num_iters))
        if epoch <= self.warmup:
            lr_factor *= epoch * 1.0 / self.warmup
        return lr_factor

In [None]:
optimizer = optim.AdamW(clf.parameters(),
    lr=optim_params["lr"],
    weight_decay=optim_params["weight_decay"],
)

max_iters = epochs * len(train_loader)
# saw recommendation of 5 - 10 % of total training budget or 100 to 500 steps
warmup = int(0.05 * max_iters)
print(f"warmup steps: {warmup}")
print(max_iters)

scheduler = CosineWarmupScheduler(optimizer=optimizer, warmup=warmup, max_iters=max_iters)

In [None]:
def checkpoint(model, filename):
    
    # remove old files
    for filename in glob.glob(f"checkpoints/{run.id}*"):
        os.remove(filename) 
    
    # create_dir
    dir_checkpoints = "checkpoints/"
    os.makedirs(dir_checkpoints, exist_ok = True) 
    
    # save new file
    print("saving new checkpoints.")
    torch.save(model.state_dict(), os.path.join(dir_checkpoints,f"{run.id}*"))

In [None]:
# half precision, see https://pytorch.org/docs/stable/amp.html
scaler = torch.cuda.amp.GradScaler()


early_stopping = EarlyStopping(patience=15)

# see https://stackoverflow.com/a/53628783/5755604
# no sigmoid required; numerically more stable
# do not reduce, calculate mean after multiplication with weight

step = 0
best_accuracy = -1
best_step = -1

for epoch in tqdm(range(epochs)):

    # perform training
    loss_in_epoch_train = 0

    batch = 0
    
    for x_cat, x_cont, _, masks in train_loader:
    
        clf.train()
        optimizer.zero_grad()

        # for my implementation
        with torch.autocast(device_type='cuda', dtype=torch.float16):
            logits = clf(x_cat, x_cont).flatten()
            train_loss = criterion(logits, masks)

        scaler.scale(train_loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        scheduler.step()
        
        # add the mini-batch training loss to epoch loss
        loss_in_epoch_train += train_loss  # .item()
        wandb.log({"train_loss_step": train_loss, "epoch": epoch, "batch": batch})
            
        batch += 1
        step +=1

    clf.eval()
    loss_in_epoch_val = 0.0
    correct = 0
    
    with torch.no_grad():
        for x_cat, x_cont, weights, targets in val_loader:

            # for my implementation
            logits = clf(x_cat, x_cont).flatten()
            logits = logits.flatten()

            val_loss = criterion(logits, targets)
            
            # get probabilities and round to nearest integer
            preds = torch.sigmoid(logits).round()
            correct += (preds == targets).sum().item()

            loss_in_epoch_val += val_loss  # val_loss #.item()
            wandb.log({"val_loss_step": val_loss, "epoch": epoch, "batch": batch})
            
            batch +=1      

    # loss average over all batches
    train_loss = loss_in_epoch_train / len(train_loader)
    val_loss = loss_in_epoch_val / len(val_loader)
    
    # correct samples / no samples
    val_accuracy = correct / len(X_val)
    if best_accuracy < val_accuracy:
        checkpoint(clf, f"checkpoints/{run.id}-{step}.ptx")
        best_accuracy = val_accuracy
        best_step = step
    
    
    wandb.log({"train_loss": train_loss, 'epoch': epoch})
    wandb.log({"val_loss": val_loss, 'epoch': epoch}) 
    
    print(f"train:{train_loss} val:{val_loss}")
    print(f"val accuracy:{val_accuracy}")

    # return early if val accuracy doesn't improve. Minus to minimize.
    early_stopping(-val_accuracy)
    if early_stopping.early_stop or math.isnan(train_loss) or math.isnan(val_loss):
        print("meh... early stopping")
        break
