In [17]:
import pandas as pd
import torch
import torch.nn.functional as F
import torch_geometric.transforms as T

from torch_geometric.data import Data

account_df = pd.read_csv('data/processed/kaggle_transaction_dataset_centrality.csv')
address_mapping = pd.DataFrame(data={
    'address': account_df['address'], 
    'mapped_id': account_df['address'].index
})

transaction_df = pd.read_csv('data/queried/full_transactions.csv')
transaction_mapping = transaction_df \
    .merge(
        address_mapping, 
        left_on='address from', 
        right_on='address'
    ) \
    .merge(
        address_mapping, 
        left_on='address to', 
        right_on='address', 
        suffixes=('_from', '_to')
    )

x = F.normalize(
    torch.tensor(
        account_df.drop(columns=['address', 'flag']).to_numpy(), 
        dtype=torch.float
    )
)
y = torch.tensor(
    account_df['flag'].to_numpy(), 
    dtype=torch.long
)
edge_index = torch.tensor(
    transaction_mapping[['mapped_id_from', 'mapped_id_to']].to_numpy().T, 
    dtype=torch.long
)

data = Data(x=x, y=y, edge_index=edge_index)
data = T.ToUndirected()(data)
data = T.RandomNodeSplit(num_val=0, num_test=2000)(data)
data

Data(x=[20302, 22], edge_index=[2, 197481], y=[20302], train_mask=[20302], val_mask=[20302], test_mask=[20302])

In [18]:
classes, counts = torch.unique(y, return_counts=True)

# Calculate class weights based on their frequency
total_samples = torch.sum(counts).float()
class_weights = total_samples / (classes.numel() * counts.float())

print("Class Weights:", class_weights)

Class Weights: tensor([0.6940, 1.7887])


In [19]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GATConv, GCNConv

class GCN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = GCNConv(data.num_node_features, 64)
        self.conv2 = GCNConv(64, 64)
        self.conv3 = GCNConv(64, 2)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv3(x, edge_index)

        return F.log_softmax(x, dim=1)

class GAT(torch.nn.Module):
    def __init__(self, heads=4):
        super().__init__()
        self.conv1 = GATConv(data.num_node_features, 32, heads=heads)
        self.conv2 = GATConv(32 * heads, 32, heads=heads)
        self.conv3 = GATConv(32 * heads, 2, heads=1)  # Last layer typically uses a single head

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv3(x, edge_index)

        return F.log_softmax(x, dim=1)

In [20]:
from sklearn.metrics import precision_score, recall_score, f1_score, roc_auc_score

def get_results(model, data):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    data = data.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

    model.train()
    for epoch in range(251):
        optimizer.zero_grad()
        out = model(data)
        loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask], weight=class_weights)
        loss.backward()
        optimizer.step()
        
        if epoch % 10 == 0:
            pred = model(data).argmax(dim=1)
            correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()
            acc = int(correct) / int(data.test_mask.sum())
            precision = precision_score(data.y[data.test_mask], pred[data.test_mask], average='weighted')
            recall = recall_score(data.y[data.test_mask], pred[data.test_mask], average='weighted')
            f1 = f1_score(data.y[data.test_mask], pred[data.test_mask], average='weighted')
            auroc = roc_auc_score(data.y[data.test_mask], pred[data.test_mask])
            print(f'Epoch {str(epoch).zfill(3)}: Accuracy {acc:.3f} Precision: {precision:.3f} Recall: {recall:.3f} F1 Score {f1:.3f} AUROC: {auroc:.3f}')
    return model

In [21]:
get_results(GCN(), data)

Epoch 000: Accuracy 0.552 Precision: 0.753 Recall: 0.552 F1 Score 0.568 AUROC: 0.650
Epoch 010: Accuracy 0.754 Precision: 0.797 Recall: 0.754 F1 Score 0.765 AUROC: 0.755
Epoch 020: Accuracy 0.818 Precision: 0.839 Recall: 0.818 F1 Score 0.824 AUROC: 0.811
Epoch 030: Accuracy 0.861 Precision: 0.872 Recall: 0.861 F1 Score 0.864 AUROC: 0.852
Epoch 040: Accuracy 0.874 Precision: 0.881 Recall: 0.874 F1 Score 0.876 AUROC: 0.859
Epoch 050: Accuracy 0.883 Precision: 0.889 Recall: 0.883 F1 Score 0.885 AUROC: 0.868
Epoch 060: Accuracy 0.881 Precision: 0.888 Recall: 0.881 F1 Score 0.883 AUROC: 0.869
Epoch 070: Accuracy 0.887 Precision: 0.895 Recall: 0.887 F1 Score 0.889 AUROC: 0.879
Epoch 080: Accuracy 0.892 Precision: 0.895 Recall: 0.892 F1 Score 0.893 AUROC: 0.873
Epoch 090: Accuracy 0.886 Precision: 0.892 Recall: 0.886 F1 Score 0.888 AUROC: 0.874
Epoch 100: Accuracy 0.892 Precision: 0.900 Recall: 0.892 F1 Score 0.894 AUROC: 0.886
Epoch 110: Accuracy 0.882 Precision: 0.892 Recall: 0.882 F1 Score

GCN(
  (conv1): GCNConv(22, 64)
  (conv2): GCNConv(64, 64)
  (conv3): GCNConv(64, 2)
)

In [22]:
get_results(GAT(), data)

Epoch 000: Accuracy 0.617 Precision: 0.740 Recall: 0.617 F1 Score 0.639 AUROC: 0.664
Epoch 010: Accuracy 0.818 Precision: 0.831 Recall: 0.818 F1 Score 0.822 AUROC: 0.796
Epoch 020: Accuracy 0.845 Precision: 0.857 Recall: 0.845 F1 Score 0.849 AUROC: 0.831
Epoch 030: Accuracy 0.875 Precision: 0.877 Recall: 0.875 F1 Score 0.875 AUROC: 0.848
Epoch 040: Accuracy 0.890 Precision: 0.891 Recall: 0.890 F1 Score 0.891 AUROC: 0.864
Epoch 050: Accuracy 0.882 Precision: 0.887 Recall: 0.882 F1 Score 0.884 AUROC: 0.865
Epoch 060: Accuracy 0.845 Precision: 0.879 Recall: 0.845 F1 Score 0.852 AUROC: 0.867
Epoch 070: Accuracy 0.907 Precision: 0.906 Recall: 0.907 F1 Score 0.906 AUROC: 0.870
Epoch 080: Accuracy 0.908 Precision: 0.907 Recall: 0.908 F1 Score 0.908 AUROC: 0.878
Epoch 090: Accuracy 0.903 Precision: 0.903 Recall: 0.903 F1 Score 0.903 AUROC: 0.873
Epoch 100: Accuracy 0.889 Precision: 0.903 Recall: 0.889 F1 Score 0.893 AUROC: 0.895
Epoch 110: Accuracy 0.909 Precision: 0.910 Recall: 0.909 F1 Score

GAT(
  (conv1): GATConv(22, 32, heads=4)
  (conv2): GATConv(128, 32, heads=4)
  (conv3): GATConv(128, 2, heads=1)
)