# GNN Anomaly Detection with GAE + GraphSAGE using BankSim

In [1]:

import pandas as pd

df = pd.read_csv('../data/bs140513_032310.csv')
print(df.head())
print(df['fraud'].value_counts())


   step       customer  age gender zipcodeOri       merchant zipMerchant  \
0     0  'C1093826151'  '4'    'M'    '28007'   'M348934600'     '28007'   
1     0   'C352968107'  '2'    'M'    '28007'   'M348934600'     '28007'   
2     0  'C2054744914'  '4'    'F'    '28007'  'M1823072687'     '28007'   
3     0  'C1760612790'  '3'    'M'    '28007'   'M348934600'     '28007'   
4     0   'C757503768'  '5'    'M'    '28007'   'M348934600'     '28007'   

              category  amount  fraud  
0  'es_transportation'    4.55      0  
1  'es_transportation'   39.68      0  
2  'es_transportation'   26.89      0  
3  'es_transportation'   17.25      0  
4  'es_transportation'   35.72      0  
fraud
0    587443
1      7200
Name: count, dtype: int64


In [2]:

import re

def get_numeric_id(customer_id):
    numeric_only = re.sub(r'[^0-9]', '', str(customer_id))
    return int(numeric_only) if numeric_only else 0

df['device_fp'] = df['customer'].apply(lambda x: f"fp_{get_numeric_id(x) % 1000}")


In [3]:

import networkx as nx

G = nx.Graph()
for _, row in df.iterrows():
    if row['fraud'] == 0:  # Use only normal transactions for training
        u = f"user_{row['customer']}"
        m = f"merch_{row['merchant']}"
        G.add_node(u, type='user')
        G.add_node(m, type='merchant')
        G.add_edge(u, m)


In [5]:

import torch
from torch_geometric.utils import from_networkx
from torch_geometric.data import Data

data = from_networkx(G)
data.x = torch.eye(data.num_nodes)  # Identity features as placeholder


In [6]:

from torch_geometric.nn import GAE, SAGEConv

class GNNEncoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv(in_channels, 2 * out_channels)
        self.conv2 = SAGEConv(2 * out_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        return self.conv2(x, edge_index)

encoder = GNNEncoder(data.num_node_features, 32)
model = GAE(encoder)


In [7]:

from torch_geometric.utils import train_test_split_edges

data = train_test_split_edges(data)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

def train():
    model.train()
    optimizer.zero_grad()
    z = model.encode(data.x, data.train_pos_edge_index)
    loss = model.recon_loss(z, data.train_pos_edge_index)
    loss.backward()
    optimizer.step()
    return loss.item()

for epoch in range(1, 101):
    loss = train()
    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {loss:.4f}")




Epoch 10, Loss: 1.0605
Epoch 20, Loss: 0.8523
Epoch 30, Loss: 1.1409
Epoch 40, Loss: 0.9314
Epoch 50, Loss: 0.8461
Epoch 60, Loss: 0.8142
Epoch 70, Loss: 0.7971
Epoch 80, Loss: 0.7859
Epoch 90, Loss: 0.7766
Epoch 100, Loss: 0.7718


In [8]:

from sklearn.metrics import roc_auc_score
from torch_geometric.utils import negative_sampling

model.eval()
with torch.no_grad():
    z = model.encode(data.x, data.train_pos_edge_index)
    pos_pred = model.decoder(z, data.test_pos_edge_index).squeeze()

    neg_edge_index = negative_sampling(
        edge_index=data.train_pos_edge_index,
        num_nodes=z.size(0),
        num_neg_samples=data.test_pos_edge_index.size(1)
    )
    neg_pred = model.decoder(z, neg_edge_index).squeeze()

    y_true = torch.cat([torch.ones(pos_pred.size(0)), torch.zeros(neg_pred.size(0))])
    y_score = torch.cat([pos_pred, neg_pred])

    auc = roc_auc_score(y_true.cpu(), y_score.cpu())
    print(f"AUC Score: {auc:.4f}")


AUC Score: 0.9960


In [9]:
# Save the trained model and create node mapping for production use
import pickle

# Create node mapping for faster lookup
node_mapping = {}
for i, node in enumerate(G.nodes()):
    node_mapping[node] = i

# Save model and necessary data
model_data = {
    'model_state_dict': model.state_dict(),
    'node_mapping': node_mapping,
    'data': data,
    'encoder_config': {'in_channels': data.num_node_features, 'out_channels': 32}
}

torch.save(model_data, '../outputs/fraud_detection_model.pth')
print("Model saved successfully!")

Model saved successfully!


In [10]:
class FraudDetectionWrapper:
    """
    Production wrapper for the GNN fraud detection model
    Handles preprocessing, prediction, and output formatting for web API integration
    """
    
    def __init__(self, model_path: str = None):
        self.model = None
        self.node_mapping = {}
        self.data = None
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        if model_path:
            self.load_model(model_path)
    
    def get_numeric_id(self, customer_id):
        """Extract numeric ID from customer string"""
        numeric_only = re.sub(r'[^0-9]', '', str(customer_id))
        return int(numeric_only) if numeric_only else 0
    
    def preprocess_transaction(self, transaction_data: dict) -> dict:
        """
        Preprocess a single transaction for model input
        
        Required inputs in transaction_data:
        - customer: Customer ID (string) - e.g., "C1093826151"
        - merchant: Merchant ID (string) - e.g., "M348934600"
        - amount: Transaction amount (float) - e.g., 156.50
        - category: Transaction category (string) - e.g., "es_transportation"
        - age: Customer age (int) - e.g., 4
        - gender: Customer gender (string) - e.g., "M" or "F"
        - zipcodeOri: Customer zipcode (string) - e.g., "28007"
        - zipMerchant: Merchant zipcode (string) - e.g., "28007"
        """
        processed = {
            'customer_node': f"user_{transaction_data['customer']}",
            'merchant_node': f"merch_{transaction_data['merchant']}",
            'amount': float(transaction_data['amount']),
            'customer_age': int(transaction_data['age']),
            'same_zipcode': transaction_data['zipcodeOri'] == transaction_data['zipMerchant'],
            'category': transaction_data['category'],
            'gender': transaction_data['gender'],
            'device_fp': f"fp_{self.get_numeric_id(transaction_data['customer']) % 1000}"
        }
        return processed
    
    def calculate_anomaly_score(self, transaction_data: dict) -> float:
        """
        Calculate anomaly score for a transaction
        Returns: Float between 0-1 (higher = more suspicious)
        """
        if not self.model:
            raise ValueError("Model not loaded. Call load_model() first.")
            
        self.model.eval()
        processed = self.preprocess_transaction(transaction_data)
        
        customer_node = processed['customer_node']
        merchant_node = processed['merchant_node']
        
        # Check if nodes exist in trained graph
        if customer_node not in self.node_mapping or merchant_node not in self.node_mapping:
            return self._calculate_new_entity_score(processed)
        
        # Get node embeddings and calculate link probability
        with torch.no_grad():
            z = self.model.encode(self.data.x, self.data.train_pos_edge_index)
            
            customer_idx = self.node_mapping[customer_node]
            merchant_idx = self.node_mapping[merchant_node]
            
            # Calculate reconstruction probability
            edge_index = torch.tensor([[customer_idx], [merchant_idx]], dtype=torch.long)
            link_prob = torch.sigmoid(self.model.decoder(z, edge_index)).item()
            
            # Convert to anomaly score (lower link probability = higher anomaly)
            base_anomaly = 1 - link_prob
            
            # Apply business rules
            anomaly_score = self._apply_business_rules(base_anomaly, processed)
            
        return min(max(anomaly_score, 0.0), 1.0)  # Clamp between 0-1
    
    def _calculate_new_entity_score(self, processed: dict) -> float:
        """Handle transactions with new customers/merchants"""
        base_score = 0.3  # Base suspicion for new entities
        
        # Adjust based on amount
        if processed['amount'] > 1000:
            base_score += 0.2
        elif processed['amount'] > 500:
            base_score += 0.1
            
        # Adjust based on location
        if not processed['same_zipcode']:
            base_score += 0.1
            
        return min(base_score, 1.0)
    
    def _apply_business_rules(self, base_score: float, processed: dict) -> float:
        """Apply business rules to adjust anomaly score"""
        adjusted_score = base_score
        
        # High amount transactions
        if processed['amount'] > 1000:
            adjusted_score += 0.2
        elif processed['amount'] > 500:
            adjusted_score += 0.1
            
        # Cross-location transactions
        if not processed['same_zipcode']:
            adjusted_score += 0.15
            
        # Suspicious categories (customize based on your domain)
        suspicious_categories = ['es_health', 'es_hyper', 'es_wellnessandbeauty']
        if processed['category'] in suspicious_categories and processed['amount'] > 200:
            adjusted_score += 0.1
            
        return adjusted_score
    
    def predict_fraud(self, transaction_data: dict) -> dict:
        """
        Main prediction method for web API
        
        Input: Dictionary with transaction details
        Output: Dictionary with fraud prediction results
        """
        anomaly_score = self.calculate_anomaly_score(transaction_data)
        
        # Define risk thresholds
        high_risk_threshold = 0.7
        medium_risk_threshold = 0.4
        
        # Determine fraud prediction
        is_fraud = anomaly_score > high_risk_threshold
        fraud_probability = anomaly_score
        
        # Risk level classification
        if anomaly_score > high_risk_threshold:
            risk_level = "HIGH"
        elif anomaly_score > medium_risk_threshold:
            risk_level = "MEDIUM"
        else:
            risk_level = "LOW"
        
        # Identify risk factors for explainability
        risk_factors = self._identify_risk_factors(transaction_data, anomaly_score)
        
        return {
            'is_fraud': is_fraud,
            'fraud_probability': round(fraud_probability, 3),
            'risk_level': risk_level,
            'anomaly_score': round(anomaly_score, 3),
            'risk_factors': risk_factors,
            'transaction_id': transaction_data.get('transaction_id', 'N/A')
        }
    
    def _identify_risk_factors(self, transaction_data: dict, anomaly_score: float) -> list:
        """Identify specific risk factors for explainability"""
        factors = []
        
        if float(transaction_data['amount']) > 1000:
            factors.append("High transaction amount (>$1000)")
        elif float(transaction_data['amount']) > 500:
            factors.append("Elevated transaction amount (>$500)")
        
        if transaction_data['zipcodeOri'] != transaction_data['zipMerchant']:
            factors.append("Cross-location transaction")
        
        if anomaly_score > 0.5:
            factors.append("Unusual customer-merchant relationship pattern")
        
        # Check for new entities
        customer_node = f"user_{transaction_data['customer']}"
        merchant_node = f"merch_{transaction_data['merchant']}"
        
        if customer_node not in self.node_mapping:
            factors.append("New customer (not seen in training data)")
        if merchant_node not in self.node_mapping:
            factors.append("New merchant (not seen in training data)")
        
        return factors
    
    def load_model(self, model_path: str):
        """Load a pre-trained model from file"""
        try:
            checkpoint = torch.load(model_path, map_location=self.device)
            
            # Reconstruct model architecture
            encoder_config = checkpoint.get('encoder_config', {'in_channels': 1000, 'out_channels': 32})
            encoder = GNNEncoder(encoder_config['in_channels'], encoder_config['out_channels'])
            self.model = GAE(encoder)
            self.model.load_state_dict(checkpoint['model_state_dict'])
            
            self.node_mapping = checkpoint['node_mapping']
            self.data = checkpoint['data']
            self.model.to(self.device)
            self.model.eval()
            
            print(f"Model loaded successfully from {model_path}")
            print(f"Graph has {len(self.node_mapping)} nodes")
            
        except Exception as e:
            raise ValueError(f"Failed to load model: {str(e)}")

# Create an instance of the wrapper
fraud_detector = FraudDetectionWrapper()
fraud_detector.model = model
fraud_detector.node_mapping = node_mapping
fraud_detector.data = data

print("Fraud detection wrapper initialized!")

Fraud detection wrapper initialized!


In [11]:
# Test the fraud detection wrapper with sample data
print("=== Testing Fraud Detection Wrapper ===")

# Example transaction data (based on your dataset structure)
sample_transaction = {
    'customer': 'C1093826151',
    'merchant': 'M348934600', 
    'amount': 156.50,
    'category': 'es_transportation',
    'age': 4,
    'gender': 'M',
    'zipcodeOri': '28007',
    'zipMerchant': '28007',
    'transaction_id': 'TXN_001'
}

# Get fraud prediction
result = fraud_detector.predict_fraud(sample_transaction)

print(f"Transaction: ${sample_transaction['amount']} from {sample_transaction['customer']} to {sample_transaction['merchant']}")
print(f"Fraud Prediction: {result['is_fraud']}")
print(f"Risk Level: {result['risk_level']}")
print(f"Fraud Probability: {result['fraud_probability']}")
print(f"Anomaly Score: {result['anomaly_score']}")
print(f"Risk Factors: {result['risk_factors']}")
print()

# Test with a high-amount transaction
high_amount_transaction = sample_transaction.copy()
high_amount_transaction['amount'] = 1500.0
high_amount_transaction['zipMerchant'] = '90210'  # Different zipcode

result_high = fraud_detector.predict_fraud(high_amount_transaction)
print(f"High Amount Transaction: ${high_amount_transaction['amount']}")
print(f"Fraud Prediction: {result_high['is_fraud']}")
print(f"Risk Level: {result_high['risk_level']}")
print(f"Risk Factors: {result_high['risk_factors']}")

=== Testing Fraud Detection Wrapper ===
Transaction: $156.5 from C1093826151 to M348934600
Fraud Prediction: False
Risk Level: LOW
Fraud Probability: 0.3
Anomaly Score: 0.3
Risk Factors: ['New customer (not seen in training data)', 'New merchant (not seen in training data)']

High Amount Transaction: $1500.0
Fraud Prediction: False
Risk Level: MEDIUM
Risk Factors: ['High transaction amount (>$1000)', 'Cross-location transaction', 'Unusual customer-merchant relationship pattern', 'New customer (not seen in training data)', 'New merchant (not seen in training data)']


In [None]:
# Flask API example for web integration
flask_api_code = '''
from flask import Flask, request, jsonify
from flask_cors import CORS
import torch
from fraud_detection_wrapper import FraudDetectionWrapper

app = Flask(__name__)
CORS(app)  # Enable CORS for web frontend integration

# Load the trained model (update path as needed)
fraud_detector = FraudDetectionWrapper('models/fraud_detection_model.pth')

@app.route('/api/predict_fraud', methods=['POST'])
def predict_fraud():
    """
    API endpoint for fraud prediction
    
    Expected JSON input:
    {
        "customer": "C1093826151",
        "merchant": "M348934600", 
        "amount": 156.50,
        "category": "es_transportation",
        "age": 4,
        "gender": "M",
        "zipcodeOri": "28007",
        "zipMerchant": "28007",
        "transaction_id": "TXN_001"  // optional
    }
    
    Returns JSON:
    {
        "is_fraud": false,
        "fraud_probability": 0.234,
        "risk_level": "LOW",
        "anomaly_score": 0.234,
        "risk_factors": ["High transaction amount"],
        "transaction_id": "TXN_001"
    }
    """
    try:
        # Validate request
        if not request.is_json:
            return jsonify({'error': 'Request must be JSON'}), 400
        
        transaction_data = request.json
        
        # Validate required fields
        required_fields = ['customer', 'merchant', 'amount', 'category', 'age', 'gender', 'zipcodeOri', 'zipMerchant']
        for field in required_fields:
            if field not in transaction_data:
                return jsonify({'error': f'Missing required field: {field}'}), 400
        
        # Get fraud prediction
        result = fraud_detector.predict_fraud(transaction_data)
        
        return jsonify(result), 200
        
    except ValueError as ve:
        return jsonify({'error': f'Validation error: {str(ve)}'}), 400
    except Exception as e:
        return jsonify({'error': f'Internal server error: {str(e)}'}), 500

@app.route('/api/health', methods=['GET'])
def health_check():
    """Health check endpoint"""
    return jsonify({'status': 'healthy', 'model_loaded': fraud_detector.model is not None}), 200

@app.route('/api/model_info', methods=['GET'])
def model_info():
    """Get model information"""
    try:
        info = {
            'model_type': 'GAE + GraphSAGE',
            'num_nodes': len(fraud_detector.node_mapping) if fraud_detector.node_mapping else 0,
            'device': str(fraud_detector.device),
            'required_fields': ['customer', 'merchant', 'amount', 'category', 'age', 'gender', 'zipcodeOri', 'zipMerchant']
        }
        return jsonify(info), 200
    except Exception as e:
        return jsonify({'error': str(e)}), 500

if __name__ == '__main__':
    print("Starting Fraud Detection API...")
    print("Available endpoints:")
    print("  POST /api/predict_fraud - Main prediction endpoint")
    print("  GET  /api/health - Health check")
    print("  GET  /api/model_info - Model information")
    app.run(host='0.0.0.0', port=5000, debug=True)
'''

print("Flask API code generated!")
print("To use this API:")
print("1. Save the above code as 'fraud_api.py'")
print("2. Install Flask: pip install flask flask-cors")
print("3. Run: python fraud_api.py")
print("4. API will be available at http://localhost:5000")

# Model Integration Summary

## Required Inputs for Fraud Detection

Your GNN model requires the following fields for each transaction:

| Field | Type | Description | Example |
|-------|------|-------------|---------|
| `customer` | string | Customer ID | "C1093826151" |
| `merchant` | string | Merchant ID | "M348934600" |
| `amount` | float | Transaction amount | 156.50 |
| `category` | string | Transaction category | "es_transportation" |
| `age` | int | Customer age | 4 |
| `gender` | string | Customer gender | "M" or "F" |
| `zipcodeOri` | string | Customer zipcode | "28007" |
| `zipMerchant` | string | Merchant zipcode | "28007" |
| `transaction_id` | string | Optional transaction ID | "TXN_001" |

## Model Outputs

The model returns a JSON object with:

| Field | Type | Description |
|-------|------|-------------|
| `is_fraud` | boolean | Binary fraud prediction (true/false) |
| `fraud_probability` | float | Fraud probability score (0.0 - 1.0) |
| `risk_level` | string | Risk classification ("LOW", "MEDIUM", "HIGH") |
| `anomaly_score` | float | Raw anomaly score (0.0 - 1.0) |
| `risk_factors` | array | List of identified risk factors |
| `transaction_id` | string | Transaction ID (if provided) |

## Integration Steps

1. **Save the model**: Run the model saving cell to create `fraud_detection_model.pth`
2. **Deploy the wrapper**: Use the `FraudDetectionWrapper` class in your application
3. **Create API endpoint**: Use the Flask example to create a REST API
4. **Connect to frontend**: Your web app can POST transaction data to `/api/predict_fraud`

## Risk Thresholds

- **HIGH RISK** (≥0.7): Block transaction, require manual review
- **MEDIUM RISK** (0.4-0.7): Flag for additional verification
- **LOW RISK** (<0.4): Allow transaction to proceed

## Business Rules Applied

- High amount transactions (>$500, >$1000)
- Cross-location transactions (different zipcodes)
- New customers/merchants not in training data
- Suspicious categories (configurable)
- Unusual customer-merchant relationship patterns