Install required libraries

In [None]:
# Install required libraries
!pip install torch torch-geometric matplotlib scikit-learn


Loads a pre-trained GraphSAGE model to perform transaction classification and detect fraudulent transactions.

In [None]:
import torch
from torch_geometric.nn import SAGEConv
from torch_geometric.data import Data

# Step 1: Recreate the GraphSAGE model architecture
class GraphSAGEModel(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, num_layers, out_channels):
        super(GraphSAGEModel, self).__init__()
        self.convs = torch.nn.ModuleList()
        self.convs.append(SAGEConv(in_channels, hidden_channels))
        for _ in range(num_layers - 2):
            self.convs.append(SAGEConv(hidden_channels, hidden_channels))
        self.convs.append(SAGEConv(hidden_channels, out_channels))

    def forward(self, x, edge_index):
        for conv in self.convs[:-1]:
            x = torch.relu(conv(x, edge_index))
        x = self.convs[-1](x, edge_index)  # No activation on the final layer
        return x

# Step 2: Instantiate the model using the same parameters from your job creation step
model = GraphSAGEModel(in_channels=165, hidden_channels=256, num_layers=3, out_channels=2)

# Step 3: Load the state dictionary from the 'model' key in the checkpoint
model_path = '/tmp/nvflare/gnn/finance_fl_workspace/simulate_job/app_server/FL_global_model.pt'
checkpoint = torch.load(model_path, weights_only=True)


# Extract the model's state_dict from the checkpoint
state_dict = checkpoint['model']

# Load the state dictionary into the model
model.load_state_dict(state_dict)

# Step 4: Set the model to evaluation mode for inference
model.eval()

# Step 5: Prepare some dummy transaction data
node_features = torch.rand((5, 165))  # 5 transactions, each with 165 features

edge_index = torch.tensor([[0, 1, 2, 3],  # From nodes (transactions)
                           [1, 0, 3, 2]], dtype=torch.long)  # To nodes (transactions)

# Create a PyTorch Geometric Data object
new_transaction_data = Data(x=node_features, edge_index=edge_index)

# Step 6: Run inference
with torch.no_grad():  # Disable gradients for inference
    prediction = model(new_transaction_data.x, new_transaction_data.edge_index)

# Step 7: Interpret the results
is_fraudulent = prediction.argmax(dim=1)  # Get the class with the highest score (fraud or not)
print("Prediction scores:", prediction)
print("Fraudulent transactions:", is_fraudulent)
