# Credit Card Fraud Detection using GNN and XAI

In [None]:

# Import necessary libraries
import pandas as pd
import torch
from torch_geometric.data import Data
from sklearn.preprocessing import StandardScaler
from src.models.gearsage import GEARSage
from src.models.tgat import TGAT
from sklearn.metrics import roc_auc_score
import matplotlib.pyplot as plt
import seaborn as sns


## 1. Data Loading and Preprocessing

In [None]:

# Load the dataset
file_path = 'data/creditcard.csv'
data = pd.read_csv(file_path)

# Preprocess the data: Normalize 'Amount' and drop unnecessary columns
data['normAmount'] = StandardScaler().fit_transform(data['Amount'].values.reshape(-1, 1))
data = data.drop(['Time', 'Amount'], axis=1)

# Prepare Features (X) and Labels (y)
X = data.drop(['Class'], axis=1).values
y = data['Class'].values

# Convert to PyTorch Geometric Data object
X_tensor = torch.tensor(X, dtype=torch.float)
y_tensor = torch.tensor(y, dtype=torch.long)
graph_data = Data(x=X_tensor, y=y_tensor)


## 2. Data Exploration

In [None]:

# 1. Fraud vs Non-Fraud Distribution
plt.figure(figsize=(6,4))
sns.countplot(x='Class', data=data)
plt.title("Fraud vs Non-Fraud Transaction Distribution")
plt.xlabel("Transaction Class (0 = Non-Fraud, 1 = Fraud)")
plt.ylabel("Count")
plt.show()

# 2. Plot Distribution of the Amount Feature
plt.figure(figsize=(6,4))
sns.histplot(data['normAmount'], bins=50, kde=True)
plt.title("Distribution of Normalized Transaction Amounts")
plt.xlabel("Normalized Amount")
plt.ylabel("Count")
plt.show()

# 3. Correlation Heatmap (Only for features V1 to V28)
plt.figure(figsize=(14,8))
corr_matrix = data.drop(['Class'], axis=1).corr()
sns.heatmap(corr_matrix, annot=False, cmap='coolwarm')
plt.title("Correlation Heatmap of PCA Features (V1 to V28)")
plt.show()


## 3. Model Selection and Training

In [None]:

# Choose the model: GEARSage or TGAT
model = GEARSage(in_channels=X.shape[1], hidden_channels=64, out_channels=2)
# model = TGAT(in_channels=X.shape[1], hidden_channels=64, out_channels=2)

# Training Function with AUC and Time per Epoch
import time
def train_with_timing(model, data, epochs=10):
    times = []
    auc_scores = []

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    loss_fn = torch.nn.CrossEntropyLoss()

    for epoch in range(epochs):
        start_time = time.time()
        model.train()
        optimizer.zero_grad()
        out = model(data)
        loss = loss_fn(out, data.y)
        loss.backward()
        optimizer.step()
        end_time = time.time()

        epoch_time = end_time - start_time
        times.append(epoch_time)

        auc = roc_auc_score(data.y.cpu().numpy(), out.detach().cpu().numpy()[:, 1])
        auc_scores.append(auc)

        print(f"Epoch {epoch+1}, Loss: {loss.item()}, AUC: {auc}, Time: {epoch_time:.2f} seconds")
    
    return times, auc_scores

# Train the model and capture the times and AUC scores
times, auc_scores = train_with_timing(model, graph_data, epochs=10)


## 4. Time and AUC per Epoch

In [None]:

# Plot Time per Epoch
plt.figure(figsize=(6,4))
sns.lineplot(x=range(1, len(times) + 1), y=times)
plt.title("Time Taken Per Epoch")
plt.xlabel("Epoch")
plt.ylabel("Time (seconds)")
plt.show()

# Plot AUC per Epoch
plt.figure(figsize=(6,4))
sns.lineplot(x=range(1, len(auc_scores) + 1), y=auc_scores)
plt.title("AUC Score Per Epoch")
plt.xlabel("Epoch")
plt.ylabel("AUC Score")
plt.show()


## 5. Confusion Matrix

In [None]:

from sklearn.metrics import confusion_matrix

def plot_confusion_matrix(model, data):
    model.eval()
    with torch.no_grad():
        out = model(data)
        y_pred = torch.argmax(out, dim=1).cpu().numpy()
        y_true = data.y.cpu().numpy()
        
        cm = confusion_matrix(y_true, y_pred)
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False)
        plt.title("Confusion Matrix")
        plt.xlabel("Predicted Label")
        plt.ylabel("True Label")
        plt.show()

# Plot confusion matrix
plot_confusion_matrix(model, graph_data)


## 6. Precision-Recall Curve

In [None]:

from sklearn.metrics import precision_recall_curve

def plot_precision_recall_curve(model, data):
    model.eval()
    with torch.no_grad():
        out = model(data)
        y_true = data.y.cpu().numpy()
        y_pred = out.cpu().numpy()[:, 1]
        
        precision, recall, _ = precision_recall_curve(y_true, y_pred)

        plt.plot(recall, precision, marker='.')
        plt.title("Precision-Recall Curve")
        plt.xlabel("Recall")
        plt.ylabel("Precision")
        plt.show()

# Plot precision-recall curve
plot_precision_recall_curve(model, graph_data)


## 7. ROC Curve

In [None]:

from sklearn.metrics import roc_curve

def plot_roc_curve(model, data):
    model.eval()
    with torch.no_grad():
        out = model(data)
        y_true = data.y.cpu().numpy()
        y_pred = out.cpu().numpy()[:, 1]

        fpr, tpr, _ = roc_curve(y_true, y_pred)
        plt.plot(fpr, tpr, marker='.')
        plt.title("ROC Curve for Fraud Detection")
        plt.xlabel("False Positive Rate")
        plt.ylabel("True Positive Rate")
        plt.show()

# Plot ROC curve
plot_roc_curve(model, graph_data)


## 8. Loss per Epoch

In [None]:

def train_with_loss(model, data, epochs=10):
    losses = []
    auc_scores = []

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    loss_fn = torch.nn.CrossEntropyLoss()

    for epoch in range(epochs):
        model.train()
        optimizer.zero_grad()
        out = model(data)
        loss = loss_fn(out, data.y)
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

        auc = roc_auc_score(data.y.cpu().numpy(), out.detach().cpu().numpy()[:, 1])
        auc_scores.append(auc)

        print(f"Epoch {epoch+1}, Loss: {loss.item()}, AUC: {auc}")
    
    return losses, auc_scores

# Train and capture the loss and AUC
losses, auc_scores = train_with_loss(model, graph_data, epochs=10)

# Plot Loss per Epoch
plt.figure(figsize=(6,4))
sns.lineplot(x=range(1, len(losses) + 1), y=losses)
plt.title("Loss Per Epoch")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.show()
