In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os

import torch
import torch.nn.functional as F
import numpy as np

from torch_frame.data import DataLoader
from torch_geometric.data import Data
from transformers import get_inverse_sqrt_schedule

from src.datasets import IBMTransactionsAML
from src.nn.gnn.model import GINe
from src.utils.loss import lp_loss
from src.utils.metric import mrr

from tqdm import tqdm
import wandb
from icecream import ic
import sys

torch.set_float32_matmul_precision('high')

In [3]:
seed = 42
batch_size = 1024
lr = 5e-4
eps = 1e-8
epochs = 5

compile = False
data_split = [0.6, 0.2, 0.2]

device = 'cuda' if torch.cuda.is_available() else 'cpu'
args = {
    'testing': True,
    'batch_size': batch_size,
    'seed': seed,
    'device': device,
    'lr': lr,
    'eps': eps,
    'epochs': epochs,
    'compile': compile,
    'data_split': data_split
}

In [40]:
wandb.login()
run = wandb.init(
    mode="disabled" if args['testing'] else "online",
    project=f"rel-mm", 
    name="model=GINe,dataset=IBM-AML_Hi_Sm,objective=lp", 
    config=args
)

In [5]:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# When running on the CuDNN backend, two further options must be set
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Set a fixed value for the hash seed
os.environ["PYTHONHASHSEED"] = str(seed)

In [6]:
dataset = IBMTransactionsAML(root='/mnt/data/ibm-transactions-for-anti-money-laundering-aml/dummy.csv')
#dataset = IBMTransactionsAML(root='/mnt/data/ibm-transactions-for-anti-money-laundering-aml/HI-Small_Trans-cleaned.csv', pretrain=pretrain, split_type='temporal', splits=data_split)
ic(dataset)
dataset.materialize()
dataset.df.head(5)

ic| dataset: IBMTransactionsAML()


Unnamed: 0,Timestamp,From Bank,From ID,To Bank,To ID,Amount Received,Receiving Currency,Amount Paid,Payment Currency,Payment Format,Is Laundering,split
0,1200,B_10,8000EBD30,B_10,8000EBD30,3.53372e-09,US Dollar,3.53372e-09,US Dollar,Reinvestment,0,0
1,1200,B_3208,8000F4580,B_1,8000F5340,9.556511e-15,US Dollar,9.556511e-15,US Dollar,Cheque,0,0
2,0,B_3209,8000F4670,B_3209,8000F4670,1.402613e-08,US Dollar,1.402613e-08,US Dollar,Reinvestment,0,0
3,120,B_12,8000F5030,B_12,8000F5030,2.682752e-09,US Dollar,2.682752e-09,US Dollar,Reinvestment,0,0
4,360,B_10,8000F5200,B_10,8000F5200,3.505963e-08,US Dollar,3.505963e-08,US Dollar,Reinvestment,0,0


In [7]:
train_dataset, val_dataset, test_dataset = dataset.split()

In [8]:
train_tensor_frame = train_dataset.tensor_frame
train_loader = DataLoader(train_tensor_frame, batch_size=batch_size, shuffle=True)
val_tensor_frame = val_dataset.tensor_frame
val_loader = DataLoader(val_tensor_frame, batch_size=batch_size, shuffle=True)
test_tensor_frame = test_dataset.tensor_frame
test_loader = DataLoader(test_tensor_frame, batch_size=batch_size, shuffle=True)

In [9]:
# TODO: generalize the trainable columns
source = train_tensor_frame.get_col_feat('From ID')
destination = train_tensor_frame.get_col_feat('To ID')

#create dummy node features
num_nodes = np.unique(np.concatenate([source, destination])).shape[0]
ic(num_nodes)
node_feat = torch.ones(num_nodes)

edge_index = torch.cat([source, destination], dim=1).t()
ic(edge_index.shape)
g = Data(node_feat, edge_index=edge_index, edge_attr=train_tensor_frame)

ic| num_nodes: 298015
ic| edge_index.shape: torch.Size([2, 499843])


In [30]:
model = GINe(num_features=1, num_gnn_layers=3, edge_dim=train_dataset.tensor_frame.num_cols-3)
model = torch.compile(model, dynamic=True) if compile else model
model.to(args['device'])
learnable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
ic(learnable_params)
wandb.log({"learnable_params": learnable_params})

ic| learnable_params: 125177


In [31]:
def get_gnn_inputs(batch):
    source = batch.get_col_feat('From ID')
    destination = batch.get_col_feat('To ID')
    #ic(source, destination)
    feat_cols = train_dataset.feat_cols

    # TODO: generalize the trainable columns
    feat_cols.remove('Timestamp')
    feat_cols.remove('From ID')
    feat_cols.remove('To ID')

    # TODO: fix, a very crude approach
    feats = [batch.get_col_feat(col_name) for col_name in feat_cols]
    edge_attr = torch.cat(feats, dim=1)
    nodes = torch.unique(torch.cat([source, destination]))
    num_nodes = nodes.shape[0]

    n_id_map = {value.item(): index for index, value in enumerate(nodes)}
    local_source = torch.tensor([n_id_map[node.item()] for node in source], dtype=torch.long)
    local_destination = torch.tensor([n_id_map[node.item()] for node in destination], dtype=torch.long)
    edge_index = torch.cat((local_source.unsqueeze(0), local_destination.unsqueeze(0)))
    node_feats = torch.ones(num_nodes).view(-1,num_nodes).t()

    # TODO: could choose false negatives, the entire graph is not used
    neg_edges = []
    neg_edge_attr = []
    nodeset = set(range(edge_index.max()+1))
    for i, edge in enumerate(edge_index.t()):
        src, dst = edge[0], edge[1]

        # Chose negative examples in a smart way
        unavail_mask = (edge_index == src).any(dim=0) | (edge_index == dst).any(dim=0)
        unavail_nodes = torch.unique(edge_index[:, unavail_mask])
        unavail_nodes = set(unavail_nodes.tolist())
        avail_nodes = nodeset - unavail_nodes
        avail_nodes = torch.tensor(list(avail_nodes))
        # Finally, emmulate np.random.choice() to chose randomly amongst available nodes
        indices = torch.randperm(len(avail_nodes))[:64]
        neg_nodes = avail_nodes[indices]
        
        # Generate 32 negative edges with the same source but different destinations
        neg_dsts = neg_nodes[:32]  # Selecting 32 random destination nodes for the source
        neg_edges_src = torch.stack([src.repeat(32), neg_dsts], dim=0)
        
        # Generate 32 negative edges with the same destination but different sources
        neg_srcs = neg_nodes[32:]  # Selecting 32 random source nodes for the destination
        neg_edges_dst = torch.stack([neg_srcs, dst.repeat(32)], dim=0)

        # Add these negative edges to the list
        neg_edges.append(neg_edges_src)
        neg_edges.append(neg_edges_dst)
        # Replicate the positive edge attribute for each of the negative edges generated from this edge
        pos_attr = edge_attr[i].unsqueeze(0)  # Get the attribute of the current positive edge
        replicated_attr = pos_attr.repeat(64, 1)  # Replicate it 64 times (for each negative edge)
        neg_edge_attr.append(replicated_attr)
    
    edge_index = edge_index.to(device)
    edge_attr = edge_attr.to(device)
    node_feats = node_feats.to(device)
    neg_edge_index = torch.cat(neg_edges, dim=1).to(device)
    neg_edge_attr = torch.cat(neg_edge_attr, dim=0).to(device)
    return edge_index, edge_attr, node_feats, neg_edge_index, neg_edge_attr
# batch = next(iter(train_loader))
# edge_index, edge_attr, node_feats, neg_edge_index, neg_edge_attr = get_gnn_inputs(batch)
# ic(edge_index, edge_attr, node_feats, neg_edge_index, neg_edge_attr)

In [58]:
def train(epoc: int, model, optimizer) -> float:
    model.train()
    loss_accum = total_count = 0

    with tqdm(train_loader, desc=f'Epoch {epoc}') as t:
        for tf in t:
            tf = tf.to(device)
            edge_index, edge_attr, node_feats, neg_edge_index, neg_edge_attr = get_gnn_inputs(tf)
            pred = model(node_feats, edge_index, edge_attr)
            neg_pred = model(node_feats, neg_edge_index, neg_edge_attr)
            loss = lp_loss(pred, neg_pred)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss_accum += float(loss) * len(pred)
            total_count += len(pred)
            t.set_postfix(loss=f'{loss_accum/total_count:.4f}')
            del pred
            del tf
        wandb.log({"train_loss": loss_accum/total_count})
    return loss_accum / total_count

@torch.no_grad()
def test(loader: DataLoader, model, dataset_name) -> float:
    model.eval()
    accum_acc = 0
    mrrs = []
    hits1 = []
    hits2 = []
    hits5 = []
    hits10 = []
    loss_accum = 0
    total_count = 0
    with tqdm(loader, desc=f'Evaluating') as t:
        for tf in t:
            tf = tf.to(device)
            edge_index, edge_attr, node_feats, neg_edge_index, neg_edge_attr = get_gnn_inputs(tf)
            pred = model(node_feats, edge_index, edge_attr)
            neg_pred = model(node_feats, neg_edge_index, neg_edge_attr)
            mrr_score, hits = mrr(pred, neg_pred, [1,2,5,10])
            mrrs.append(mrr_score)
            hits1.append(hits['hits@1'])
            hits2.append(hits['hits@2'])
            hits5.append(hits['hits@5'])
            hits10.append(hits['hits@10'])
            loss = lp_loss(pred, neg_pred)
            loss_accum += float(loss) * len(pred)
            accum_acc += pred.sum().item()
            accum_acc += len(neg_pred) - neg_pred.sum().item()
            total_count += len(pred) + len(neg_pred)
            t.set_postfix(
                accuracy=f'{accum_acc/total_count:.4f}',
                loss=f'{loss_accum/total_count:.4f}',
                mrr=f'{mrr_score:.4f}',
                hits1=f'{hits["hits@1"]:.4f}',
                hits2=f'{hits["hits@2"]:.4f}',
                hits5=f'{hits["hits@5"]:.4f}',
                hits10=f'{hits["hits@10"]:.4f}'
            )
        mrr_score = np.mean(mrrs)
        hits1 = np.mean(hits1)
        hits2 = np.mean(hits2)
        hits5 = np.mean(hits5)
        hits10 = np.mean(hits10)
        wandb.log({
            f"{dataset_name}_accuracy": accum_acc/total_count, 
            f"{dataset_name}_loss": loss_accum/total_count,
            f"{dataset_name}_mrr": mrr,
            f"{dataset_name}_hits@1": hits1,
            f"{dataset_name}_hits@2": hits2,
            f"{dataset_name}_hits@5": hits5,
            f"{dataset_name}_hits@10": hits10
        })
        del tf
        del pred
        accuracy = accum_acc / total_count
        return accuracy, mrr_score, hits1, hits2, hits5, hits10

In [59]:
model = GINe(num_features=1, num_gnn_layers=3, edge_dim=train_dataset.tensor_frame.num_cols-3, n_classes=1)
model = torch.compile(model, dynamic=True) if compile else model
model.to(device)
learnable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
ic(learnable_params)
wandb.log({"learnable_params": learnable_params})

no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
    {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=lr, eps=eps)
scheduler = get_inverse_sqrt_schedule(optimizer, num_warmup_steps=0, timescale=1000)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

train_metric = test(train_loader, model, "train")
val_metric = test(val_loader, model, "val")
test_metric = test(test_loader, model, "test")
ic(
        train_metric, 
        val_metric, 
        test_metric
)

for epoch in range(1, epochs + 1):
    train_loss = train(epoch, model, optimizer)
    train_metric = test(train_loader, model, "train")
    val_metric = test(val_loader, model, "val")
    test_metric = test(test_loader, model, "test")
    ic(
        train_loss, 
        train_metric, 
        val_metric, 
        test_metric
    )

ic| learnable_params: 125151
Evaluating:   0%|                                                                                                                                                                                                       | 0/489 [00:00<?, ?it/s]

Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 489/489 [01:53<00:00,  4.29it/s, accuracy=0.0387, hits1=0.0000, hits10=0.0992, hits2=0.0000, hits5=0.0000, loss=0.3702, mrr=0.0267]
Evaluating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 127.12it/s, accuracy=0.0220, hits1=0.0000, hits10=0.0492, hits2=0.0000, hits5=0.0000, loss=0.4090, mrr=0.0218]
Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 91.07it/s, accuracy=0.0232, hits1=0.0000, hits10=0.0632, hits2=0.0000, hits5=0.0000, loss=0.4177, mrr=0.0238]
Epoch 1: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 489/489 [01:52<00:00,  4.35it/s, loss=1.4193]
Evaluating: 100%|███████████████████

In [37]:
wandb.finish()