# GNN node classification for AML detection
### The main code for the GNN was referenced from, https://github.com/IBM/Pattern-GNN 

In [1]:
import datetime
import os
from typing import Callable, Optional
import pandas as pd
from sklearn import preprocessing
import numpy as np
import torch

In [2]:
from torch_geometric.data import (
    Data,
    InMemoryDataset
)

In [3]:
pd.set_option('display.max_columns', None)
path = 'data\\raw\LI-Small_Trans.csv'
df = pd.read_csv(path)

# Pre-processing the data

In [4]:
print(df.head())

          Timestamp  From Bank    Account  To Bank  Account.1  \
0  2022/09/01 00:08         11  8000ECA90       11  8000ECA90   
1  2022/09/01 00:21       3402  80021DAD0     3402  80021DAD0   
2  2022/09/01 00:00         11  8000ECA90     1120  8006AA910   
3  2022/09/01 00:16       3814  8006AD080     3814  8006AD080   
4  2022/09/01 00:00         20  8006AD530       20  8006AD530   

   Amount Received Receiving Currency  Amount Paid Payment Currency  \
0       3195403.00          US Dollar   3195403.00        US Dollar   
1          1858.96          US Dollar      1858.96        US Dollar   
2        592571.00          US Dollar    592571.00        US Dollar   
3            12.32          US Dollar        12.32        US Dollar   
4          2941.56          US Dollar      2941.56        US Dollar   

  Payment Format  Is Laundering  
0   Reinvestment              0  
1   Reinvestment              0  
2         Cheque              0  
3   Reinvestment              0  
4   Reinvest

In [5]:
print(df.dtypes)

Timestamp              object
From Bank               int64
Account                object
To Bank                 int64
Account.1              object
Amount Received       float64
Receiving Currency     object
Amount Paid           float64
Payment Currency       object
Payment Format         object
Is Laundering           int64
dtype: object


The dataset would be turned into a ndoe classification problem 

Accounts as nodes and transactions as edges. Objects column would needed to be encoded using LabelEncoder.

Check if there are any null values

In [6]:
print(df.isnull().sum())

Timestamp             0
From Bank             0
Account               0
To Bank               0
Account.1             0
Amount Received       0
Receiving Currency    0
Amount Paid           0
Payment Currency      0
Payment Format        0
Is Laundering         0
dtype: int64


Check for imbalances in transactions

In [7]:
print('Amount Received equals to Amount Paid:')
print(df['Amount Received'].equals(df['Amount Paid']))
print('Receiving Currency equals to Payment Currency:')
print(df['Receiving Currency'].equals(df['Payment Currency']))

Amount Received equals to Amount Paid:
False
Receiving Currency equals to Payment Currency:
False


It seens involved the transcations between different currency, let's print it out

In [8]:
not_equal1 = df.loc[~(df['Amount Received'] == df['Amount Paid'])]
not_equal2 = df.loc[~(df['Receiving Currency'] == df['Payment Currency'])]
print(not_equal1)
print('---------------------------------------------------------------------------')
print(not_equal2)

                Timestamp  From Bank    Account  To Bank  Account.1  \
2770     2022/09/01 00:12        394  80056EDE0      394  80056EDE0   
8081     2022/09/01 00:28      11701  800C95BF0    11701  800C95BF0   
10451    2022/09/01 00:18      22481  80105E630    22481  80105E630   
12948    2022/09/01 00:17       1439  8014545C0     1439  8014545C0   
13799    2022/09/01 00:02         20  8015D68E0       20  8015D68E0   
...                   ...        ...        ...      ...        ...   
6924007  2022/09/10 23:57       9096  80356BD61     9096  80356BD60   
6924009  2022/09/10 23:30       9096  80356BD61     9096  80356BD60   
6924019  2022/09/10 23:38      13474  803A93631    13474  803A93630   
6924021  2022/09/10 23:31      13474  803A93631    13474  803A93630   
6924023  2022/09/10 23:56      13474  803A93631    13474  803A93630   

         Amount Received Receiving Currency  Amount Paid Payment Currency  \
2770           47.610000               Euro        55.79        US Dol

The amount received and paid are not the same meaning that there might be transactiuon fees involved. cannot remove and use only 1 amount column.

In [9]:
print(sorted(df['Receiving Currency'].unique()))
print(sorted(df['Payment Currency'].unique()))

['Australian Dollar', 'Bitcoin', 'Brazil Real', 'Canadian Dollar', 'Euro', 'Mexican Peso', 'Ruble', 'Rupee', 'Saudi Riyal', 'Shekel', 'Swiss Franc', 'UK Pound', 'US Dollar', 'Yen', 'Yuan']
['Australian Dollar', 'Bitcoin', 'Brazil Real', 'Canadian Dollar', 'Euro', 'Mexican Peso', 'Ruble', 'Rupee', 'Saudi Riyal', 'Shekel', 'Swiss Franc', 'UK Pound', 'US Dollar', 'Yen', 'Yuan']


In the data preprocessing, these transformations were done:  
1. Transform the Timestamp with min max normalization.  
2. Create unique ID for each account by adding bank code with account number.  
3. Create receiving_df with the information of receiving accounts, received amount and currency
4. Create paying_df with the information of payer accounts, paid amount and currency
5. Create a list of currency used among all transactions
6. Label the 'Payment Format', 'Payment Currency', 'Receiving Currency' by classes with sklearn LabelEncoder


In [10]:
def df_label_encoder(df, columns):
        le = preprocessing.LabelEncoder()
        for i in columns:
            df[i] = le.fit_transform(df[i].astype(str))
        return df

def preprocess(df):
        df = df_label_encoder(df,['Payment Format', 'Payment Currency', 'Receiving Currency'])
        df['Timestamp'] = pd.to_datetime(df['Timestamp'])
        df['Timestamp'] = df['Timestamp'].apply(lambda x: x.value)
        df['Timestamp'] = (df['Timestamp']-df['Timestamp'].min())/(df['Timestamp'].max()-df['Timestamp'].min())

        df['Account'] = df['From Bank'].astype(str) + '_' + df['Account']
        df['Account.1'] = df['To Bank'].astype(str) + '_' + df['Account.1']
        df = df.sort_values(by=['Account'])
        receiving_df = df[['Account.1', 'Amount Received', 'Receiving Currency']]
        paying_df = df[['Account', 'Amount Paid', 'Payment Currency']]
        receiving_df = receiving_df.rename({'Account.1': 'Account'}, axis=1)
        currency_ls = sorted(df['Receiving Currency'].unique())

        return df, receiving_df, paying_df, currency_ls

Let's have a look of processed df

In [11]:
df, receiving_df, paying_df, currency_ls = preprocess(df = df)
print(df.head())

         Timestamp  From Bank      Account  To Bank        Account.1  \
3408783   0.266147          0  0_800060CE0    11314  11314_800990320   
3986981   0.318925          0  0_800060CE0    11314  11314_800990320   
4804475   0.393400          0  0_800060CE0    11314  11314_800990320   
4804474   0.394151          0  0_800060CE0    11314  11314_800990320   
6690464   0.547730          0  0_800060CE0     1390   1390_800E49870   

         Amount Received  Receiving Currency  Amount Paid  Payment Currency  \
3408783          8081.58                   4      8081.58                 4   
3986981         47468.31                   4     47468.31                 4   
4804475          8081.58                   4      8081.58                 4   
4804474         47468.31                   4     47468.31                 4   
6690464           787.72                   4       787.72                 4   

         Payment Format  Is Laundering  
3408783               4              0  
3986981   

paying df and receiving df:

In [12]:
print(receiving_df.head())
print(paying_df.head())

                 Account  Amount Received  Receiving Currency
3408783  11314_800990320          8081.58                   4
3986981  11314_800990320         47468.31                   4
4804475  11314_800990320          8081.58                   4
4804474  11314_800990320         47468.31                   4
6690464   1390_800E49870           787.72                   4
             Account  Amount Paid  Payment Currency
3408783  0_800060CE0      8081.58                 4
3986981  0_800060CE0     47468.31                 4
4804475  0_800060CE0      8081.58                 4
4804474  0_800060CE0     47468.31                 4
6690464  0_800060CE0       787.72                 4


currency_ls:

In [13]:
print(currency_ls)

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]


For the node of the graph, the unique accounts for payer and receiver would be extracted, this would include its id and bank code, and if it is laundering. both payer and receiver for the laundering trasaction would be marked as laundering.

In [14]:
def get_all_account(df):
        ldf = df[['Account', 'From Bank']]
        rdf = df[['Account.1', 'To Bank']]
        suspicious = df[df['Is Laundering']==1]
        s1 = suspicious[['Account', 'Is Laundering']]
        s2 = suspicious[['Account.1', 'Is Laundering']]
        s2 = s2.rename({'Account.1': 'Account'}, axis=1)
        suspicious = pd.concat([s1, s2], join='outer')
        suspicious = suspicious.drop_duplicates()

        ldf = ldf.rename({'From Bank': 'Bank'}, axis=1)
        rdf = rdf.rename({'Account.1': 'Account', 'To Bank': 'Bank'}, axis=1)
        df = pd.concat([ldf, rdf], join='outer')
        df = df.drop_duplicates()

        df['Is Laundering'] = 0
        df.set_index('Account', inplace=True)
        df.update(suspicious.set_index('Account'))
        df = df.reset_index()
        return df

Take a look of the account list:

In [15]:
accounts = get_all_account(df)
print(accounts.head())

       Account  Bank  Is Laundering
0  0_800060CE0     0              0
1  0_800061260     0              0
2  0_800062D90     0              0
3  0_800062F80     0              0
4  0_800064980     0              0


# Node features
For node features, we would like to aggregate the mean of paid and received amount with different types of currency as the new features of each node. 

In [16]:
def paid_currency_aggregate(currency_ls, paying_df, accounts):
        for i in currency_ls:
            temp = paying_df[paying_df['Payment Currency'] == i]
            accounts['avg paid '+str(i)] = temp['Amount Paid'].groupby(temp['Account']).transform('mean')
        return accounts

def received_currency_aggregate(currency_ls, receiving_df, accounts):
    for i in currency_ls:
        temp = receiving_df[receiving_df['Receiving Currency'] == i]
        accounts['avg received '+str(i)] = temp['Amount Received'].groupby(temp['Account']).transform('mean')
    accounts = accounts.fillna(0)
    return accounts

Now we can define the node attributes by the bank code and the mean of paid and received amount with different types of currency.

In [17]:
def get_node_attr(currency_ls, paying_df,receiving_df, accounts):
        node_df = paid_currency_aggregate(currency_ls, paying_df, accounts)
        node_df = received_currency_aggregate(currency_ls, receiving_df, node_df)
        node_label = torch.from_numpy(node_df['Is Laundering'].values).to(torch.float)
        node_df = node_df.drop(['Account', 'Is Laundering'], axis=1)
        node_df = df_label_encoder(node_df,['Bank'])
#         node_df = torch.from_numpy(node_df.values).to(torch.float)  # comment for visualization
        return node_df, node_label

node_df:

In [18]:
node_df, node_label = get_node_attr(currency_ls, paying_df,receiving_df, accounts)
print(node_df.head())

   Bank  avg paid 0  avg paid 1  avg paid 2  avg paid 3  avg paid 4  \
0     0         0.0         0.0         0.0         0.0         0.0   
1     0         0.0         0.0         0.0         0.0         0.0   
2     0         0.0         0.0         0.0         0.0         0.0   
3     0         0.0         0.0         0.0         0.0         0.0   
4     0         0.0         0.0         0.0         0.0         0.0   

   avg paid 5  avg paid 6  avg paid 7  avg paid 8  avg paid 9  avg paid 10  \
0         0.0         0.0         0.0         0.0         0.0          0.0   
1         0.0         0.0         0.0         0.0         0.0          0.0   
2         0.0         0.0         0.0         0.0         0.0          0.0   
3         0.0         0.0         0.0         0.0         0.0          0.0   
4         0.0         0.0         0.0         0.0         0.0          0.0   

   avg paid 11    avg paid 12  avg paid 13  avg paid 14  avg received 0  \
0          0.0  307628.336486

# Edge features
What is edge features?

Edge features are the attributes of the edges in the graph. In the context of the AML detection, the edge features are the attributes of the transactions between the accounts. To be simple its like the connection between the nodes.

For edge index, we replace all account with index and stack into a list with size of [2, num of transcation]

For edge attributes, 

'Timestamp', 'Amount Received', 'Receiving Currency', 'Amount Paid', 'Payment Currency' and 'Payment Format'


In [19]:
def get_edge_df(accounts, df):
        accounts = accounts.reset_index(drop=True)
        accounts['ID'] = accounts.index
        mapping_dict = dict(zip(accounts['Account'], accounts['ID']))
        df['From'] = df['Account'].map(mapping_dict)
        df['To'] = df['Account.1'].map(mapping_dict)
        df = df.drop(['Account', 'Account.1', 'From Bank', 'To Bank'], axis=1)

        edge_index = torch.stack([torch.from_numpy(df['From'].values), torch.from_numpy(df['To'].values)], dim=0)

        df = df.drop(['Is Laundering', 'From', 'To'], axis=1)

#         edge_attr = torch.from_numpy(df.values).to(torch.float)  # comment for visualization

        edge_attr = df  # for visualization
        return edge_attr, edge_index

edge_attr:

In [20]:
edge_attr, edge_index = get_edge_df(accounts, df)
print(edge_attr.head())

         Timestamp  Amount Received  Receiving Currency  Amount Paid  \
3408783   0.266147          8081.58                   4      8081.58   
3986981   0.318925         47468.31                   4     47468.31   
4804475   0.393400          8081.58                   4      8081.58   
4804474   0.394151         47468.31                   4     47468.31   
6690464   0.547730           787.72                   4       787.72   

         Payment Currency  Payment Format  
3408783                 4               4  
3986981                 4               3  
4804475                 4               4  
4804474                 4               3  
6690464                 4               3  


edge_index:

In [21]:
print(edge_index)

tensor([[     0,      0,      0,  ..., 681281, 681282, 681282],
        [ 22343,  22343,  22343,  ..., 681281, 681282, 681282]])


# Model Architecture
Graph Attention Networks was used for the model.
This GAT class defines a two-layer Graph Attention Network for processing graph data. It uses attention mechanisms to focus on important parts of the graph and applies dropout for regularization. The network architecture is designed to transform node features through two GAT layers followed by a linear layer and a sigmoid activation to produce the final output.

In [22]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.transforms as T
from torch_geometric.nn import GATConv, Linear

class GAT(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, heads):
        super().__init__()
        self.conv1 = GATConv(in_channels, hidden_channels, heads, dropout=0.6)
        self.conv2 = GATConv(hidden_channels * heads, int(hidden_channels/4), heads=1, concat=False, dropout=0.6)
        self.lin = Linear(int(hidden_channels/4), out_channels)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x, edge_index, edge_attr):
        x = F.dropout(x, p=0.6, training=self.training)
        x = F.elu(self.conv1(x, edge_index, edge_attr))
        x = F.dropout(x, p=0.6, training=self.training)
        x = F.elu(self.conv2(x, edge_index, edge_attr))
        x = self.lin(x)
        x = self.sigmoid(x)
        
        return x

## PyG InMemoryDataset
The Dataset would be built using the functions made above.

The AMLtoGraph class is designed to process transaction data from a CSV file and convert it into a graph format suitable for graph neural network (GNN) training. It inherits from InMemoryDataset, which is a part of the PyTorch Geometric library.

In [23]:
class AMLtoGraph(InMemoryDataset):

    def __init__(self, root: str, edge_window_size: int = 10,
                 transform: Optional[Callable] = None,
                 pre_transform: Optional[Callable] = None):
        self.edge_window_size = edge_window_size
        super().__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self) -> str:
        return 'HI-Small_Trans.csv'

    @property
    def processed_file_names(self) -> str:
        return 'data.pt'

    @property
    def num_nodes(self) -> int:
        return self._data.edge_index.max().item() + 1

    def df_label_encoder(self, df, columns):
        le = preprocessing.LabelEncoder()
        for i in columns:
            df[i] = le.fit_transform(df[i].astype(str))
        return df


    def preprocess(self, df):
        df = self.df_label_encoder(df,['Payment Format', 'Payment Currency', 'Receiving Currency'])
        df['Timestamp'] = pd.to_datetime(df['Timestamp'])
        df['Timestamp'] = df['Timestamp'].apply(lambda x: x.value)
        df['Timestamp'] = (df['Timestamp']-df['Timestamp'].min())/(df['Timestamp'].max()-df['Timestamp'].min())

        df['Account'] = df['From Bank'].astype(str) + '_' + df['Account']
        df['Account.1'] = df['To Bank'].astype(str) + '_' + df['Account.1']
        df = df.sort_values(by=['Account'])
        receiving_df = df[['Account.1', 'Amount Received', 'Receiving Currency']]
        paying_df = df[['Account', 'Amount Paid', 'Payment Currency']]
        receiving_df = receiving_df.rename({'Account.1': 'Account'}, axis=1)
        currency_ls = sorted(df['Receiving Currency'].unique())

        return df, receiving_df, paying_df, currency_ls

    def get_all_account(self, df):
        ldf = df[['Account', 'From Bank']]
        rdf = df[['Account.1', 'To Bank']]
        suspicious = df[df['Is Laundering']==1]
        s1 = suspicious[['Account', 'Is Laundering']]
        s2 = suspicious[['Account.1', 'Is Laundering']]
        s2 = s2.rename({'Account.1': 'Account'}, axis=1)
        suspicious = pd.concat([s1, s2], join='outer')
        suspicious = suspicious.drop_duplicates()

        ldf = ldf.rename({'From Bank': 'Bank'}, axis=1)
        rdf = rdf.rename({'Account.1': 'Account', 'To Bank': 'Bank'}, axis=1)
        df = pd.concat([ldf, rdf], join='outer')
        df = df.drop_duplicates()

        df['Is Laundering'] = 0
        df.set_index('Account', inplace=True)
        df.update(suspicious.set_index('Account'))
        df = df.reset_index()
        return df
    
    def paid_currency_aggregate(self, currency_ls, paying_df, accounts):
        for i in currency_ls:
            temp = paying_df[paying_df['Payment Currency'] == i]
            accounts['avg paid '+str(i)] = temp['Amount Paid'].groupby(temp['Account']).transform('mean')
        return accounts

    def received_currency_aggregate(self, currency_ls, receiving_df, accounts):
        for i in currency_ls:
            temp = receiving_df[receiving_df['Receiving Currency'] == i]
            accounts['avg received '+str(i)] = temp['Amount Received'].groupby(temp['Account']).transform('mean')
        accounts = accounts.fillna(0)
        return accounts

    def get_edge_df(self, accounts, df):
        accounts = accounts.reset_index(drop=True)
        accounts['ID'] = accounts.index
        mapping_dict = dict(zip(accounts['Account'], accounts['ID']))
        df['From'] = df['Account'].map(mapping_dict)
        df['To'] = df['Account.1'].map(mapping_dict)
        df = df.drop(['Account', 'Account.1', 'From Bank', 'To Bank'], axis=1)

        edge_index = torch.stack([torch.from_numpy(df['From'].values), torch.from_numpy(df['To'].values)], dim=0)

        df = df.drop(['Is Laundering', 'From', 'To'], axis=1)

        edge_attr = torch.from_numpy(df.values).to(torch.float)
        return edge_attr, edge_index

    def get_node_attr(self, currency_ls, paying_df,receiving_df, accounts):
        node_df = self.paid_currency_aggregate(currency_ls, paying_df, accounts)
        node_df = self.received_currency_aggregate(currency_ls, receiving_df, node_df)
        node_label = torch.from_numpy(node_df['Is Laundering'].values).to(torch.float)
        node_df = node_df.drop(['Account', 'Is Laundering'], axis=1)
        node_df = self.df_label_encoder(node_df,['Bank'])
        node_df = torch.from_numpy(node_df.values).to(torch.float)
        return node_df, node_label

    def process(self):
        df = pd.read_csv(self.raw_paths[0])
        df, receiving_df, paying_df, currency_ls = self.preprocess(df)
        accounts = self.get_all_account(df)
        node_attr, node_label = self.get_node_attr(currency_ls, paying_df,receiving_df, accounts)
        edge_attr, edge_index = self.get_edge_df(accounts, df)

        data = Data(x=node_attr,
                    edge_index=edge_index,
                    y=node_label,
                    edge_attr=edge_attr
                    )
        
        data_list = [data] 
        if self.pre_filter is not None:
            data_list = [d for d in data_list if self.pre_filter(d)]

        if self.pre_transform is not None:
            data_list = [self.pre_transform(d) for d in data_list]

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

# Model Training 

This code sets up a graph neural network (GNN) using PyTorch and PyTorch Geometric, specifically focusing on a Graph Attention Network (GAT). It processes transaction data, splits it into training and validation sets, and trains the model to detect suspicious transactions.

In [24]:
import torch
import torch_geometric.transforms as T
from torch_geometric.loader import NeighborLoader

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dataset = AMLtoGraph('data')
data = dataset[0]
epoch = 150

model = GAT(in_channels=data.num_features, hidden_channels=16, out_channels=1, heads=8)
model = model.to(device)
criterion = torch.nn.BCELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)

split = T.RandomNodeSplit(split='train_rest', num_val=0.1, num_test=0)
data = split(data)

train_loader = loader = NeighborLoader(
    data,
    num_neighbors=[30] * 2,
    batch_size=256,
    input_nodes=data.train_mask,
)

test_loader = loader = NeighborLoader(
    data,
    num_neighbors=[30] * 2,
    batch_size=256,
    input_nodes=data.val_mask,
)

for i in range(epoch):
    total_loss = 0
    model.train()
    for data in train_loader:
        optimizer.zero_grad()
        data.to(device)
        pred = model(data.x, data.edge_index, data.edge_attr)
        ground_truth = data.y
        loss = criterion(pred, ground_truth.unsqueeze(1))
        loss.backward()
        optimizer.step()
        total_loss += float(loss)
    if epoch%10 == 0:
        print(f"Epoch: {i:03d}, Loss: {total_loss:.4f}")
        model.eval()
        acc = 0
        total = 0
        for test_data in test_loader:
            test_data.to(device)
            pred = model(test_data.x, test_data.edge_index, test_data.edge_attr)
            ground_truth = test_data.y
            correct = (pred == ground_truth.unsqueeze(1)).sum().item()
            total += len(ground_truth)
            acc += correct
        acc = acc/total
        print('accuracy:', acc)

Epoch: 000, Loss: 3539.0887
accuracy: 0.9687609162133839
Epoch: 001, Loss: 2196.2791
accuracy: 0.9699783920275065
Epoch: 002, Loss: 2104.6358
accuracy: 0.9699872700496718
Epoch: 003, Loss: 2004.3377
accuracy: 0.9702978697917707
Epoch: 004, Loss: 1938.8116
accuracy: 0.9699355012495697
Epoch: 005, Loss: 1893.9508
accuracy: 0.9701303430421357
Epoch: 006, Loss: 1842.7416
accuracy: 0.9706239172437207
Epoch: 007, Loss: 1799.4301
accuracy: 0.9707260595958739
Epoch: 008, Loss: 1750.1756
accuracy: 0.9707752030816518
Epoch: 009, Loss: 1760.9437
accuracy: 0.9708991673360241
Epoch: 010, Loss: 1720.9987
accuracy: 0.9709595959595959
Epoch: 011, Loss: 1692.9881
accuracy: 0.9711886491964673
Epoch: 012, Loss: 1685.2074
accuracy: 0.9712663764179581
Epoch: 013, Loss: 1642.8474
accuracy: 0.9710630353613002
Epoch: 014, Loss: 1635.3046
accuracy: 0.9709446845893522
Epoch: 015, Loss: 1619.7822
accuracy: 0.9709762401333187
Epoch: 016, Loss: 1597.0071
accuracy: 0.9713059319707519
Epoch: 017, Loss: 1578.0267
acc

In [25]:
import torch
import numpy as np
from sklearn.metrics import confusion_matrix, precision_recall_fscore_support


model.eval()
all_preds = []
all_labels = []

for test_data in test_loader:
    test_data.to(device)
    with torch.no_grad():
        pred = model(test_data.x, test_data.edge_index, test_data.edge_attr)
    predicted_classes = (pred.sigmoid() > 0.5).int()  # Convert probabilities to 0 or 1
    all_preds.extend(predicted_classes.cpu().numpy())
    all_labels.extend(test_data.y.cpu().numpy())

all_preds = np.array(all_preds).flatten()
all_labels = np.array(all_labels).flatten()

# Calculate confusion matrix and other metrics
conf_mat = confusion_matrix(all_labels, all_preds)
precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='binary')

print("Confusion Matrix:\n", conf_mat)
print("Precision: {:.4f}".format(precision))
print("Recall: {:.4f}".format(recall))
print("F1 Score: {:.4f}".format(f1))


Confusion Matrix:
 [[195338     54]
 [  5002      2]]
Precision: 0.0357
Recall: 0.0004
F1 Score: 0.0008


In [26]:
import joblib

# Assuming 'model' is your trained model
joblib.dump(model, 'GNN_model.joblib')


['GNN_model.joblib']

In [27]:
torch.save(model.state_dict(), 'GNN_model.pth')
