In [31]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [32]:
import os

import torch
import torch.nn.functional as F
import numpy as np

from torch_frame.data import DataLoader
from torch_frame import TensorFrame
from torch_geometric.data import Data
from transformers import get_inverse_sqrt_schedule

from src.datasets import IBMTransactionsAML
from src.nn.gnn.model import GINe
from src.utils.loss import lp_loss
from src.utils.metric import mrr

from tqdm import tqdm
import wandb
from icecream import ic
import sys

torch.set_float32_matmul_precision('high')

In [33]:
seed = 42
batch_size = 1024
lr = 5e-4
eps = 1e-8
epochs = 5

compile = True
data_split = [0.6, 0.2, 0.2]
split_type = 'temporal'

pos_sample_prob = 0.15
num_neg_samples = 64
channels = 256

device = 'cuda' if torch.cuda.is_available() else 'cpu'
args = {
    'testing': False,
    'batch_size': batch_size,
    'seed': seed,
    'device': device,
    'lr': lr,
    'eps': eps,
    'epochs': epochs,
    'compile': compile,
    'data_split': data_split,
    'pos_sample_prob': pos_sample_prob,
    'channels': channels,
    'split_type': split_type,
    'num_neg_samples': num_neg_samples
}

In [34]:
wandb.login()
run = wandb.init(
    mode="disabled" if args['testing'] else "online",
    project=f"rel-mm", 
    #name="model=GINe,dataset=IBM-AML_Hi_Sm,objective=lp",
    name="debug-temporal-LOL-channels256-LOL",
    config=args
)



In [35]:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# When running on the CuDNN backend, two further options must be set
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Set a fixed value for the hash seed
os.environ["PYTHONHASHSEED"] = str(seed)

In [36]:
#dataset = IBMTransactionsAML(root='/mnt/data/ibm-transactions-for-anti-money-laundering-aml/dummy.csv')
dataset = IBMTransactionsAML(root='/mnt/data/ibm-transactions-for-anti-money-laundering-aml/HI-Small_Trans-cleaned.csv', split_type=split_type, splits=data_split)
ic(dataset)
dataset.materialize()
dataset.df.head(5)

ic| dataset: IBMTransactionsAML()


Unnamed: 0,Timestamp,From Bank,From ID,To Bank,To ID,Amount Received,Receiving Currency,Amount Paid,Payment Currency,Payment Format,Is Laundering,split
0,1200,B_10,8000EBD30,B_10,8000EBD30,0.296848,US Dollar,0.296848,US Dollar,Reinvestment,0,0
1,1200,B_3208,8000F4580,B_1,8000F5340,0.000359,US Dollar,0.000359,US Dollar,Cheque,0,0
2,0,B_3209,8000F4670,B_3209,8000F4670,0.346651,US Dollar,0.346651,US Dollar,Reinvestment,0,0
3,120,B_12,8000F5030,B_12,8000F5030,0.286896,US Dollar,0.286896,US Dollar,Reinvestment,0,0
4,360,B_10,8000F5200,B_10,8000F5200,0.379751,US Dollar,0.379751,US Dollar,Reinvestment,0,0


In [37]:
train_dataset, val_dataset, test_dataset = dataset.split()

In [38]:
train_tensor_frame = train_dataset.tensor_frame
train_loader = DataLoader(train_tensor_frame, batch_size=batch_size, shuffle=True)
val_tensor_frame = val_dataset.tensor_frame
val_loader = DataLoader(val_tensor_frame, batch_size=batch_size, shuffle=True)
test_tensor_frame = test_dataset.tensor_frame
test_loader = DataLoader(test_tensor_frame, batch_size=batch_size, shuffle=True)

In [39]:
# # TODO: generalize the trainable columns
# source = train_tensor_frame.get_col_feat('From ID')
# destination = train_tensor_frame.get_col_feat('To ID')

# #create dummy node features
# num_nodes = np.unique(np.concatenate([source, destination])).shape[0]
# ic(num_nodes)
# node_feat = torch.ones(num_nodes)

# edge_index = torch.cat([source, destination], dim=1).t()
# ic(edge_index.shape)
# g = Data(node_feat, edge_index=edge_index, edge_attr=train_tensor_frame)

In [40]:
from torch_frame.nn.encoder.stypewise_encoder import StypeWiseFeatureEncoder
from torch_frame import stype
from torch_frame.nn import (
    EmbeddingEncoder,
    LinearEncoder,
    TimestampEncoder,
)
stype_encoder_dict = {
    stype.categorical: EmbeddingEncoder(),
    stype.numerical: LinearEncoder(),
    stype.timestamp: TimestampEncoder(),
}
encoder = StypeWiseFeatureEncoder(
            out_channels=channels,
            col_stats=dataset.col_stats,
            col_names_dict=train_tensor_frame.col_names_dict,
            stype_encoder_dict=stype_encoder_dict,
)
def get_gnn_inputs(tf: TensorFrame, pos_sample_prob=0.15):
    source = tf.get_col_feat('From ID')
    destination = tf.get_col_feat('To ID')

    edge_attr, col_names = encoder(tf)
    edge_attr = edge_attr.view(-1, len(col_names) * channels)

    nodes = torch.unique(torch.cat([source, destination]))
    num_nodes = nodes.shape[0]
    node_feats = torch.ones(num_nodes).view(-1,num_nodes).t()

    n_id_map = {value.item(): index for index, value in enumerate(nodes)}
    local_source = torch.tensor([n_id_map[node.item()] for node in source], dtype=torch.long)
    local_destination = torch.tensor([n_id_map[node.item()] for node in destination], dtype=torch.long)
    edge_index = torch.cat((local_source.unsqueeze(0), local_destination.unsqueeze(0)))

    # sample positive edges
    E = edge_index.shape[1]
    positions = torch.arange(E)
    num_samples = int(len(positions) * pos_sample_prob)
    if len(positions) > 0 and num_samples > 0:
        drop_idxs = torch.multinomial(torch.full((len(positions),), 1.0), num_samples, replacement=False)
    else:
        drop_idxs = torch.tensor([]).long()
    drop_edge_ind = positions[drop_idxs]

    mask = torch.zeros((E,)).long() #[E, ]
    mask = mask.index_fill_(dim=0, index=drop_edge_ind, value=1).bool() #[E, ]
    input_edge_index = edge_index[:, ~mask]
    input_edge_attr  = edge_attr[~mask]

    pos_edge_index = edge_index[:, mask]
    pos_edge_attr  = edge_attr[mask]

    # generate/sample negative edges
    # could choose false negatives, the entire graph is not used!
    neg_edges = []
    neg_edge_attr = []
    nodeset = set(range(edge_index.max()+1))
    for i, edge in enumerate(pos_edge_index.t()):
        src, dst = edge[0], edge[1]

        # Chose negative examples in a smart way
        unavail_mask = (edge_index == src).any(dim=0) | (edge_index == dst).any(dim=0)
        unavail_nodes = torch.unique(edge_index[:, unavail_mask])
        unavail_nodes = set(unavail_nodes.tolist())
        avail_nodes = nodeset - unavail_nodes
        avail_nodes = torch.tensor(list(avail_nodes))
        # Finally, emmulate np.random.choice() to chose randomly amongst available nodes
        indices = torch.randperm(len(avail_nodes))[:64]
        neg_nodes = avail_nodes[indices]
        
        # Generate 32 negative edges with the same source but different destinations
        neg_dsts = neg_nodes[:32]  # Selecting num_neg_samples/2 random destination nodes for the source
        neg_edges_src = torch.stack([src.repeat(int(num_neg_samples/2)), neg_dsts], dim=0)
        
        # Generate 32 negative edges with the same destination but different sources
        neg_srcs = neg_nodes[32:]  # Selecting num_neg_samples/2 random source nodes for the destination
        neg_edges_dst = torch.stack([neg_srcs, dst.repeat(int(num_neg_samples/2))], dim=0)

        # Add these negative edges to the list
        neg_edges.append(neg_edges_src)
        neg_edges.append(neg_edges_dst)
        # Replicate the positive edge attribute for each of the negative edges generated from this edge
        pos_attr = pos_edge_attr[i].unsqueeze(0)  # Get the attribute of the current positive edge
        replicated_attr = pos_attr.repeat(64, 1)  # Replicate it 64 times (for each negative edge)
        neg_edge_attr.append(replicated_attr)
    
    input_edge_index = input_edge_index.to(device)
    input_edge_attr = input_edge_attr.to(device)
    pos_edge_index = pos_edge_index.to(device)
    pos_edge_attr = pos_edge_attr.to(device)
    node_feats = node_feats.to(device)
    neg_edge_index = torch.cat(neg_edges, dim=1).to(device)
    neg_edge_attr = torch.cat(neg_edge_attr, dim=0).to(device)
    return node_feats, edge_index, edge_attr, input_edge_index, input_edge_attr, pos_edge_index, pos_edge_attr, neg_edge_index, neg_edge_attr
batch = next(iter(train_loader))
node_feats, edge_index, edge_attr, input_edge_index, input_edge_attr, pos_edge_index, pos_edge_attr, neg_edge_index, neg_edge_attr = get_gnn_inputs(batch)
ic( node_feats, edge_index, edge_attr, input_edge_index, input_edge_attr, pos_edge_index, pos_edge_attr, neg_edge_index, neg_edge_attr)
ic( node_feats.shape, edge_index.shape, edge_attr.shape, input_edge_index.shape, input_edge_attr.shape, pos_edge_index.shape, pos_edge_attr.shape, neg_edge_index.shape, neg_edge_attr.shape)

ic| node_feats: tensor([[1.],
                        [1.],
                        [1.],
                        ...,
                        [1.],
                        [1.],
                        [1.]], device='cuda:0')
    edge_index: tensor([[1697,  710, 1900,  ...,  140,  209, 1084],
                        [1743,  735, 1897,  ...,  534,  428,  903]])
    edge_attr: tensor([[-0.0009,  0.0013,  0.0005,  ..., -0.0527, -0.0719,  0.0330],
                       [-0.0027,  0.0039,  0.0015,  ..., -0.0527, -0.0719,  0.0330],
                       [ 0.0038, -0.0054, -0.0022,  ..., -0.0527, -0.0719,  0.0330],
                       ...,
                       [-0.0020,  0.0029,  0.0011,  ..., -0.0527, -0.0719,  0.0330],
                       [ 0.0055, -0.0078, -0.0031,  ..., -0.0527, -0.0719,  0.0330],
                       [ 0.0029, -0.0041, -0.0016,  ..., -0.0527, -0.0719,  0.0330]],
                      grad_fn=<ViewBackward0>)
    input_edge_index: tensor([[1697,  710, 1900,  

(torch.Size([1962, 1]),
 torch.Size([2, 1024]),
 torch.Size([1024, 320]),
 torch.Size([2, 871]),
 torch.Size([871, 320]),
 torch.Size([2, 153]),
 torch.Size([153, 320]),
 torch.Size([2, 9792]),
 torch.Size([9792, 320]))

In [50]:
def train(epoc: int, model, optimizer) -> float:
    model.train()
    loss_accum = total_count = 0

    with tqdm(train_loader, desc=f'Epoch {epoc}') as t:
        for tf in t:
            node_feats, _, _, input_edge_index, input_edge_attr, pos_edge_index, pos_edge_attr, neg_edge_index, neg_edge_attr = get_gnn_inputs(tf)
            pred, neg_pred = model(node_feats, input_edge_index, input_edge_attr, pos_edge_index, pos_edge_attr, neg_edge_index, neg_edge_attr)
            loss = lp_loss(pred, neg_pred)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss_accum += float(loss) * len(pred)
            total_count += len(pred)
            t.set_postfix(loss=f'{loss_accum/total_count:.4f}')
            del pred
            del tf
        wandb.log({"train_loss": loss_accum/total_count})
    return {'loss': loss_accum / total_count}

@torch.no_grad()
def test(loader: DataLoader, model, dataset_name) -> float:
    model.eval()
    mrrs = []
    hits1 = []
    hits2 = []
    hits5 = []
    hits10 = []
    loss_accum = 0
    total_count = 0
    with tqdm(loader, desc=f'Evaluating') as t:
        for tf in t:
            node_feats, _, _, input_edge_index, input_edge_attr, pos_edge_index, pos_edge_attr, neg_edge_index, neg_edge_attr = get_gnn_inputs(tf)
            pred, neg_pred = model(node_feats, input_edge_index, input_edge_attr, pos_edge_index, pos_edge_attr, neg_edge_index, neg_edge_attr)
            loss = lp_loss(pred, neg_pred)
            loss_accum += float(loss) * len(pred)
            total_count += len(pred)
            mrr_score, hits = mrr(pred, neg_pred, [1,2,5,10], num_neg_samples)
            # ic(hits)
            # sys.exit()
            mrrs.append(mrr_score)
            hits1.append(hits['hits@1'])
            hits2.append(hits['hits@2'])
            hits5.append(hits['hits@5'])
            hits10.append(hits['hits@10'])
            t.set_postfix(
                loss=f'{loss_accum/total_count:.4f}',
                mrr=f'{mrr_score:.4f}',
                hits1=f'{hits["hits@1"]:.4f}',
                hits2=f'{hits["hits@2"]:.4f}',
                hits5=f'{hits["hits@5"]:.4f}',
                hits10=f'{hits["hits@10"]:.4f}'
            )
        mrr_score = np.mean(mrrs)
        # ic(mrr_score)
        # ic(hits1)
        # ic(hits2)
        # ic(hits5)
        # ic(hits10)
        # sys.exit()
        hits1 = np.sum(hits1) / total_count
        hits2 = np.sum(hits2) / total_count
        hits5 = np.sum(hits5) / total_count
        hits10 = np.sum(hits10) / total_count
        wandb.log({
            f"{dataset_name}_loss": loss_accum/total_count,
            f"{dataset_name}_mrr": mrr_score,
            f"{dataset_name}_hits@1": hits1,
            f"{dataset_name}_hits@2": hits2,
            f"{dataset_name}_hits@5": hits5,
            f"{dataset_name}_hits@10": hits10
        })
        del tf
        del pred
        return {"mrr": mrr_score, "hits@1": hits1, "hits@2": hits2, "hits@5": hits5, "hits@10": hits10}

In [51]:
model = GINe(num_features=1, num_gnn_layers=3, edge_dim=train_dataset.tensor_frame.num_cols*channels, n_classes=1)
model = torch.compile(model, dynamic=True) if compile else model
model.to(device)
learnable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
ic(learnable_params)
wandb.log({"learnable_params": learnable_params})

no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
    {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=lr, eps=eps)
scheduler = get_inverse_sqrt_schedule(optimizer, num_warmup_steps=0, timescale=1000)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

train_metric = test(train_loader, model, "train")
val_metric = test(val_loader, model, "val")
test_metric = test(test_loader, model, "test")
ic(
        train_metric, 
        val_metric, 
        test_metric
)

for epoch in range(1, epochs + 1):
    train_loss = train(epoch, model, optimizer)
    train_metric = test(train_loader, model, "train")
    val_metric = test(val_loader, model, "val")
    test_metric = test(test_loader, model, "test")
    ic(
        train_loss, 
        train_metric, 
        val_metric, 
        test_metric
    )

ic| learnable_params: 243251
Evaluating:   0%|                                                                                | 0/3173 [00:00<?, ?it/s]

Evaluating: 100%|█| 3173/3173 [01:55<00:00, 27.53it/s, hits1=0.0000, hits10=7.0000, hits2=0.0000, hits5=2.0000, loss=1.396
Evaluating: 100%|█| 943/943 [00:34<00:00, 27.53it/s, hits1=0.0000, hits10=12.0000, hits2=0.0000, hits5=2.0000, loss=1.3962
Evaluating: 100%|█| 844/844 [00:30<00:00, 27.56it/s, hits1=2.0000, hits10=11.0000, hits2=3.0000, hits5=3.0000, loss=1.3962
ic| train_metric: {'hits@1': 0.004665927808929741,
                   'hits@10': 0.09710485874495813,
                   'hits@2': 0.007949587379540782,
                   'hits@5': 0.020663983157339617,
                   'mrr': 0.06022523790879262}
    val_metric: {'hits@1': 0.005053270762426956,
                 'hits@10': 0.09622009801543015,
                 'hits@2': 0.008158710133575484,
                 'hits@5': 0.020677512598517984,
                 'mrr': 0.060144905547873995}
    test_metric: {'hits@1': 0.00471804088968771,
                  'hits@10': 0.09677794219044152,
                  'hits@2': 0.007832412

In [None]:
wandb.finish()

VBox(children=(Label(value='0.004 MB of 0.004 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
learnable_params,▁█
test_hits@1,▁█▇▇▆▆
test_hits@10,▁█▅▆▃▃
test_hits@2,▁█▇▇▆▆
test_hits@5,▁█▆▇▆▅
test_loss,█▁▁▁▁▁
test_mrr,▁█▆▇▆▆
train_hits@1,▁█▆▇▆▆
train_hits@10,▁█▅▆▄▃
train_hits@2,▁█▆▇▆▆

0,1
learnable_params,284211.0
test_hits@1,0.09626
test_hits@10,0.119
test_hits@2,0.09721
test_hits@5,0.09888
test_loss,1.05289
test_mrr,0.1333
train_hits@1,0.09975
train_hits@10,0.11772
train_hits@2,0.10013
