<a href="https://colab.research.google.com/github/aishwarya-walimbe/Fraud-Detetection-Using-GNN/blob/main/train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [5]:
import os
cur_path = '/content/drive/MyDrive/Fraud_Detection_Project'
os.chdir(cur_path)
!pwd

/content/drive/MyDrive/Fraud_Detection_Project


In [6]:
%run preprocess.ipynb
%run model.ipynb

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/MyDrive/Fraud_Detection_Project
Collecting torch_geometric
  Downloading torch_geometric-2.7.0-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.7/63.7 kB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.7.0-py3-none-any.whl (1.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m22.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.7.0
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/MyDrive/Fraud_Detection_Project
ImprovedGraphSAGE: ImprovedGraphSAGE(
  (conv1): SAGEConv(12, 128, aggr=mean)
  (conv2): SAGEConv(128, 128, aggr=mean)
  (conv3): SAGEConv(128, 128, aggr=me

In [7]:
import os
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns



In [8]:
from sklearn.metrics import (
    classification_report,
    confusion_matrix,
    roc_auc_score,
    f1_score,
    roc_curve,
)
from sklearn.utils.class_weight import compute_class_weight

# The functions and classes from preprocess.ipynb and model.ipynb should be available
# in the global namespace after the `%run` commands in the previous cell.
# Therefore, explicit imports are not needed and cause a ModuleNotFoundError.

os.makedirs("outputs", exist_ok=True)

In [9]:
file_path = "/content/drive/MyDrive/Fraud_Detection_Project/paysim.csv"

In [16]:
df = load_data(file_path)
data = build_graph(df)

[load_data] rows=200,000  fraud=8,213  (4.11%)
[build_graph] nodes=374,979  edges=200,000  fraud_nodes=8,213
  train=262,485  val=37,460  test=75,034


In [17]:
class FocalLoss(torch.nn.Module):
    """
    Focal Loss (Lin et al., 2017).

    Why better than plain cross-entropy for fraud?
    The dataset is ~98% non-fraud. With plain loss the model
    gets away with memorising "everything is normal."

    Focal loss DOWN-WEIGHTS the loss on easy (confident) examples
    and UP-WEIGHTS the loss on hard (uncertain) examples —
    forcing the model to focus on the rare fraud cases.

    gamma=2 is the standard starting point.
    alpha weights the positive (fraud) class higher.
    """

    def __init__(self, alpha: float = 0.25, gamma: float = 2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        ce_loss = F.cross_entropy(logits, targets, reduction="none")
        pt      = torch.exp(-ce_loss)                          # probability of correct class
        focal   = self.alpha * (1 - pt) ** self.gamma * ce_loss
        return focal.mean()

In [18]:
# THRESHOLD TUNING
# ──────────────────────────────────────────────────────────────
def find_best_threshold(probs: np.ndarray, labels: np.ndarray,
                        step: float = 0.01) -> tuple[float, float]:
    """Scan thresholds on validation set, return (best_threshold, best_f1)."""
    best_t, best_f1 = 0.5, 0.0
    for t in np.arange(0.1, 0.95, step):
        preds = (probs >= t).astype(int)
        f1    = f1_score(labels, preds, zero_division=0)
        if f1 > best_f1:
            best_f1, best_t = f1, t
    return best_t, best_f1

In [19]:
# EVALUATION
# ──────────────────────────────────────────────────────────────
def evaluate(model, data, mask, threshold: float = 0.5) -> dict:
    model.eval()
    with torch.no_grad():
        logits  = model(data)
        probs   = torch.softmax(logits, dim=1)[:, 1].cpu().numpy()
        labels  = data.y[mask].cpu().numpy()
        mask_np = mask.cpu().numpy()
        preds   = (probs[mask_np] >= threshold).astype(int)

    return {
        "probs":  probs[mask_np],
        "labels": labels,
        "preds":  preds,
        "f1":     f1_score(labels, preds, zero_division=0),
        "auc":    roc_auc_score(labels, probs[mask_np]),
    }


In [23]:
# ── Model selection ───────────────────────────────────────

# Define model parameters
use_gat = False # Set to True to use GATFraudNet, False for ImprovedGraphSAGE
hidden = 64 # Number of hidden units
lr = 0.001 # Learning rate
epochs = 100 # Number of training epochs
patience = 20 # Early stopping patience

print(f"[config] use_gat={use_gat}, hidden={hidden}, lr={lr}, epochs={epochs}, patience={patience}")


if use_gat:
    edge_dim = data.edge_attr.shape[1] if hasattr(data, "edge_attr") else 5
    model    = GATFraudNet(in_channels=data.num_features,
                           hidden_channels=hidden,
                           heads=4, edge_dim=edge_dim)
    print(f"[train] Using GATFraudNet  | features={data.num_features} | "
          f"hidden={hidden} | edge_dim={edge_dim}")
else:
    model = ImprovedGraphSAGE(in_channels=data.num_features,
                              hidden_channels=hidden)
    print(f"[train] Using ImprovedGraphSAGE | features={data.num_features} | "
          f"hidden={hidden}")

# ── Focal loss (replaces plain NLLLoss) ───────────────────
criterion = FocalLoss(alpha=0.25, gamma=2.0)

# ── Optimiser + LR scheduler ──────────────────────────────
optimizer = torch.optim.AdamW(model.parameters(), lr=lr,
                              weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode="max", factor=0.5, patience=10
)

# ── Training loop ─────────────────────────────────────────
train_losses, val_f1s   = [], []
best_val_f1, best_epoch = 0.0, 0
patience_counter        = 0
best_state              = None

for epoch in range(1, epochs + 1):

    # --- Forward + loss ---
    model.train()
    optimizer.zero_grad()
    logits = model(data)
    loss   = criterion(logits[data.train_mask], data.y[data.train_mask])
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)   # gradient clipping
    optimizer.step()

    train_losses.append(loss.item())

    # --- Validation F1 every epoch ---
    result = evaluate(model, data, data.val_mask, threshold=0.5)
    val_f1 = result["f1"]
    val_f1s.append(val_f1)

    scheduler.step(val_f1)

    if epoch % 10 == 0:
        print(f"Epoch {epoch:03d} | Loss: {loss.item():.4f} | "
              f"Val F1: {val_f1:.4f} | LR: {optimizer.param_groups[0]["lr"]:.5f}")

    # --- Early stopping ---
    if val_f1 > best_val_f1:
        best_val_f1   = val_f1
        best_epoch    = epoch
        best_state    = {k: v.clone() for k, v in model.state_dict().items()}
        patience_counter = 0
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print(f"\n[early stop] No improvement for {patience} epochs. "
                  f"Best Val F1={best_val_f1:.4f} at epoch {best_epoch}")
            break

# ── Restore best weights
if best_state:
    model.load_state_dict(best_state)
    print(f"[train] Restored best model from epoch {best_epoch}")

plot_loss_curve(train_losses, val_f1s)

# ── Tune threshold on validation set
val_result = evaluate(model, data, data.val_mask, threshold=0.5)
best_thresh, best_val_f1_tuned = find_best_threshold(
    val_result["probs"], val_result["labels"]
)
print(f"\n[threshold] Best threshold on val set: {best_thresh:.2f} "
      f"(val F1 = {best_val_f1_tuned:.4f})")

# ── Final evaluation on TEST set
test_result = evaluate(model, data, data.test_mask, threshold=best_thresh)

print("\n" + "=" * 55)
print("  FINAL TEST SET RESULTS")
print("=" * 55)
print(classification_report(
    test_result["labels"], test_result["preds"],
    target_names=["Normal", "Fraud"], digits=4
))
print(f"  F1  (fraud): {test_result["f1"]:.4f}")
print(f"  ROC-AUC    : {test_result["auc"]:.4f}")
print("=" * 55)

plot_confusion_matrix(test_result["labels"], test_result["preds"])
plot_roc_curve(test_result["labels"], test_result["probs"],
               test_result["auc"])

# ── Save model
torch.save({
    "model_state": best_state,
    "threshold":   best_thresh,
    "num_features": data.num_features,
    "hidden":       hidden,
    "use_gat":      use_gat,
}, "outputs/fraud_model.pth")
print("[train] Model saved → outputs/fraud_model.pth")


[config] use_gat=False, hidden=64, lr=0.001, epochs=100, patience=20
[train] Using ImprovedGraphSAGE | features=12 | hidden=64
Epoch 010 | Loss: 0.0070 | Val F1: 0.5168 | LR: 0.00100
Epoch 020 | Loss: 0.0049 | Val F1: 0.6067 | LR: 0.00100
Epoch 030 | Loss: 0.0041 | Val F1: 0.6413 | LR: 0.00100
Epoch 040 | Loss: 0.0038 | Val F1: 0.6577 | LR: 0.00100
Epoch 050 | Loss: 0.0035 | Val F1: 0.6796 | LR: 0.00100
Epoch 060 | Loss: 0.0033 | Val F1: 0.6873 | LR: 0.00100
Epoch 070 | Loss: 0.0031 | Val F1: 0.6917 | LR: 0.00100
Epoch 080 | Loss: 0.0030 | Val F1: 0.6974 | LR: 0.00100
Epoch 090 | Loss: 0.0029 | Val F1: 0.6982 | LR: 0.00100
Epoch 100 | Loss: 0.0029 | Val F1: 0.6983 | LR: 0.00100
[train] Restored best model from epoch 99
[plot] Saved outputs/loss_f1_curve.png

[threshold] Best threshold on val set: 0.45 (val F1 = 0.7020)

  FINAL TEST SET RESULTS
              precision    recall  f1-score   support

      Normal     0.9930    0.9941    0.9935     73391
       Fraud     0.7220    0.6878 

In [8]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, roc_curve

# PLOT HELPERS

def plot_loss_curve(train_losses: list, val_f1s: list):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

    ax1.plot(train_losses, color="steelblue")
    ax1.set_title("Training Loss (Focal)")
    ax1.set_xlabel("Epoch")
    ax1.set_ylabel("Loss")

    ax2.plot(val_f1s, color="darkorange")
    ax2.set_title("Validation F1 Score (Fraud)")
    ax2.set_xlabel("Epoch")
    ax2.set_ylabel("F1")

    plt.tight_layout()
    plt.savefig("outputs/loss_f1_curve.png", dpi=150)
    plt.show()
    print("[plot] Saved outputs/loss_f1_curve.png")


def plot_confusion_matrix(labels: np.ndarray, preds: np.ndarray):
    cm = confusion_matrix(labels, preds)
    plt.figure(figsize=(5, 4))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
                xticklabels=["Normal", "Fraud"],
                yticklabels=["Normal", "Fraud"])
    plt.ylabel("True Label")
    plt.xlabel("Predicted Label")
    plt.title("Confusion Matrix — Test Set")
    plt.tight_layout()
    plt.show()
    plt.savefig("outputs/confusion_matrix.png", dpi=150)
    plt.close()
    print("[plot] Saved outputs/confusion_matrix.png")


def plot_roc_curve(labels: np.ndarray, probs: np.ndarray, auc_score: float):
    fpr, tpr, _ = roc_curve(labels, probs)
    plt.figure(figsize=(6, 5))
    plt.plot(fpr, tpr, color="crimson",
             label=f"ROC Curve (AUC = {auc_score:.4f})")
    plt.plot([0, 1], [0, 1], "k--", linewidth=0.8)
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("ROC Curve — Test Set")
    plt.legend()
    plt.tight_layout()
    plt.savefig("outputs/roc_curve.png", dpi=150)
    plt.close()
    print("[plot] Saved outputs/roc_curve.png")