In [1]:
import math
import pandas as pd
import torch
import torch.nn.functional as F
import torch_geometric.transforms as T

from torch_geometric.data import Data

account_df = pd.read_csv('data/processed/kaggle_transaction_dataset_centrality.csv')
address_mapping = pd.DataFrame(data={
    'address': account_df['address'], 
    'mapped_id': account_df['address'].index
})

transaction_df = pd.read_csv('data/queried/full_transactions.csv')
transaction_mapping = transaction_df \
    .merge(
        address_mapping, 
        left_on='address from', 
        right_on='address'
    ) \
    .merge(
        address_mapping, 
        left_on='address to', 
        right_on='address', 
        suffixes=('_from', '_to')
    )

x = F.normalize(
    torch.tensor(
        account_df.drop(columns=['address', 'flag']).to_numpy(), 
        dtype=torch.float
    )
)
y = torch.tensor(
    account_df['flag'].to_numpy(), 
    dtype=torch.long
)
edge_index = torch.tensor(
    transaction_mapping[['mapped_id_from', 'mapped_id_to']].to_numpy().T, 
    dtype=torch.long
)

data = Data(x=x, y=y, edge_index=edge_index)
data = T.ToUndirected()(data)
data = T.RandomNodeSplit(num_val=0, num_test=2000)(data)
data

Data(x=[20302, 22], edge_index=[2, 197481], y=[20302], train_mask=[20302], val_mask=[20302], test_mask=[20302])

In [2]:
classes, counts = torch.unique(y, return_counts=True)

# Calculate class weights based on their frequency
total_samples = torch.sum(counts).float()
class_weights = total_samples / (classes.numel() * counts.float())

print("Class Weights:", class_weights)

Class Weights: tensor([0.6940, 1.7887])


In [3]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GATConv, GCNConv

class GCN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = GCNConv(data.num_node_features, 64)
        self.conv2 = GCNConv(64, 64)
        self.conv3 = GCNConv(64, 2)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv3(x, edge_index)

        return F.log_softmax(x, dim=1)

class GAT(torch.nn.Module):
    def __init__(self, heads=4):
        super().__init__()
        self.conv1 = GATConv(data.num_node_features, 32, heads=heads)
        self.conv2 = GATConv(32 * heads, 32, heads=heads)
        self.conv3 = GATConv(32 * heads, 2, heads=1)  # Last layer typically uses a single head

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv3(x, edge_index)

        return F.log_softmax(x, dim=1)

In [4]:
from sklearn.metrics import f1_score, roc_auc_score

def get_results(model, data):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    data = data.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

    model.train()
    for epoch in range(251):
        optimizer.zero_grad()
        out = model(data)
        loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask], weight=class_weights)
        loss.backward()
        optimizer.step()
        
        if epoch % 10 == 0:
            pred = model(data).argmax(dim=1)
            correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()
            acc = int(correct) / int(data.test_mask.sum())
            f1 = f1_score(data.y[data.test_mask], pred[data.test_mask], average='weighted')
            auroc = roc_auc_score(data.y[data.test_mask], pred[data.test_mask])
            print(f'Epoch {str(epoch).zfill(3)}: Accuracy {acc:.3f} F1 Score {f1:.3f} AUROC: {auroc:.3f}')
    return model

In [5]:
get_results(GCN(), data)

Epoch 000: Accuracy 0.590 F1 Score 0.606 AUROC: 0.660
Epoch 010: Accuracy 0.762 F1 Score 0.771 AUROC: 0.757
Epoch 020: Accuracy 0.802 F1 Score 0.809 AUROC: 0.805
Epoch 030: Accuracy 0.858 F1 Score 0.861 AUROC: 0.851
Epoch 040: Accuracy 0.859 F1 Score 0.863 AUROC: 0.857
Epoch 050: Accuracy 0.864 F1 Score 0.867 AUROC: 0.864
Epoch 060: Accuracy 0.869 F1 Score 0.871 AUROC: 0.858
Epoch 070: Accuracy 0.876 F1 Score 0.878 AUROC: 0.866
Epoch 080: Accuracy 0.858 F1 Score 0.861 AUROC: 0.858
Epoch 090: Accuracy 0.868 F1 Score 0.871 AUROC: 0.862
Epoch 100: Accuracy 0.884 F1 Score 0.886 AUROC: 0.873
Epoch 110: Accuracy 0.872 F1 Score 0.874 AUROC: 0.860
Epoch 120: Accuracy 0.876 F1 Score 0.879 AUROC: 0.867
Epoch 130: Accuracy 0.858 F1 Score 0.862 AUROC: 0.855
Epoch 140: Accuracy 0.867 F1 Score 0.870 AUROC: 0.867
Epoch 150: Accuracy 0.879 F1 Score 0.882 AUROC: 0.877
Epoch 160: Accuracy 0.872 F1 Score 0.875 AUROC: 0.875
Epoch 170: Accuracy 0.879 F1 Score 0.881 AUROC: 0.866
Epoch 180: Accuracy 0.875 F1

GCN(
  (conv1): GCNConv(22, 64)
  (conv2): GCNConv(64, 64)
  (conv3): GCNConv(64, 2)
)

In [6]:
get_results(GAT(), data)

Epoch 000: Accuracy 0.671 F1 Score 0.682 AUROC: 0.642
Epoch 010: Accuracy 0.798 F1 Score 0.802 AUROC: 0.771
Epoch 020: Accuracy 0.849 F1 Score 0.851 AUROC: 0.826
Epoch 030: Accuracy 0.872 F1 Score 0.873 AUROC: 0.853
Epoch 040: Accuracy 0.893 F1 Score 0.892 AUROC: 0.863
Epoch 050: Accuracy 0.875 F1 Score 0.877 AUROC: 0.866
Epoch 060: Accuracy 0.895 F1 Score 0.895 AUROC: 0.874
Epoch 070: Accuracy 0.894 F1 Score 0.895 AUROC: 0.878
Epoch 080: Accuracy 0.870 F1 Score 0.873 AUROC: 0.869
Epoch 090: Accuracy 0.900 F1 Score 0.901 AUROC: 0.879
Epoch 100: Accuracy 0.900 F1 Score 0.902 AUROC: 0.888
Epoch 110: Accuracy 0.887 F1 Score 0.890 AUROC: 0.896
Epoch 120: Accuracy 0.889 F1 Score 0.892 AUROC: 0.897
Epoch 130: Accuracy 0.909 F1 Score 0.910 AUROC: 0.895
Epoch 140: Accuracy 0.903 F1 Score 0.905 AUROC: 0.903
Epoch 150: Accuracy 0.905 F1 Score 0.907 AUROC: 0.906
Epoch 160: Accuracy 0.913 F1 Score 0.914 AUROC: 0.899
Epoch 170: Accuracy 0.905 F1 Score 0.907 AUROC: 0.899
Epoch 180: Accuracy 0.920 F1

GAT(
  (conv1): GATConv(22, 32, heads=4)
  (conv2): GATConv(128, 32, heads=4)
  (conv3): GATConv(128, 2, heads=1)
)