In [1]:
import torch
import numpy as np
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv

# === 1. Define your GCN model architecture ===
# Must match your training setup
class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = torch.relu(x)
        x = self.conv2(x, edge_index)
        return x

# === 2. Initialize and load trained model ===
model = GCN(in_channels=166, hidden_channels=64, out_channels=2)
checkpoint = torch.load("gnn_elliptic_model.pth")

model.conv1.load_state_dict(checkpoint["conv1"])
model.conv2.load_state_dict(checkpoint["conv2"])
model.eval()

# === 3. Load mean vector for filling features ===
mean_vector = np.load("mean_vector.npy")  # shape = (166,)

# === 4. Map user input to full feature vector ===
def map_user_input_to_features(amount, time_step):
    features = mean_vector.copy()
    features[0] = amount       # Suppose feature 0 is amount
    features[1] = time_step    # Suppose feature 1 is timestep
    return features.tolist()

# === 5. Predict user transaction ===
def predict_user_transaction(amount, time_step):
    features = map_user_input_to_features(amount, time_step)
    x = torch.tensor([features], dtype=torch.float)

    # No edges — single node prediction
    edge_index = torch.empty((2, 0), dtype=torch.long)

    data = Data(x=x, edge_index=edge_index)

    with torch.no_grad():
        out = model(data.x, data.edge_index)
        pred_class = out.argmax(dim=1).item()
        confidence = torch.softmax(out, dim=1)[0][pred_class].item()

    label_map = {0: "Licit (Legit)", 1: "Illicit (Fraud)"}
    return label_map[pred_class], confidence

# === 6. Example: Run prediction ===
if __name__ == "__main__":
    amount = float(input("Enter transaction amount: "))
    time_step = int(input("Enter time step (0-49): "))

    prediction, confidence = predict_user_transaction(amount, time_step)
    print(f"\n🧾 Prediction: {prediction}")
    print(f"🔍 Confidence: {confidence * 100:.2f}%")

Enter transaction amount:  2350
Enter time step (0-49):  30



🧾 Prediction: Licit (Legit)
🔍 Confidence: 100.00%
