In [None]:
import glob
import os
import math
import sys
from pathlib import Path

import numpy as np
import pandas as pd
import wandb
import torch
from torch import optim, nn
from tqdm.auto import tqdm


In [None]:
sys.path.append("..")
from otc.models.fttransformer import FeatureTokenizer, FTTransformer, Transformer
from otc.models.activation import ReGLU
from otc.data.dataset import TabDataset
from otc.data.dataloader import TabDataLoader
from otc.features.build_features import features_classical_size
from otc.optim.early_stopping import EarlyStopping
from otc.optim.scheduler import CosineWarmupScheduler

In [None]:
os.environ["GCLOUD_PROJECT"] = "flowing-mantis-239216"

In [None]:
run = wandb.init(project="thesis", entity="fbv")

dataset = "fbv/thesis/ise_supervised_log_standardized_clipped:latest"
artifact = run.use_artifact(dataset)
data_dir = artifact.download()


In [None]:
# preserve relative ordering, sample for testing ache
frac = 0.05

# sample
X_train = pd.read_parquet(Path(data_dir, "train_set.parquet"), engine="fastparquet")
X_train = X_train.sample(frac=frac, random_state=42)
y_train = X_train["buy_sell"]
X_train = X_train[features_classical_size]

X_val = pd.read_parquet(Path(data_dir, "val_set.parquet"), engine="fastparquet").sample(
    frac=frac, random_state=42
)
y_val = X_val["buy_sell"]
X_val = X_val[features_classical_size]

X_test = pd.read_parquet(Path(data_dir, "test_set.parquet"), engine="fastparquet")
y_test = X_test["buy_sell"]
X_test = X_test[features_classical_size]

# eps = 0.1
# y_train[np.where(y_train == 0)] = eps
# y_train[np.where(y_train == 1)] = 1.0 - eps


## Run Area

In [None]:
frac = 1

device = "cuda"
batch_size = 16192
epochs = 100

d_token = 192
n_blocks = 3
attention_dropout = 0.2
ffn_dropout = 0.1
residual_dropout = 0.0
attention_heads = 8


# clipping_value = 5
reduction = "mean"

feature_tokenizer_kwargs = {
    "num_continous": len(X_train.columns.tolist()),
    "cat_cardinalities": (),
    "d_token": d_token,
}

dl_params = {
    "batch_size": batch_size,  # dataprallel splits batches across devices
    "shuffle": False,
    "device": device,
}

transformer_kwargs = {
    "d_token": d_token,
    "n_blocks": n_blocks,
    "attention_n_heads": attention_heads,
    "attention_initialization": "kaiming",
    "ffn_activation": ReGLU,
    "attention_normalization": nn.LayerNorm,
    "ffn_normalization": nn.LayerNorm,
    "ffn_dropout": ffn_dropout,
    # fix at 4/3, as activation (see search space B in
    # https://arxiv.org/pdf/2106.11959v2.pdf)
    # is static with ReGLU / GeGLU
    "ffn_d_hidden": int(d_token * (4 / 3)),
    "attention_dropout": attention_dropout,
    "residual_dropout": residual_dropout,  # see search space (B)
    "prenormalization": True,
    "first_prenormalization": False,
    "last_layer_query_idx": None,
    "n_tokens": None,
    "kv_compression_ratio": None,
    "kv_compression_sharing": None,
    "head_activation": nn.GELU,  # nn.ReLU
    "head_normalization": nn.LayerNorm,
    "d_out": 1,  # fix at 1, due to binary classification
}


optim_params = {"lr": 1e-4, "weight_decay": 0.00001}

module_params = {
    "transformer": Transformer(**transformer_kwargs),  
    "feature_tokenizer": FeatureTokenizer(**feature_tokenizer_kwargs),   # noqa: E501
    "cat_features": None,
    "cat_cardinalities": [],
}

clf = FTTransformer(**module_params)
# use multiple gpus, if available
clf = nn.DataParallel(clf).to(device)


criterion = nn.BCEWithLogitsLoss()
# wandb.log(other_kwargs)
# wandb.log(transformer_kwargs)
# wandb.log(optim_params)
# wandb.log(feature_tokenizer_kwargs)
# wandb.log(dl_params)


In [None]:
training_data = TabDataset(X_train, y_train)
val_data = TabDataset(X_val, y_val)
test_data = TabDataset(X_test, y_test)

train_loader = TabDataLoader(
    training_data.x_cat,
    training_data.x_cont,
    training_data.weight,
    training_data.y,
    **dl_params
)
val_loader = TabDataLoader(
    val_data.x_cat, val_data.x_cont, val_data.weight, val_data.y, **dl_params
)

test_loader = TabDataLoader(
    test_data.x_cat, test_data.x_cont, test_data.weight, test_data.y, **dl_params
)


In [None]:
optimizer = optim.AdamW(
    clf.parameters(),
    lr=optim_params["lr"],
    weight_decay=optim_params["weight_decay"],
)

max_iters = epochs * len(train_loader)
# saw recommendation of 5 - 10 % of total training budget or 100 to 500 steps
warmup = int(0.05 * max_iters)
print(f"warmup steps: {warmup}")
print(max_iters)


scheduler = CosineWarmupScheduler(
    optimizer=optimizer, warmup=warmup, max_iters=max_iters
)


In [None]:
def checkpoint(model, filename):

    # remove old files
    for filename in glob.glob(f"checkpoints/{run.id}*"):
        os.remove(filename)

    # create_dir
    dir_checkpoints = "checkpoints/"
    os.makedirs(dir_checkpoints, exist_ok=True)

    # save new file
    print("saving new checkpoints.")
    torch.save(model.state_dict(), os.path.join(dir_checkpoints, f"{run.id}*"))


In [None]:
# half precision, see https://pytorch.org/docs/stable/amp.html
scaler = torch.cuda.amp.GradScaler()


early_stopping = EarlyStopping(patience=15)

# see https://stackoverflow.com/a/53628783/5755604
# no sigmoid required; numerically more stable
# do not reduce, calculate mean after multiplication with weight

step = 0
best_accuracy = -1
best_step = -1

for epoch in tqdm(range(epochs)):

    # perform training
    loss_in_epoch_train = 0

    batch = 0

    for x_cat, x_cont, _, targets in train_loader:

        clf.train()
        optimizer.zero_grad()

        # for my implementation
        with torch.autocast(device_type="cuda", dtype=torch.float16):
            logits = clf(x_cat, x_cont).flatten()
            train_loss = criterion(logits, targets)

        scaler.scale(train_loss).backward()
        scaler.step(optimizer)
        scaler.update()

        scheduler.step()

        # add the mini-batch training loss to epoch loss
        loss_in_epoch_train += train_loss.item()
        wandb.log({"train_loss_step": train_loss.item(), "epoch": epoch, "batch": batch})

        batch += 1
        step += 1

    clf.eval()
    loss_in_epoch_val = 0.0
    correct = 0

    with torch.no_grad():
        for x_cat, x_cont, _, targets in val_loader:

            # for my implementation
            logits = clf(x_cat, x_cont).flatten()
            logits = logits.flatten()

            val_loss = criterion(logits, targets)

            # get probabilities and round to nearest integer
            preds = torch.sigmoid(logits).round()
            correct += (preds == targets).sum().item()

            loss_in_epoch_val += val_loss.item()
            wandb.log({"val_loss_step": val_loss.item(), "epoch": epoch, "batch": batch})

            batch += 1

    # loss average over all batches
    train_loss = loss_in_epoch_train / len(train_loader)
    val_loss = loss_in_epoch_val / len(val_loader)

    # correct samples / no samples
    val_accuracy = correct / len(X_val)
    if best_accuracy < val_accuracy:
        checkpoint(clf, f"checkpoints/{run.id}-{step}.ptx")
        best_accuracy = val_accuracy
        best_step = step

    wandb.log({"train_loss": train_loss, "epoch": epoch})
    wandb.log({"val_loss": val_loss, "epoch": epoch})

    print(f"train:{train_loss} val:{val_loss}")
    print(f"val accuracy:{val_accuracy}")

    # return early if val accuracy doesn't improve. Minus to minimize.
    early_stopping(-val_accuracy)
    if early_stopping.early_stop or math.isnan(train_loss) or math.isnan(val_loss):
        print("meh... early stopping")
        break


In [None]:
cp = glob.glob(f"checkpoints/{run.id}*")
print(cp)


In [None]:
clf.load_state_dict(torch.load(cp[0]))


In [None]:
y_pred, y_true = [], []

for x_cat, x_cont, _, targets in test_loader:

    logits = clf(x_cat, x_cont).flatten()
    logits = logits.flatten()

    # map between zero and one, sigmoid is otherwise included in loss already
    # https://stackoverflow.com/a/66910866/5755604
    preds = torch.sigmoid(logits.squeeze())
    y_pred.append(preds.detach().cpu().numpy())
    y_true.append(targets.detach().cpu().numpy())  

# round prediction to nearest int
y_pred = np.rint(np.concatenate(y_pred))
y_true = np.concatenate(y_true)

acc = (y_pred == y_true).sum() / len(y_true) 
print(acc)
