In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torch.nn.functional as F
from torch_frame import stype
from torch_frame.data import DataLoader
from torch_frame.nn import (
    EmbeddingEncoder,
    LinearEncoder,
    TimestampEncoder,
)
from tqdm import tqdm

from transformers import get_inverse_sqrt_schedule

import sys
from icecream import ic
import wandb

In [3]:
seed = 42
batch_size = 512
channels = 256
num_layers = 4

pretrain = True
compile = True
lr = 5e-4
eps = 1e-8
epochs = 10
args = {
    "seed": seed,
    "batch_size": batch_size,
    "channels": channels,
    "num_layers": num_layers,
    "pretrain": pretrain,
    "compile": compile,
    "lr": lr,
    "eps": eps,
    "epochs": epochs,
}


In [4]:
wandb.login()
run = wandb.init(project=f"rel-mm", name="model=fttransformer,dataset=ibm-aml_hi_sm", config=args)

[34m[1mwandb[0m: Currently logged in as: [33maakyildiz[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [5]:
from torch_frame.datasets import IBMTransactionsAML
#dataset = IBMTransactionsAML(root='/mnt/data/ibm-transactions-for-anti-money-laundering-aml/dummy.csv', pretrain=pretrain)
dataset = IBMTransactionsAML(root='/mnt/data/ibm-transactions-for-anti-money-laundering-aml/HI-Small_Trans-cleaned.csv', pretrain=pretrain)
ic(dataset)
dataset.materialize()
num_numerical = len(dataset.tensor_frame.col_names_dict[stype.numerical])
num_categorical = len(dataset.tensor_frame.col_names_dict[stype.categorical])
dataset.df.head(5)

ic| dataset: IBMTransactionsAML()
ic| list(self._col_names_dict[stype.numerical]) + list(self._col_names_dict[stype.categorical]): ['Amount Paid',
                                                                                                  'Amount Received',
                                                                                                  'From Bank',
                                                                                                  'From ID',
                                                                                                  'Payment Currency',
                                                                                                  'Payment Format',
                                                                                                  'Receiving Currency',
                                                                                                  'To Bank',
                                                     

Unnamed: 0,Timestamp,From Bank,From ID,To Bank,To ID,Amount Received,Receiving Currency,Amount Paid,Payment Currency,Payment Format,Is Laundering,MASK,split
0,1200,B_10,8000EBD30,B_10,8000EBD30,0.296848,US Dollar,,US Dollar,Reinvestment,0,"[0.2968476112178767, 0]",0
1,1200,B_3208,8000F4580,B_1,8000F5340,0.000359,US Dollar,,US Dollar,Cheque,0,"[0.0003594894238955, 0]",0
2,0,B_3209,8000F4670,B_3209,8000F4670,0.346651,US Dollar,0.346651,US Dollar,,0,"[4, 5]",0
3,120,B_12,8000F5030,B_12,8000F5030,0.286896,US Dollar,0.286896,US Dollar,,0,"[4, 5]",0
4,360,B_10,8000F5200,B_10,8000F5200,,US Dollar,0.379751,US Dollar,Reinvestment,0,"[0.3797509348152993, 1]",0


In [6]:
torch.manual_seed(seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
wandb.log({"device": str(device)})

In [7]:
train_dataset, val_dataset, test_dataset = dataset.split()

In [8]:
train_tensor_frame = train_dataset.tensor_frame
val_tensor_frame = val_dataset.tensor_frame
test_tensor_frame = test_dataset.tensor_frame
train_loader = DataLoader(train_tensor_frame, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_tensor_frame, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_tensor_frame, batch_size=batch_size, shuffle=False)
ic(len(train_loader), len(val_loader), len(test_loader))
wandb.log({
    "train_loader size": len(train_loader), 
    "val_loader size": len(val_loader), 
    "test_loader size": len(test_loader)
})

ic| len(train_loader): 6346
    len(val_loader): 1886
    len(test_loader): 1688


In [9]:
# print an example batch
ic(next(iter(train_loader)).feat_dict)
ic(next(iter(train_loader)).y)

ic| next(iter(train_loader)).feat_dict: {<stype.numerical: 'numerical'>: tensor([[0.3537, 0.3537],
                                                [0.2395, 0.2395],
                                                [0.3307, 0.3307],
                                                ...,
                                                [0.3662, 0.3662],
                                                [0.0811, 0.0811],
                                                [   nan, 0.2744]]),
                                         <stype.categorical: 'categorical'>: tensor([[   167,  42250,      1,  ...,     -1,    337,  68473],
                                                [   845,  64259,     -1,  ...,      1,    929,  11430],
                                                [   488,  17825,      3,  ...,      3,   1070, 104175],
                                                ...,
                                                [     4,   8224,      8,  ...,     -1,     94,   3956],
          

tensor([[0.3935, 1.0000],
        [0.0000, 4.0000],
        [0.2136, 1.0000],
        ...,
        [0.3796, 1.0000],
        [0.2273, 0.0000],
        [0.0819, 0.0000]])

In [10]:

stype_encoder_dict = {
    stype.categorical: EmbeddingEncoder(),
    stype.numerical: LinearEncoder(),
    stype.timestamp: TimestampEncoder(),
}

from models.ft_transformer import FTTransformer 
model = FTTransformer(
    channels=channels,
    out_channels=None,
    num_layers=num_layers,
    col_stats=dataset.col_stats,
    col_names_dict=train_tensor_frame.col_names_dict,
    stype_encoder_dict=stype_encoder_dict,
    pretrain = pretrain
).to(device)

model = torch.compile(model, dynamic=True) if compile else model
learnable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
ic(learnable_params)
wandb.log({"learnable_params": learnable_params})

# Prepare optimizer and lr scheduler
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
    {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=lr, eps=eps)
scheduler = get_inverse_sqrt_schedule(optimizer, num_warmup_steps=0, timescale=1000)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

def calc_loss(pred, y):
    accum_n = accum_c = t_n = t_c = 0
    for i, ans in enumerate(y):
        # ans --> [val, idx]
        # pred --> feature_type_num X type_num X batch_size
        if ans[1] > (num_numerical-1):
            t_c += 1
            a = torch.tensor(int(ans[0])).to(device)
            accum_c += F.cross_entropy(pred[1][int(ans[1])-num_numerical][i], a)
            del a
        else:
            t_n += 1
            accum_n += torch.square(pred[0][i][int(ans[1])] - ans[0]) #mse
    return (accum_n / t_n) + torch.sqrt(accum_c / t_c), (accum_c, t_c), (accum_n, t_n)

def train(epoc: int) -> float:
    model.train()
    loss_accum = loss_c_accum = loss_n_accum = total_count = t_c = t_n = 0

    with tqdm(train_loader, desc=f'Epoch {epoc}') as t:
        for tf in t:
            tf = tf.to(device)
            pred = model(tf)
            loss, loss_c, loss_n = calc_loss(pred, tf.y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            loss_accum += float(loss) * len(tf.y)
            loss_c_accum += loss_c[0]
            loss_n_accum += loss_n[0]
            total_count += len(tf.y)
            t_c += loss_c[1]
            t_n += loss_n[1]
            t.set_postfix(loss=f'{loss_accum/total_count:.4f}', loss_c = f'{loss_c_accum/t_c:.4f}', loss_n = f'{loss_n_accum/t_n:.4f}')
            wandb.log({"train_loss": loss_accum/total_count, "train_loss_c": loss_c_accum/t_c, "train_loss_n": loss_n_accum/t_n})
    return (loss_c_accum/t_c) + (loss_n_accum/t_n)

@torch.no_grad()
def test(loader: DataLoader, dataset_name) -> float:
    model.eval()
    accum_acc = accum_l2 = 0
    loss_c_accum = loss_n_accum = 0
    t_n = t_c = 0
    with tqdm(loader, desc=f'Evaluating') as t:
        for tf in t:
            tf = tf.to(device)
            pred = model(tf)
            _, loss_c, loss_n = calc_loss(pred, tf.y)
            loss_c_accum += loss_c[0]
            loss_n_accum += loss_n[0]
            t_c += loss_c[1]
            t_n += loss_n[1]
            for i, ans in enumerate(tf.y):
                # ans --> [val, idx]
                # pred --> feature_type_num X type_num X batch_size
                if ans[1] > (num_numerical-1):
                    accum_acc += (pred[1][int(ans[1])-num_numerical][i].argmax() == int(ans[0]))
                else:
                    accum_l2 += torch.square(ans[0] - pred[0][i][int(ans[1])]) #rmse
            
            t.set_postfix(accuracy=f'{accum_acc/t_c:.4f}', rmse=f'{torch.sqrt(accum_l2/t_n):.4f}', loss=f'{(loss_c_accum/t_c) + (loss_n_accum/t_n):.4f}', loss_c = f'{loss_c_accum/t_c:.4f}', loss_n = f'{loss_n_accum/t_n:.4f}')
            wandb.log({f"{dataset_name}_accuracy": accum_acc/t_c, f"{dataset_name}_rmse": torch.sqrt(accum_l2/t_n), f"{dataset_name}_loss": (loss_c_accum/t_c) + (loss_n_accum/t_n), f"{dataset_name}_loss_c": loss_c_accum/t_c, f"{dataset_name}_loss_n": loss_n_accum/t_n})
        del pred
        accuracy = accum_acc / t_c
        rmse = torch.sqrt(accum_l2 / t_n)
        return [rmse, accuracy]

ic| learnable_params: 496110191


In [11]:
torch.cuda.empty_cache()

In [12]:
for epoch in range(1, epochs + 1):
    train_loss = train(epoch)
    train_metric = test(train_loader, "train")
    val_metric = test(val_loader, "val")
    test_metric = test(test_loader, "test")
    #ic(train_loss, train_metric, val_metric, test_metric)

Epoch 1:   0%|                                                                                                                                                                                                              | 0/6346 [00:00<?, ?it/s]

Epoch 1: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6346/6346 [11:36<00:00,  9.11it/s, loss=0.7418, loss_c=0.5362, loss_n=0.0131]
Evaluating: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6346/6346 [04:34<00:00, 23.15it/s, accuracy=0.8015, loss=0.4942, loss_c=0.4823, loss_n=0.0119, rmse=0.1088]
Evaluating: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1886/1886 [01:20<00:00, 23.39it/s, accuracy=0.8011, loss=0.5009, loss_c=0.4915, loss_n=0.0095, rmse=0.0973]
Evaluating: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1688/1688 [01:12<00:00, 23.31it/s, accuracy=0.7965, loss=0.5110, loss_c=0.5015, loss_n=0.0094, rmse=0.0970]
Epoch 2:  68%|██