In [2]:
import torch
import torch.nn.functional as F
from torch_frame import stype
from torch_frame.datasets import Yandex
from torch_frame.data import DataLoader
from torch_frame.nn import (
    EmbeddingEncoder,
    FTTransformer,
    LinearBucketEncoder,
    LinearEncoder,
    LinearPeriodicEncoder,
    ResNet
)
from icecream import ic
from tqdm import tqdm

In [3]:
seed = 42
batch_size = 512
numerical_encoder_type = 'linear'
model_type = 'fttransformer'
channels = 256
num_layers = 4

compile = True
lr = 1e-3
epochs = 10

In [37]:
dataset = Yandex(root='/tmp/yandex', name='adult')
ic(dataset)
ic(dataset.feat_cols)
dataset.materialize()
is_classification = dataset.task_type.is_classification
dataset.df.head(5)

ic| dataset: Yandex(name='adult')
ic| dataset.feat_cols: ['C_feature_0',
                        'C_feature_1',
                        'C_feature_2',
                        'C_feature_3',
                        'C_feature_4',
                        'C_feature_5',
                        'C_feature_6',
                        'C_feature_7',
                        'N_feature_0',
                        'N_feature_1',
                        'N_feature_2',
                        'N_feature_3',
                        'N_feature_4',
                        'N_feature_5']


Unnamed: 0,C_feature_0,C_feature_1,C_feature_2,C_feature_3,C_feature_4,C_feature_5,C_feature_6,C_feature_7,N_feature_0,N_feature_1,N_feature_2,N_feature_3,N_feature_4,N_feature_5,target_col,split_col
0,,Some-college,Never-married,,Other-relative,White,Female,United-States,19.0,140399.0,10.0,0.0,0.0,30.0,0,0
1,Private,Some-college,Married-civ-spouse,Machine-op-inspct,Husband,White,Male,United-States,50.0,158284.0,10.0,0.0,0.0,40.0,0,0
2,Private,Some-college,Married-civ-spouse,Exec-managerial,Husband,White,Male,United-States,62.0,183735.0,10.0,0.0,0.0,40.0,0,0
3,Private,HS-grad,Never-married,Adm-clerical,Not-in-family,White,Female,United-States,20.0,154781.0,9.0,0.0,0.0,40.0,0,0
4,Private,Bachelors,Never-married,Adm-clerical,Own-child,White,Female,United-States,25.0,356344.0,13.0,0.0,0.0,40.0,0,0


In [6]:
import torch_frame
import pandas as pd
import numpy as np
import itertools

class AML(torch_frame.data.Dataset):
        def __init__(self, root):
            self.root = root
            names = [
                'Timestamp',
                'From Bank',
                'From ID',
                'To Bank',
                'To ID',
                'Amount Received',
                'Receiving Currency',
                'Amount Paid',
                'Payment Currency',
                'Payment Format',
                'Is Laundering',
            ]
            dtypes = {
                'From Bank': 'category',
                'From ID': 'category',
                'To Bank': 'category',
                'To ID': 'category',
                'Amount Received': 'float64',
                'Receiving Currency': 'category',
                'Amount Paid': 'float64',
                'Payment Currency': 'category',
                'Payment Format': 'category',
                'Is Laundering': 'category',
            }
            self.df = pd.read_csv(root, names=names, dtype=dtypes, header=0)
            laundering = self.df[self.df['Is Laundering'] == '1']
            ic(laundering.shape)
            self.temporal_balanced_split()
            
            col_to_stype = {
                'Is Laundering': torch_frame.categorical,
                'From ID': torch_frame.categorical,
                'To ID': torch_frame.categorical,
                'From Bank': torch_frame.categorical,
                'To Bank': torch_frame.categorical,
                'Payment Currency': torch_frame.categorical,
                'Receiving Currency': torch_frame.categorical,
                'Payment Format': torch_frame.categorical,
                'Timestamp': torch_frame.numerical,
                'Amount Paid': torch_frame.numerical,
                'Amount Received': torch_frame.numerical
            }
            super().__init__(self.df, col_to_stype, split_col='split', target_col='Is Laundering')
        
        def random_split(self):
            self.df['split'] = torch_frame.utils.generate_random_split(self.df.shape[0], seed)

        def temporal_split(self):
            self.df = self.df.sort_values(by='Timestamp')
            train_size = int(self.df.shape[0] * 0.3)
            validation_size = int(self.df.shape[0] * 0.1)
            test_size = self.df.shape[0] - train_size - validation_size

            #add split column, use 0 for train, 1 for validation, 2 for test
            self.df['split'] = [0] * train_size + [1] * validation_size + [2] * test_size 

        def temporal_balanced_split(self):
            assert 'Timestamp' in self.df.columns, \
            '"transaction" split is only available for datasets with a "Timestamp" column'
            self.df['Timestamp'] = self.df['Timestamp'] - self.df['Timestamp'].min()

            timestamps = torch.Tensor(self.df['Timestamp'].to_numpy())
            n_days = int(timestamps.max() / (3600 * 24) + 1)

            daily_inds, daily_trans = [], [] #irs = illicit ratios, inds = indices, trans = transactions
            for day in range(n_days):
                l = day * 24 * 3600
                r = (day + 1) * 24 * 3600
                day_inds = torch.where((timestamps >= l) & (timestamps < r))[0]
                daily_inds.append(day_inds)
                daily_trans.append(day_inds.shape[0])
            
            split_per = [0.6, 0.2, 0.2]
            daily_totals = np.array(daily_trans)
            d_ts = daily_totals
            I = list(range(len(d_ts)))
            split_scores = dict()
            for i,j in itertools.combinations(I, 2):
                if j >= i:
                    split_totals = [d_ts[:i].sum(), d_ts[i:j].sum(), d_ts[j:].sum()]
                    split_totals_sum = np.sum(split_totals)
                    split_props = [v/split_totals_sum for v in split_totals]
                    split_error = [abs(v-t)/t for v,t in zip(split_props, split_per)]
                    score = max(split_error) #- (split_totals_sum/total) + 1
                    split_scores[(i,j)] = score
                else:
                    continue

            i,j = min(split_scores, key=split_scores.get)
            # add split column, use 0 for train, 1 for validation, 2 for test
            self.df['split'] = [0] * daily_totals[:i].sum() + [1] * daily_totals[i:j].sum() + [2] * daily_totals[j:].sum()

dataset = AML(root='/mnt/data/ibm-transactions-for-anti-money-laundering-aml/HI-Small_Trans-cleaned.csv')
ic(dataset)
dataset.materialize()
is_classification = dataset.task_type.is_classification
ic(is_classification)
dataset.df.head(5)

ic| laundering.shape: (5177, 11)
ic| dataset: AML()
ic| is_classification: True


Unnamed: 0,Timestamp,From Bank,From ID,To Bank,To ID,Amount Received,Receiving Currency,Amount Paid,Payment Currency,Payment Format,Is Laundering,split
0,1200,B_10,8000EBD30,B_10,8000EBD30,3697.34,US Dollar,3697.34,US Dollar,Reinvestment,0,0
1,1200,B_3208,8000F4580,B_1,8000F5340,0.01,US Dollar,0.01,US Dollar,Cheque,0,0
2,0,B_3209,8000F4670,B_3209,8000F4670,14675.57,US Dollar,14675.57,US Dollar,Reinvestment,0,0
3,120,B_12,8000F5030,B_12,8000F5030,2806.97,US Dollar,2806.97,US Dollar,Reinvestment,0,0
4,360,B_10,8000F5200,B_10,8000F5200,36682.97,US Dollar,36682.97,US Dollar,Reinvestment,0,0


In [8]:
torch.manual_seed(seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [9]:
trrain_dataset, val_dataset, test_dataset = dataset.split()

In [10]:
train_tensor_frame = trrain_dataset.tensor_frame
val_tensor_frame = val_dataset.tensor_frame
test_tensor_frame = test_dataset.tensor_frame
train_loader = DataLoader(train_tensor_frame, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_tensor_frame, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_tensor_frame, batch_size=batch_size, shuffle=False)
ic(len(train_loader), len(val_loader), len(test_loader))

ic| len(train_loader): 6346
    len(val_loader): 1886
    len(test_loader): 1688


(6346, 1886, 1688)

In [11]:
if numerical_encoder_type == 'linear':
    numerical_encoder = LinearEncoder()
elif numerical_encoder_type == 'linear_bucket':
    numerical_encoder = LinearBucketEncoder()
elif numerical_encoder_type == 'periodic':
    numerical_encoder = LinearPeriodicEncoder()
else:
    raise ValueError(f'Unknown numerical encoder type: {numerical_encoder_type}')

stype_encoder_dict = {
    stype.categorical: EmbeddingEncoder(),
    stype.numerical: numerical_encoder
}

if is_classification:
    output_channels = dataset.num_classes
else:
    output_channels = 1

In [22]:
if model_type == 'fttransformer':
    model = FTTransformer(
        channels=channels,
        out_channels=output_channels,
        num_layers=num_layers,
        col_stats=dataset.col_stats,
        col_names_dict=train_tensor_frame.col_names_dict,
        stype_encoder_dict=stype_encoder_dict
    ).to(device)
elif model_type == 'resnet':
    model = ResNet(
        channels=channels,
        out_channels=output_channels,
        col_stats=dataset.col_stats,
        col_names_dict=train_tensor_frame.col_names_dict,
        stype_encoder_dict=stype_encoder_dict
    ).to(device)
else:
    raise ValueError(f'Unknown model type: {model_type}')

model = torch.compile(model, dynamic=True) if compile else model
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

def train(epoc: int) -> float:
    model.train()
    loss_accum = total_count = 0

    with tqdm(train_loader, desc=f'Epoch {epoc}') as t:
        for tf in t:
            tf = tf.to(device)
            pred = model(tf)
            if is_classification:
                loss = F.cross_entropy(pred, tf.y)
            else:
                loss = F.mse_loss(pred.view(-1), tf.y.view(-1))
            optimizer.zero_grad()
            loss.backward()
            loss_accum += float(loss) * len(tf.y)
            total_count += len(tf.y)
            optimizer.step()
            t.set_postfix(loss=f'{loss_accum/total_count:.4f}')
    return loss_accum / total_count

@torch.no_grad()
def test(loader: DataLoader) -> float:
    model.eval()
    accum = total_count = 0
    confusion_matrix = [[0 for _ in range(dataset.num_classes)] for _ in range(dataset.num_classes)]
    with tqdm(loader, desc=f'Evaluating') as t:
        for tf in t:
            tf = tf.to(device)
            pred = model(tf)
            total_count += len(tf.y)
            if is_classification:
                pred_class = pred.argmax(dim=-1)
                #update confusion matrix
                for r, p in zip(tf.y, pred_class):
                    confusion_matrix[r][p] += 1
                #display confusion matrix
                #t.set_postfix(confusion_matrix=confusion_matrix)
                accum += float((tf.y == pred_class).sum())
                t.set_postfix(accuracy=f'{accum/total_count:.4f}')
            else:
                accum += float(F.mse_loss(pred.view(-1), tf.y.view(-1), reduction='sum'))

        if is_classification:
            accuracy = accum / total_count
            return [confusion_matrix, accuracy]
        else:
            rmse = (accum / total_count) **0.5
            return rmse

In [26]:
if is_classification:
    metric = 'Acc'
    best_val_metric = (None, 0)
    best_test_metric = (None, 0)
else:
    metric = 'RMSE'
    best_val_metric = float('inf')
    best_test_metric = float('inf')

for epoch in range(1, epochs + 1):
    train_loss = train(epoch)
    train_metric = test(train_loader)
    val_metric = test(val_loader)
    ic(val_metric)
    test_metric = test(test_loader)
    ic(test_metric)

    if is_classification and val_metric[1] > best_val_metric[1]:
        best_val_metric = val_metric
        best_test_metric = test_metric
    elif not is_classification and val_metric < best_val_metric:
        best_val_metric = val_metric
        best_test_metric = test_metric

    ic(train_loss, train_metric, val_metric, test_metric)

ic(best_val_metric, best_test_metric)

Epoch 1: 100%|██████████| 6346/6346 [03:16<00:00, 32.24it/s, loss=0.0077]
Evaluating: 100%|██████████| 6346/6346 [01:14<00:00, 85.35it/s, accuracy=0.9990]
Evaluating: 100%|██████████| 1886/1886 [00:21<00:00, 86.00it/s, accuracy=0.9990]
ic| val_metric: [[[964590, 0], [934, 0]], 0.9990326496285955]
Evaluating: 100%|██████████| 1688/1688 [00:19<00:00, 85.89it/s, accuracy=0.9987]
ic| test_metric: [[[862785, 0], [1115, 0]], 0.9987093413589536]
ic| train_loss: 0.007718086628077842
    train_metric: [[[3245793, 0], [3128, 0]], 0.9990372188181861]
    val_metric: [[[964590, 0], [934, 0]], 0.9990326496285955]
    test_metric: [[[862785, 0], [1115, 0]], 0.9987093413589536]
Epoch 2: 100%|██████████| 6346/6346 [03:17<00:00, 32.20it/s, loss=0.0077]
Evaluating:  10%|█         | 658/6346 [00:07<01:08, 83.35it/s, accuracy=0.9990]