In [None]:
import glob
import os
import math
from pathlib import Path
import sys

import numpy as np
import pandas as pd

from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import wandb

from otc.models.activation import ReGLU
from otc.models.fttransformer import (
    FeatureTokenizer,
    FTTransformer,
    Transformer,
    CLSHead,
)

from otc.data.dataset import TabDataset
from otc.data.dataloader import TabDataLoader
from otc.features.build_features import features_classical, features_classical_size
from otc.optim.early_stopping import EarlyStopping
from otc.optim.scheduler import CosineWarmupScheduler

In [None]:
os.environ["GCLOUD_PROJECT"] = "flowing-mantis-239216"

In [None]:
run = wandb.init(project="thesis", entity="fbv")

dataset = "fbv/thesis/ise_unsupervised_log_standardized_clipped:latest"
artifact = run.use_artifact(dataset)
data_dir = artifact.download()

In [None]:
# preserve relative ordering, sample for testing ache
frac = 1 #0.05

X_train = pd.read_parquet(Path(data_dir, "train_set.parquet"), engine="fastparquet")
X_train = X_train.sample(frac=frac, random_state=42)

y_train = X_train["buy_sell"] # here: y = 0
X_train = X_train[features_classical_size]

In [None]:
training_data = TabDataset(X_train, y_train)
x_cont = training_data.x_cont
x_cat = training_data.x_cat

In [None]:
def gen_perm(X):
    """
    Generate index permutation.
    """
    if X is None:
        return None
    return torch.randint_like(X, X.shape[0], dtype=torch.long)

x_cont_perm = gen_perm(x_cont)
x_cat_perm = gen_perm(x_cat)

In [None]:
def gen_masks(X, perm, corrupt_probability = 0.15):
    """
    Generate binary mask for detection.
    """
    masks = torch.empty_like(X).bernoulli(p=corrupt_probability).bool()
    new_masks = masks & (X != X[perm, torch.arange(X.shape[1], device=X.device)])
    return new_masks

# generate masks for numeric and for categorical features (optional)
x_cont_mask = gen_masks(training_data.x_cont, x_cont_perm)

if training_data.x_cat is not None:
    x_cat_mask = gen_masks(training_data.x_cat, x_cat_perm)
else:
    x_cat_mask = None

In [None]:
# replace at permutation
x_cont_permuted = torch.gather(x_cont, 0, x_cont_perm)

# replace at mask = True
x_cont[x_cont_mask] = x_cont_permuted[x_cont_mask]

if x_cat is not None:

    # along the 0 axis get elements based on perm_cat
    x_cat_permuted = torch.gather(x_cat, 0, x_cat_perm)
    
    # replace at mask
    x_cat[x_cat_mask] = x_cat_permuted[x_cat_mask]

In [None]:
# merge masks [mask_num, mask_cat]
if x_cat != None:
    masks = torch.cat([x_cont_mask, x_cat_mask], dim=1)
else:
    masks = x_cont_mask

In [None]:
# split up into train (first 80 %) and val (last 20 %)
idx = int (len(x_cont) * 0.8)

x_cont_train, x_cont_val = torch.split(x_cont, idx, dim=0)
masks_train, masks_val = torch.split(masks, idx, dim=0)

if x_cat is not None:
    x_cat_train, x_cat_val = torch.split(x_cat, idx, dim=0)
else:
    x_cat_train, x_cat_val = None, None


## Runfield 🛫

In [None]:
device = "cuda"
batch_size = 16192
epochs = 100

d_token = 192
n_blocks = 3
attention_dropout = 0.2
ffn_dropout = 0.1
residual_dropout = 0.0
attention_heads = 8


reduction = "mean"

dl_params = {
    "batch_size": batch_size,
    "shuffle": False,
    "device": "cuda",
}

feature_tokenizer_kwargs = {
    "num_continous": len(X_train.columns.tolist()),
    "cat_cardinalities": (),
    "d_token": d_token,
}

dl_params = {
    "batch_size": batch_size,  # dataprallel splits batches across devices
    "shuffle": False,
    "device": device,
}

transformer_kwargs = {
    "d_token": d_token,
    "n_blocks": n_blocks,
    "attention_n_heads": attention_heads,
    "attention_initialization": "kaiming",
    "ffn_activation": ReGLU,
    "attention_normalization": nn.LayerNorm,
    "ffn_normalization": nn.LayerNorm,
    "ffn_dropout": ffn_dropout,
    # fix at 4/3, as activation (see search space B in
    # https://arxiv.org/pdf/2106.11959v2.pdf)
    # is static with ReGLU / GeGLU
    "ffn_d_hidden": int(d_token * (4 / 3)),
    "attention_dropout": attention_dropout,
    "residual_dropout": residual_dropout,  # see search space (B)
    "prenormalization": True,
    "first_prenormalization": False,
    "last_layer_query_idx": None,
    "n_tokens": None,
    "kv_compression_ratio": None,
    "kv_compression_sharing": None,
    "head_activation": nn.GELU, # nn.ReLU
    "head_normalization": nn.LayerNorm,
    "d_out": 1,  # fix at 1, due to binary classification
}

head_kwargs = {"d_in": d_token, "d_hidden": 32}

optim_params = {"lr": 1e-4, "weight_decay": 0.00001}

module_params = {
    "transformer": Transformer(**transformer_kwargs),  # type: ignore
    "feature_tokenizer": FeatureTokenizer(**feature_tokenizer_kwargs),  # type: ignore # noqa: E501
    "cat_features": None,
    "cat_cardinalities": [],
}

clf = FTTransformer(**module_params)
criterion = nn.BCEWithLogitsLoss()

# swap target head with classification head, and restore later
target_head = clf.transformer.head
clf_head = CLSHead(**head_kwargs)
clf.transformer.head = clf_head

clf.to(device)


In [None]:
train_loader = TabDataLoader(
    x_cat_train,
    x_cont_train,
    masks_train, 
    **dl_params
)

val_loader = TabDataLoader(
    x_cat_val,
    x_cont_val,
    masks_val, 
    **dl_params
)

In [None]:
optimizer = optim.AdamW(clf.parameters(),
    lr=optim_params["lr"],
    weight_decay=optim_params["weight_decay"],
)

max_iters = epochs * len(train_loader)
# saw recommendation of 5 - 10 % of total training budget or 100 to 500 steps
warmup = int(0.05 * max_iters)
print(f"warmup steps: {warmup}")
print(max_iters)

scheduler = CosineWarmupScheduler(optimizer=optimizer, warmup=warmup, max_iters=max_iters)

In [None]:
def checkpoint(model):
    
    # remove old files
    for fn in glob.glob(f"checkpoints/{run.id}*"):
        os.remove(fn) 
    
    # create_dir
    dir_checkpoints = "checkpoints/"
    os.makedirs(dir_checkpoints, exist_ok = True) 
    
    # save new file
    print("saving new checkpoints.")
    torch.save(model.state_dict(), os.path.join(dir_checkpoints,f"{run.id}*"))

In [None]:
# half precision, see https://pytorch.org/docs/stable/amp.html
scaler = torch.cuda.amp.GradScaler()

early_stopping = EarlyStopping(patience=15)

step = 0
best_accuracy = -1
best_step = -1

for epoch in tqdm(range(epochs)):

    # perform training
    loss_in_epoch_train = 0

    batch = 0
    
    for x_cat, x_cont, masks in train_loader:
    
        clf.train()
        optimizer.zero_grad()
        
        with torch.autocast(device_type='cuda', dtype=torch.float16):
            logits = clf(x_cat, x_cont)
            train_loss = criterion(logits, masks.float())

        scaler.scale(train_loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        scheduler.step()
        
        # add the mini-batch training loss to epoch loss
        loss_in_epoch_train += train_loss.item()
        
        wandb.log({"train_loss_step": train_loss.item(), "epoch": epoch, "batch": batch})

        batch += 1
        step +=1

    clf.eval()
    loss_in_epoch_val = 0.0
    correct = 0
    
    with torch.no_grad():
        for x_cat, x_cont, masks in val_loader:

            # for my implementation
            logits = clf(x_cat, x_cont)
            val_loss = criterion(logits, masks.float())
            

            # hard_predictions = torch.zeros_like(logits, dtype=torch.long)
            # hard_predictions[logits > 0] = 1
            # correct += (hard_predictions.bool() == masks).sum()  / hard_predictions.shape[0]

            loss_in_epoch_val += val_loss.item()
            wandb.log({"val_loss_step": val_loss.item(), "epoch": epoch, "batch": batch})
            
            batch +=1      

    # correct / (rows * columns)
    # val_accuracy = correct / (X_train.shape[0] * X_train.shape[1])        
    
    # loss average over all batches
    train_loss = loss_in_epoch_train / len(train_loader)
    val_loss = loss_in_epoch_val / len(val_loader)
    
    print(f"train loss: {train_loss}")
    print(f"val loss: {val_loss}")
    
    # correct samples / no samples
    # val_accuracy = correct / len(X_val)
    # if best_accuracy < val_accuracy:
    #     checkpoint(clf, f"checkpoints/{run.id}-{step}.ptx")
    #     best_accuracy = val_accuracy
    #     best_step = step
    
    
    wandb.log({"train_loss": train_loss, 'epoch': epoch})

## Prepare for Finetuning🕰️

In [None]:
# swap heads and save model
clf.transformer.head = target_head
checkpoint(clf)

# FIXME: Think about which weights to freeze and which to update
# https://ai.stackexchange.com/questions/23884/why-arent-the-bert-layers-frozen-during-fine-tuning-tasks