In [1]:
%load_ext autoreload
%autoreload 2

In [6]:
import torch
import torch.nn.functional as F
from torch_frame import stype
from torch_frame.datasets import Yandex
from torch_frame.data import DataLoader
from torch_frame.nn import (
    EmbeddingEncoder,
    FTTransformer,
    TimestampEncoder,
    LinearBucketEncoder,
    LinearEncoder,
    LinearPeriodicEncoder,
    ResNet
)
from icecream import ic
from tqdm import tqdm

In [3]:
seed = 42
batch_size = 512
numerical_encoder_type = 'linear'
model_type = 'fttransformer'
channels = 256
num_layers = 4

compile = True
lr = 1e-3
epochs = 10

In [4]:
dataset = Yandex(root='/tmp/yandex', name='adult')
ic(dataset)
ic(dataset.feat_cols)
dataset.materialize()
is_classification = dataset.task_type.is_classification
dataset.df.head(5)

ic| dataset: Yandex(name='adult')
ic| dataset.feat_cols: ['C_feature_0',
                        'C_feature_1',
                        'C_feature_2',
                        'C_feature_3',
                        'C_feature_4',
                        'C_feature_5',
                        'C_feature_6',
                        'C_feature_7',
                        'N_feature_0',
                        'N_feature_1',
                        'N_feature_2',
                        'N_feature_3',
                        'N_feature_4',
                        'N_feature_5']


Unnamed: 0,C_feature_0,C_feature_1,C_feature_2,C_feature_3,C_feature_4,C_feature_5,C_feature_6,C_feature_7,N_feature_0,N_feature_1,N_feature_2,N_feature_3,N_feature_4,N_feature_5,target_col,split_col
0,,Some-college,Never-married,,Other-relative,White,Female,United-States,19.0,140399.0,10.0,0.0,0.0,30.0,0,0
1,Private,Some-college,Married-civ-spouse,Machine-op-inspct,Husband,White,Male,United-States,50.0,158284.0,10.0,0.0,0.0,40.0,0,0
2,Private,Some-college,Married-civ-spouse,Exec-managerial,Husband,White,Male,United-States,62.0,183735.0,10.0,0.0,0.0,40.0,0,0
3,Private,HS-grad,Never-married,Adm-clerical,Not-in-family,White,Female,United-States,20.0,154781.0,9.0,0.0,0.0,40.0,0,0
4,Private,Bachelors,Never-married,Adm-clerical,Own-child,White,Female,United-States,25.0,356344.0,13.0,0.0,0.0,40.0,0,0


In [5]:
from torch_frame.datasets import IBMTransactionsAML
dataset = IBMTransactionsAML(root='/mnt/data/ibm-transactions-for-anti-money-laundering-aml/dummy.csv')
ic(dataset)
dataset.materialize()
is_classification = dataset.task_type.is_classification
ic(is_classification)
dataset.df.head(5)

ic| dataset: IBMTransactionsAML()
ic| is_classification: True


Unnamed: 0,Timestamp,From Bank,From ID,To Bank,To ID,Amount Received,Receiving Currency,Amount Paid,Payment Currency,Payment Format,Is Laundering,split
0,1200,B_10,8000EBD30,B_10,8000EBD30,3697.34,US Dollar,3697.34,US Dollar,Reinvestment,0,0
1,1200,B_3208,8000F4580,B_1,8000F5340,0.01,US Dollar,0.01,US Dollar,Cheque,0,0
2,0,B_3209,8000F4670,B_3209,8000F4670,14675.57,US Dollar,14675.57,US Dollar,Reinvestment,0,0
3,120,B_12,8000F5030,B_12,8000F5030,2806.97,US Dollar,2806.97,US Dollar,Reinvestment,0,0
4,360,B_10,8000F5200,B_10,8000F5200,36682.97,US Dollar,36682.97,US Dollar,Reinvestment,0,0


In [7]:
torch.manual_seed(seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [8]:
train_dataset, val_dataset, test_dataset = dataset.split()

In [9]:
train_tensor_frame = train_dataset.tensor_frame
val_tensor_frame = val_dataset.tensor_frame
test_tensor_frame = test_dataset.tensor_frame
train_loader = DataLoader(train_tensor_frame, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_tensor_frame, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_tensor_frame, batch_size=batch_size, shuffle=False)
ic(len(train_loader), len(val_loader), len(test_loader))

ic| len(train_loader): 977, len(val_loader): 1, len(test_loader): 1


(977, 1, 1)

In [10]:
# print an example batch
ic(next(iter(train_loader)).feat_dict)
ic(next(iter(train_loader)).y)

ic| next(iter(train_loader)).feat_dict: {<stype.numerical: 'numerical'>: tensor([[84261.8672, 84261.8672],
                                                [  206.3000,   206.3000],
                                                [  327.1900,   327.1900],
                                                ...,
                                                [31433.6992, 31433.6992],
                                                [ 9969.0400,  9969.0400],
                                                [21614.0293, 21614.0293]]),
                                         <stype.categorical: 'categorical'>: tensor([[    36,  29950,      1,  ...,      1,     37, 102997],
                                                [     0,      0,      0,  ...,      0,     72,  23533],
                                                [   348,  31404,      1,  ...,      1,    350, 131441],
                                                ...,
                                                [   508,  34393,  

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

In [11]:
if numerical_encoder_type == 'linear':
    numerical_encoder = LinearEncoder()
elif numerical_encoder_type == 'linear_bucket':
    numerical_encoder = LinearBucketEncoder()
elif numerical_encoder_type == 'periodic':
    numerical_encoder = LinearPeriodicEncoder()
else:
    raise ValueError(f'Unknown numerical encoder type: {numerical_encoder_type}')

stype_encoder_dict = {
    stype.categorical: EmbeddingEncoder(),
    stype.numerical: numerical_encoder,
    stype.timestamp: TimestampEncoder(),
}

if is_classification:
    output_channels = dataset.num_classes
else:
    output_channels = 1

In [12]:
if model_type == 'fttransformer':
    model = FTTransformer(
        channels=channels,
        out_channels=output_channels,
        num_layers=num_layers,
        col_stats=dataset.col_stats,
        col_names_dict=train_tensor_frame.col_names_dict,
        stype_encoder_dict=stype_encoder_dict
    ).to(device)
elif model_type == 'resnet':
    model = ResNet(
        channels=channels,
        out_channels=output_channels,
        col_stats=dataset.col_stats,
        col_names_dict=train_tensor_frame.col_names_dict,
        stype_encoder_dict=stype_encoder_dict
    ).to(device)
else:
    raise ValueError(f'Unknown model type: {model_type}')

model = torch.compile(model, dynamic=True) if compile else model
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

def train(epoc: int) -> float:
    model.train()
    loss_accum = total_count = 0

    with tqdm(train_loader, desc=f'Epoch {epoc}') as t:
        for tf in t:
            tf = tf.to(device)
            pred = model(tf)
            if is_classification:
                loss = F.cross_entropy(pred, tf.y)
            else:
                loss = F.mse_loss(pred.view(-1), tf.y.view(-1))
            optimizer.zero_grad()
            loss.backward()
            loss_accum += float(loss) * len(tf.y)
            total_count += len(tf.y)
            optimizer.step()
            t.set_postfix(loss=f'{loss_accum/total_count:.4f}')
    return loss_accum / total_count

@torch.no_grad()
def test(loader: DataLoader) -> float:
    model.eval()
    accum = total_count = 0
    confusion_matrix = [[0 for _ in range(dataset.num_classes)] for _ in range(dataset.num_classes)]
    with tqdm(loader, desc=f'Evaluating') as t:
        for tf in t:
            tf = tf.to(device)
            pred = model(tf)
            total_count += len(tf.y)
            if is_classification:
                pred_class = pred.argmax(dim=-1)
                #update confusion matrix
                for r, p in zip(tf.y, pred_class):
                    confusion_matrix[r][p] += 1
                #display confusion matrix
                #t.set_postfix(confusion_matrix=confusion_matrix)
                accum += float((tf.y == pred_class).sum())
                t.set_postfix(accuracy=f'{accum/total_count:.4f}')
            else:
                accum += float(F.mse_loss(pred.view(-1), tf.y.view(-1), reduction='sum'))

        if is_classification:
            accuracy = accum / total_count
            return [confusion_matrix, accuracy]
        else:
            rmse = (accum / total_count) **0.5
            return rmse

In [13]:
if is_classification:
    metric = 'Acc'
    best_val_metric = (None, 0)
    best_test_metric = (None, 0)
else:
    metric = 'RMSE'
    best_val_metric = float('inf')
    best_test_metric = float('inf')

for epoch in range(1, epochs + 1):
    train_loss = train(epoch)
    train_metric = test(train_loader)
    val_metric = test(val_loader)
    ic(val_metric)
    test_metric = test(test_loader)
    ic(test_metric)

    if is_classification and val_metric[1] > best_val_metric[1]:
        best_val_metric = val_metric
        best_test_metric = test_metric
    elif not is_classification and val_metric < best_val_metric:
        best_val_metric = val_metric
        best_test_metric = test_metric

    ic(train_loss, train_metric, val_metric, test_metric)

ic(best_val_metric, best_test_metric)

Epoch 1:   0%|          | 0/977 [00:00<?, ?it/s]

Epoch 1: 100%|██████████| 977/977 [00:34<00:00, 28.56it/s, loss=0.0040]
Evaluating: 100%|██████████| 977/977 [00:13<00:00, 71.29it/s, accuracy=0.9996]
Evaluating: 100%|██████████| 1/1 [00:00<00:00, 393.91it/s, accuracy=1.0000]
ic| val_metric: [[[61, 0], [0, 0]], 1.0]
Evaluating: 100%|██████████| 1/1 [00:00<00:00, 334.90it/s, accuracy=1.0000]
ic| test_metric: [[[95, 0], [0, 0]], 1.0]
ic| train_loss: 0.003999752386029505
    train_metric: [[[499650, 0], [193, 0]], 0.99961387875793]
    val_metric: [[[61, 0], [0, 0]], 1.0]
    test_metric: [[[95, 0], [0, 0]], 1.0]
Epoch 2: 100%|██████████| 977/977 [00:21<00:00, 44.63it/s, loss=0.0027]
Evaluating: 100%|██████████| 977/977 [00:11<00:00, 84.95it/s, accuracy=0.9996]
Evaluating: 100%|██████████| 1/1 [00:00<00:00, 397.83it/s, accuracy=1.0000]
ic| val_metric: [[[61, 0], [0, 0]], 1.0]
Evaluating: 100%|██████████| 1/1 [00:00<00:00, 341.69it/s, accuracy=1.0000]
ic| test_metric: [[[95, 0], [0, 0]], 1.0]
ic| train_loss: 0.0027073161381999684
    trai

([[[61, 0], [0, 0]], 1.0], [[[95, 0], [0, 0]], 1.0])