In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torch.nn.functional as F
from torch_frame import stype
from torch_frame.data import DataLoader
from torch_frame.nn import (
    EmbeddingEncoder,
    LinearEncoder,
    TimestampEncoder,
)
from tqdm import tqdm

from transformers import get_inverse_sqrt_schedule

import sys
from icecream import ic
import wandb

torch.set_float32_matmul_precision('high')

In [3]:
seed = 42
batch_size = 1024
channels = 256
num_layers = 4

pretrain = True
compile = True
lr = 5e-4
eps = 1e-8
epochs = 15
args = {
    "seed": seed,
    "batch_size": batch_size,
    "channels": channels,
    "num_layers": num_layers,
    "pretrain": pretrain,
    "compile": compile,
    "lr": lr,
    "eps": eps,
    "epochs": epochs,
}


In [17]:
wandb.login()
#run = wandb.init(project=f"rel-mm", name="model=fttransformer,dataset=IBM-AML_Hi_Sm,batch_size=1024,weighted_loss", config=args)
run = wandb.init(project=f"rel-mm", name="debug", config=args)



VBox(children=(Label(value='0.004 MB of 0.004 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112169191862146, max=1.0…

In [5]:
from torch_frame.datasets import IBMTransactionsAML
#dataset = IBMTransactionsAML(root='/mnt/data/ibm-transactions-for-anti-money-laundering-aml/dummy.csv', pretrain=pretrain)
dataset = IBMTransactionsAML(root='/mnt/data/ibm-transactions-for-anti-money-laundering-aml/HI-Small_Trans-cleaned.csv', pretrain=pretrain)
ic(dataset)
dataset.materialize()
num_numerical = len(dataset.tensor_frame.col_names_dict[stype.numerical])
num_categorical = len(dataset.tensor_frame.col_names_dict[stype.categorical])
dataset.df.head(5)

ic| dataset: IBMTransactionsAML()
ic| list(self._col_names_dict[stype.numerical]) + list(self._col_names_dict[stype.categorical]): ['Amount Paid',
                                                                                                  'Amount Received',
                                                                                                  'From Bank',
                                                                                                  'From ID',
                                                                                                  'Payment Currency',
                                                                                                  'Payment Format',
                                                                                                  'Receiving Currency',
                                                                                                  'To Bank',
                                                     

Unnamed: 0,Timestamp,From Bank,From ID,To Bank,To ID,Amount Received,Receiving Currency,Amount Paid,Payment Currency,Payment Format,Is Laundering,MASK,split
0,1200,B_10,8000EBD30,B_10,8000EBD30,0.296848,,0.296848,US Dollar,Reinvestment,0,"[0, 6]",0
1,1200,B_3208,8000F4580,B_1,8000F5340,0.000359,,0.000359,US Dollar,Cheque,0,"[0, 6]",0
2,0,B_3209,8000F4670,B_3209,8000F4670,0.346651,,0.346651,US Dollar,Reinvestment,0,"[0, 6]",0
3,120,B_12,8000F5030,B_12,8000F5030,0.286896,US Dollar,0.286896,,Reinvestment,0,"[0, 4]",0
4,360,B_10,8000F5200,B_10,8000F5200,0.379751,US Dollar,,US Dollar,Reinvestment,0,"[0.3797509348152993, 0]",0


In [18]:
num_columns = num_numerical + num_categorical
ic(
    num_numerical,
    num_categorical,
    num_columns,
)

ic| num_numerical: 2, num_categorical: 7, num_columns: 

9


(2, 7, 9)

In [19]:
torch.manual_seed(seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
wandb.log({"device": str(device)})

In [20]:
train_dataset, val_dataset, test_dataset = dataset.split()

In [21]:
train_tensor_frame = train_dataset.tensor_frame
val_tensor_frame = val_dataset.tensor_frame
test_tensor_frame = test_dataset.tensor_frame
train_loader = DataLoader(train_tensor_frame, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_tensor_frame, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_tensor_frame, batch_size=batch_size, shuffle=False)
ic(len(train_loader), len(val_loader), len(test_loader))
wandb.log({
    "train_loader size": len(train_loader), 
    "val_loader size": len(val_loader), 
    "test_loader size": len(test_loader)
})

ic| len(train_loader): 4116, len(val_loader):

 844, len(test_loader): 0


In [22]:

stype_encoder_dict = {
    stype.categorical: EmbeddingEncoder(),
    stype.numerical: LinearEncoder(),
    stype.timestamp: TimestampEncoder(),
}

from models.ft_transformer import FTTransformer 
model = FTTransformer(
    channels=channels,
    out_channels=None,
    num_layers=num_layers,
    col_stats=dataset.col_stats,
    col_names_dict=train_tensor_frame.col_names_dict,
    stype_encoder_dict=stype_encoder_dict,
    pretrain = pretrain
).to(device)

model = torch.compile(model, dynamic=True) if compile else model
learnable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
ic(learnable_params)
wandb.log({"learnable_params": learnable_params})

# Prepare optimizer and lr scheduler
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
    {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=lr, eps=eps)
scheduler = get_inverse_sqrt_schedule(optimizer, num_warmup_steps=0, timescale=1000)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

def calc_loss(pred, y):
    accum_n = accum_c = t_n = t_c = 0
    for i, ans in enumerate(y):
        # ans --> [val, idx]
        # pred --> feature_type_num X type_num X batch_size
        if ans[1] > (num_numerical-1):
            t_c += 1
            a = torch.tensor(int(ans[0])).to(device)
            accum_c += F.cross_entropy(pred[1][int(ans[1])-num_numerical][i], a)
            del a
        else:
            t_n += 1
            accum_n += torch.square(pred[0][i][int(ans[1])] - ans[0]) #mse
    return (accum_n / t_n) + torch.sqrt(accum_c / t_c), (accum_c, t_c), (accum_n, t_n)

def train(epoc: int) -> float:
    model.train()
    loss_accum = loss_c_accum = loss_n_accum = total_count = t_c = t_n = 0

    with tqdm(train_loader, desc=f'Epoch {epoc}') as t:
        for tf in t:
            tf = tf.to(device)
            pred = model(tf)
            loss, loss_c, loss_n = calc_loss(pred, tf.y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            loss_accum += float(loss) * len(tf.y)
            loss_c_accum += loss_c[0]
            loss_n_accum += loss_n[0]
            total_count += len(tf.y)
            t_c += loss_c[1]
            t_n += loss_n[1]
            t.set_postfix(loss=f'{loss_accum/total_count:.4f}', loss_c = f'{loss_c_accum/t_c:.4f}', loss_n = f'{loss_n_accum/t_n:.4f}')
            del pred
        wandb.log({"train_loss": loss_accum/total_count, "train_loss_c": loss_c_accum/t_c, "train_loss_n": loss_n_accum/t_n})
    return ((loss_c_accum/t_c) * (num_categorical/num_columns)) + ((loss_n_accum/t_n) * (num_numerical/num_columns))

@torch.no_grad()
def test(loader: DataLoader, dataset_name) -> float:
    model.eval()
    accum_acc = accum_l2 = 0
    loss_c_accum = loss_n_accum = 0
    t_n = t_c = 0
    with tqdm(loader, desc=f'Evaluating') as t:
        for tf in t:
            tf = tf.to(device)
            pred = model(tf)
            _, loss_c, loss_n = calc_loss(pred, tf.y)
            loss_c_accum += loss_c[0]
            loss_n_accum += loss_n[0]
            t_c += loss_c[1]
            t_n += loss_n[1]
            for i, ans in enumerate(tf.y):
                # ans --> [val, idx]
                # pred --> feature_type_num X type_num X batch_size
                if ans[1] > (num_numerical-1):
                    accum_acc += (pred[1][int(ans[1])-num_numerical][i].argmax() == int(ans[0]))
                else:
                    accum_l2 += torch.square(ans[0] - pred[0][i][int(ans[1])]) #rmse
            
            t.set_postfix(accuracy=f'{accum_acc/t_c:.4f}', rmse=f'{torch.sqrt(accum_l2/t_n):.4f}', loss=f'{(loss_c_accum/t_c) + (loss_n_accum/t_n):.4f}', loss_c = f'{loss_c_accum/t_c:.4f}', loss_n = f'{loss_n_accum/t_n:.4f}')
        wandb.log({f"{dataset_name}_accuracy": accum_acc/t_c, f"{dataset_name}_rmse": torch.sqrt(accum_l2/t_n), f"{dataset_name}_loss": ((loss_c_accum/t_c) * (num_categorical/num_columns)) + ((loss_n_accum/t_n) * (num_numerical/num_columns)), f"{dataset_name}_loss_c": loss_c_accum/t_c, f"{dataset_name}_loss_n": loss_n_accum/t_n})
        del pred
        accuracy = accum_acc / t_c
        rmse = torch.sqrt(accum_l2 / t_n)
        return [rmse, accuracy]

ic| learnable_params: 496110191


In [23]:
for epoch in range(1, epochs + 1):
    train_loss = train(epoch)
    train_metric = test(train_loader, "train")
    val_metric = test(val_loader, "val")
    #test_metric = test(test_loader, "test")
    ic(
        train_loss, 
        train_metric, 
        val_metric, 
        #test_metric
    )

Epoch 1:   0%|                                                                                                                                                                                                            | 0/4116 [00:00<?, ?it/s]

Epoch 1: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4116/4116 [11:07<00:00,  6.17it/s, loss=0.7395, loss_c=0.5340, loss_n=0.0129]
Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4116/4116 [05:48<00:00, 11.81it/s, accuracy=0.8105, loss=0.4740, loss_c=0.4630, loss_n=0.0110, rmse=0.1048]
Evaluating: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 844/844 [01:10<00:00, 12.02it/s, accuracy=0.8045, loss=0.4972, loss_c=0.4882, loss_n=0.0089, rmse=0.0946]
ic| train_loss: tensor(0.4182, device='cuda:0', grad_fn=<AddBackward0>)
    train_metric: [tensor(0.1048, device='cuda:0'), tensor(0.8105, device='cuda:0')]
    val_metric: [tensor(0.0946, device='cuda:0'), tensor(0.8045, device='cuda:0')]


In [24]:
wandb.finish()

VBox(children=(Label(value='0.004 MB of 0.004 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
learnable_params,▁
test_loader size,▁
train_accuracy,▁
train_loader size,▁
train_loss,█▁
train_loss_c,█▁
train_loss_n,█▁
train_rmse,▁
val_accuracy,▁
val_loader size,▁

0,1
device,cuda
learnable_params,496110191
test_loader size,0
train_accuracy,0.81049
train_loader size,4116
train_loss,0.36259
train_loss_c,0.46305
train_loss_n,0.01099
train_rmse,0.10478
val_accuracy,0.80455
