In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os

import torch
import torch.nn.functional as F
import numpy as np

from torch_frame.data import DataLoader
from torch_geometric.data import Data
from torch_geometric.utils import negative_sampling
from transformers import get_inverse_sqrt_schedule

from src.datasets import IBMTransactionsAML
from src.nn.gnn.model import GINe
from src.utils import lp_loss

from tqdm import tqdm
import wandb
from icecream import ic

In [3]:
seed = 42
batch_size = 1024
lr = 5e-4
eps = 1e-8
epochs = 20

compile = True
data_split = [0.6, 0.2, 0.2]

device = 'cuda' if torch.cuda.is_available() else 'cpu'
args = {
    'testing': True,
    'batch_size': batch_size,
    'seed': seed,
    'device': device,
    'lr': lr,
    'eps': eps,
    'epochs': epochs,
    'compile': compile,
    'data_split': data_split
}

In [4]:
wandb.login()
run = wandb.init(
    mode="disabled" if args['testing'] else "online",
    project=f"rel-mm", 
    name="model=GINe,dataset=IBM-AML_Hi_Sm,objective=lp", 
    config=args
)

[34m[1mwandb[0m: Currently logged in as: [33maakyildiz[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [5]:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# When running on the CuDNN backend, two further options must be set
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Set a fixed value for the hash seed
os.environ["PYTHONHASHSEED"] = str(seed)

In [6]:
dataset = IBMTransactionsAML(root='/mnt/data/ibm-transactions-for-anti-money-laundering-aml/dummy.csv')
#dataset = IBMTransactionsAML(root='/mnt/data/ibm-transactions-for-anti-money-laundering-aml/HI-Small_Trans-cleaned.csv', pretrain=pretrain, split_type='temporal', splits=data_split)
ic(dataset)
dataset.materialize()
dataset.df.head(5)

ic| dataset: IBMTransactionsAML()


In [None]:
train_dataset, val_dataset, test_dataset = dataset.split()

In [None]:
train_tensor_frame = train_dataset.tensor_frame
train_loader = DataLoader(train_tensor_frame, batch_size=batch_size, shuffle=True)
val_tensor_frame = val_dataset.tensor_frame
val_loader = DataLoader(val_tensor_frame, batch_size=batch_size, shuffle=True)
test_tensor_frame = test_dataset.tensor_frame
test_loader = DataLoader(test_tensor_frame, batch_size=batch_size, shuffle=True)

In [None]:
# TODO: generalize the trainable columns
source = train_tensor_frame.get_col_feat('From ID')
destination = train_tensor_frame.get_col_feat('To ID')

#create dummy node features
num_nodes = np.unique(np.concatenate([source, destination])).shape[0]
ic(num_nodes)
node_feat = torch.ones(num_nodes)

edge_index = torch.cat([source, destination], dim=1).t()
ic(edge_index.shape)
g = Data(node_feat, edge_index=edge_index, edge_attr=train_tensor_frame)

ic| num_nodes: 298015
ic| edge_index.shape: torch.Size([2, 499843])


In [None]:
model = GINe(num_features=1, num_gnn_layers=3, edge_dim=train_dataset.tensor_frame.num_cols-3)
model = torch.compile(model, dynamic=True) if compile else model
model.to(args['device'])
learnable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
ic(learnable_params)
wandb.log({"learnable_params": learnable_params})

ic| learnable_params: 125177


In [None]:
batch = next(iter(train_loader))
ic(batch)

ic| batch: TensorFrame(
             num_cols=10,
             num_rows=1024,
             categorical (7): ['From Bank', 'From ID',

 'Payment Currency', 'Payment Format', 'Receiving Currency', 'To Bank', 'To ID'],
             timestamp (1): ['Timestamp'],
             numerical (2): ['Amount Paid', 'Amount Received'],
             has_target=True,
             device='cpu',
           )


TensorFrame(
  num_cols=10,
  num_rows=1024,
  categorical (7): ['From Bank', 'From ID', 'Payment Currency', 'Payment Format', 'Receiving Currency', 'To Bank', 'To ID'],
  timestamp (1): ['Timestamp'],
  numerical (2): ['Amount Paid', 'Amount Received'],
  has_target=True,
  device='cpu',
)

In [None]:
def get_gnn_inputs(batch):
    source = batch.get_col_feat('From ID')
    destination = batch.get_col_feat('To ID')
    #ic(source, destination)
    feat_cols = train_dataset.feat_cols

    # TODO: generalize the trainable columns
    feat_cols.remove('Timestamp')
    feat_cols.remove('From ID')
    feat_cols.remove('To ID')

    # TODO: fix, a very crude approach
    feats = [batch.get_col_feat(col_name) for col_name in feat_cols]
    edge_attr = torch.cat(feats, dim=1).to(device)
    nodes = torch.unique(torch.cat([source, destination]))
    num_nodes = nodes.shape[0]

    n_id_map = {value.item(): index for index, value in enumerate(nodes)}
    local_source = torch.tensor([n_id_map[node.item()] for node in source], dtype=torch.long)
    local_destination = torch.tensor([n_id_map[node.item()] for node in destination], dtype=torch.long)
    edge_index = torch.cat((local_source.unsqueeze(0), local_destination.unsqueeze(0))).to(device)
    node_feats = torch.ones(num_nodes).view(-1,num_nodes).t().to(device)
    neg_edge_index = negative_sampling(edge_index=edge_index, num_nodes=num_nodes, num_neg_samples=local_source.shape[0])
    return edge_index, edge_attr, node_feats, neg_edge_index
edge_index, edge_attr, node_feats, neg_edge_index = get_gnn_inputs(batch)
ic(edge_index, edge_attr, node_feats, neg_edge_index)

ic| edge_index: tensor([[1029, 1641, 1738,  ...,  722, 1722,  285],
                        [1035, 1704,  706,  ...,  204, 1688,  495]], device='cuda:0')
    edge_attr: tensor([[5.3000e+01, 5.0000e+01, 0.0000e+00,  ..., 0.0000e+00, 1.3525e-10,
                        1.3525e-10],
                       [4.9100e+02, 4.8900e+02, 0.0000e+00,  ..., 0.0000e+00, 2.5417e-07,
                        2.5417e-07],
                       [8.2400e+02, 7.1400e+02, 4.0000e+00,  ..., 0.0000e+00, 1.5371e-06,
                        1.5371e-06],
                       ...,
                       [1.0000e+00, 0.0000e+00, 1.4000e+01,  ..., 0.0000e+00, 8.4651e-07,
                        8.4651e-07],
                       [4.5900e+02, 4.3600e+02, 4.0000e+00,  ..., 0.0000e+00, 5.9275e-11,
                        5.9275e-11],
                       [2.4800e+02, 7.0000e+02, 4.0000e+00,  ..., 1.0000e+00, 8.8068e-08,
                        8.8068e-08]], device='cuda:0')
    node_feats: tensor([[1.],
        

(tensor([[1029, 1641, 1738,  ...,  722, 1722,  285],
         [1035, 1704,  706,  ...,  204, 1688,  495]], device='cuda:0'),
 tensor([[5.3000e+01, 5.0000e+01, 0.0000e+00,  ..., 0.0000e+00, 1.3525e-10,
          1.3525e-10],
         [4.9100e+02, 4.8900e+02, 0.0000e+00,  ..., 0.0000e+00, 2.5417e-07,
          2.5417e-07],
         [8.2400e+02, 7.1400e+02, 4.0000e+00,  ..., 0.0000e+00, 1.5371e-06,
          1.5371e-06],
         ...,
         [1.0000e+00, 0.0000e+00, 1.4000e+01,  ..., 0.0000e+00, 8.4651e-07,
          8.4651e-07],
         [4.5900e+02, 4.3600e+02, 4.0000e+00,  ..., 0.0000e+00, 5.9275e-11,
          5.9275e-11],
         [2.4800e+02, 7.0000e+02, 4.0000e+00,  ..., 1.0000e+00, 8.8068e-08,
          8.8068e-08]], device='cuda:0'),
 tensor([[1.],
         [1.],
         [1.],
         ...,
         [1.],
         [1.],
         [1.]], device='cuda:0'),
 tensor([[1150,  178, 1016,  ...,  356,  882, 1080],
         [ 449,  235,  489,  ...,  567,  170,  120]], device='cuda:0'))

In [None]:
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
    {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=lr, eps=eps)
scheduler = get_inverse_sqrt_schedule(optimizer, num_warmup_steps=0, timescale=1000)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

def train(epoc: int) -> float:
    model.train()
    loss_accum = total_count = 0

    with tqdm(train_loader, desc=f'Epoch {epoc}') as t:
        for tf in t:
            tf = tf.to(device)
            edge_index, edge_attr, node_feats, neg_edge_index = get_gnn_inputs(tf)
            pred = model(node_feats, edge_index, edge_attr)
            neg_pred = model(node_feats, neg_edge_index, edge_attr)
            #loss = calc_loss(pred, tf.y)
            loss = lp_loss(pred, neg_pred)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss_accum += float(loss) * len(tf.y)
            total_count += len(tf.y)
            t.set_postfix(loss=f'{loss_accum/total_count:.4f}')
            del pred
            del tf
        wandb.log({"train_loss": loss_accum/total_count})
    return loss_accum / total_count

@torch.no_grad()
def test(loader: DataLoader, dataset_name) -> float:
    model.eval()
    accum_acc = 0
    loss_accum = 0
    total_count = 0
    with tqdm(loader, desc=f'Evaluating') as t:
        for tf in t:
            tf = tf.to(device)
            edge_index, edge_attr, node_feats, neg_edge_index = get_gnn_inputs(tf)
            pred = model(node_feats, edge_index, edge_attr)
            neg_pred = model(node_feats, neg_edge_index, edge_attr)
            loss = lp_loss(pred, neg_pred)
            loss_accum += float(loss) * (2 * len(pred))
            accum_acc += pred.argmax(dim=1).sum().item()
            accum_acc += len(neg_pred) - neg_pred.argmax(dim=1).sum().item()
            total_count += len(pred) + len(neg_pred)
            t.set_postfix(accuracy=f'{accum_acc/total_count:.4f}')
        wandb.log({f"{dataset_name}_accuracy": accum_acc/total_count})
        del tf
        del pred
        accuracy = accum_acc / total_count
        return accuracy

In [None]:
for epoch in range(1, epochs + 1):
    train_loss = train(epoch)
    train_metric = test(train_loader, "train")
    val_metric = test(val_loader, "val")
    test_metric = test(test_loader, "test")
    ic(
        train_loss, 
        train_metric, 
        val_metric, 
        test_metric
    )

Epoch 1: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 489/489 [00:14<00:00, 33.71it/s, loss=1.3584]
Evaluating: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 489/489 [00:12<00:00, 38.27it/s, accuracy=0.5131]
Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 250.57it/s, accuracy=0.5164]
Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 226.16it/s, accuracy=0.4947]
ic| train_loss: 1.358403

KeyboardInterrupt: 