In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import json
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.metrics import fbeta_score, classification_report, precision_recall_curve, confusion_matrix
from torch_geometric.nn import SAGEConv

In [3]:
data = torch.load("../data/graphs/alemari_graph.pt", weights_only=False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data = data.to(device)

In [4]:
class GraphSAGE(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers=3, dropout=0.2, use_batchnorm=False):
        super(GraphSAGE, self).__init__()
        self.convs = nn.ModuleList()
        self.bns = nn.ModuleList() if use_batchnorm else None
        self.dropout = nn.Dropout(dropout)
        self.use_batchnorm = use_batchnorm

        self.convs.append(SAGEConv(in_channels, hidden_channels))
        if use_batchnorm:
            self.bns.append(nn.BatchNorm1d(hidden_channels))

        for _ in range(num_layers - 2):
            self.convs.append(SAGEConv(hidden_channels, hidden_channels))
            if use_batchnorm:
                self.bns.append(nn.BatchNorm1d(hidden_channels))

        self.convs.append(SAGEConv(hidden_channels, out_channels))

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        for i, conv in enumerate(self.convs[:-1]):
            x = conv(x, edge_index)
            if self.use_batchnorm:
                x = self.bns[i](x)
            x = F.leaky_relu(x)
            x = self.dropout(x)
        return F.log_softmax(self.convs[-1](x, edge_index), dim=1)

In [None]:
def grid_search_graphsage(data, m_grid, fp_exact=146, beta=2.0): #change fp exact -> 146 for 1%, 73 for 0.5%
    split = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
    train_idx, test_idx = next(split.split(torch.arange(len(data.y)), data.y.cpu()))
    train_mask = torch.tensor(train_idx, dtype=torch.long, device=device)
    test_mask = torch.tensor(test_idx, dtype=torch.long, device=device)
    torch.save(test_mask, "../configs/sage_test_mask_4feat_146fp.pt")

    best_result = None
    best_f2 = -1

    arch_grid = [
        {"hidden_dim": 128, "lr": 0.001,  "weight_decay": 5e-4, "layers": 3, "dropout": 0.2, "bn": False, "opt": "adam"},
        {"hidden_dim": 128, "lr": 0.0005, "weight_decay": 5e-4, "layers": 3, "dropout": 0.3, "bn": True,  "opt": "adam"},
        {"hidden_dim": 256, "lr": 0.001,  "weight_decay": 1e-4, "layers": 2, "dropout": 0.1, "bn": False, "opt": "adam"},
        {"hidden_dim": 128, "lr": 0.001,  "weight_decay": 1e-3, "layers": 4, "dropout": 0.2, "bn": True,  "opt": "adam"},
        {"hidden_dim": 64,  "lr": 0.001,  "weight_decay": 5e-4, "layers": 3, "dropout": 0.2, "bn": False, "opt": "sgd"},
        {"hidden_dim": 128, "lr": 0.001,  "weight_decay": 0.0,  "layers": 3, "dropout": 0.0, "bn": False, "opt": "adam"},
    ]

    for arch in arch_grid:
        for m in m_grid:
            model = GraphSAGE(
                in_channels=data.num_features,
                hidden_channels=arch["hidden_dim"],
                out_channels=2,
                num_layers=arch["layers"],
                dropout=arch["dropout"],
                use_batchnorm=arch["bn"]
            ).to(device)

            optimizer = torch.optim.Adam(model.parameters(), lr=arch["lr"], weight_decay=arch["weight_decay"])
            class_weights = 1.0 / torch.bincount(data.y).float()
            class_weights[1] *= m
            criterion = nn.CrossEntropyLoss(weight=class_weights.to(device))

            for _ in range(200):
                model.train()
                optimizer.zero_grad()
                out = model(data)
                loss = criterion(out[train_mask], data.y[train_mask])
                loss.backward()
                optimizer.step()

            model.eval()
            with torch.no_grad():
                logits = model(data)[test_mask]
                probs = torch.exp(logits)[:, 1].cpu().numpy()
                labels = data.y[test_mask].cpu().numpy()

            prec, rec, thr = precision_recall_curve(labels, probs)
            f2_arr = (1 + beta**2) * prec * rec / (beta**2 * prec + rec + 1e-8)

            for i in range(len(thr)):
                if thr[i] < 1e-6:
                    continue
                preds = (probs > thr[i]).astype(int)
                fp = ((labels == 0) & (preds == 1)).sum()
                if fp == fp_exact and f2_arr[i] > best_f2:
                    best_f2 = f2_arr[i]
                    best_result = {
                        "m": float(m),
                        "thr": float(thr[i]),
                        "Fbeta": float(f2_arr[i]),
                        "precision": float(prec[i]),
                        "recall": float(rec[i]),
                        "false_positives": int(fp),
                        "state_dict": model.state_dict(),
                        "arch": arch
                    }

    if best_result:
        torch.save(best_result["state_dict"], "../models/best_graphsage_model_4feat_146fp.pth")
        with open("../configs/sage_best_config_4feat_146fp.json", "w") as f:
            json.dump({k: v for k, v in best_result.items() if k != "state_dict"}, f, indent=2)
        print(f"Saved best GraphSAGE model with 4 features (FP={fp_exact}, F2={best_f2:.4f})")
        return best_result
    else:
        raise ValueError("No configuration met the FP constraint.")

In [None]:
#run grid sweep
m_grid = np.round(np.arange(1.0, 4.0, 0.1), 2)
best_result = grid_search_graphsage(data, m_grid)

Saved best GraphSAGE model with 4 features (FP=146, F2=0.2990)


{'m': 1.0,
 'thr': 0.8539395928382874,
 'Fbeta': 0.2990033191631439,
 'precision': 0.10909090909090909,
 'recall': 0.5294117647058824,
 'false_positives': 146,
 'state_dict': OrderedDict([('convs.0.lin_l.weight',
               tensor([[ 2.5236e-01,  8.1764e-02, -1.4439e-01, -6.9121e-02],
                       [ 1.7084e-01, -1.7110e-01, -1.0833e-02, -3.5549e-04],
                       [ 6.9645e-02, -3.0716e-01,  1.0107e-01,  3.4181e-01],
                       [-2.2788e-01,  1.4115e-01,  1.0135e-01, -2.4547e-01],
                       [-2.0901e-01,  3.4390e-01, -2.3704e-01,  4.2807e-01],
                       [-7.9578e-02, -1.5358e-01, -2.8486e-01,  6.0347e-02],
                       [-2.2874e-01,  2.1388e-01, -2.4956e-01,  4.7980e-01],
                       [-1.1621e-01, -2.8689e-03, -1.8341e-02, -2.5177e-01],
                       [ 5.4514e-02,  2.7392e-01,  3.1855e-01,  4.4093e-01],
                       [ 1.2360e-02, -1.8251e-01, -6.1343e-02,  3.4851e-01],
                 

In [None]:
arch = best_result["arch"]
model = GraphSAGE(
    in_channels=data.num_features,
    hidden_channels=arch["hidden_dim"],
    out_channels=2,
    num_layers=arch["layers"],
    dropout=arch["dropout"],
    use_batchnorm=arch["bn"]
).to(device)

model.load_state_dict(best_result["state_dict"])
model.eval()
test_mask = torch.load("../configs/sage_test_mask_exact146.pt").to(device)

with torch.no_grad():
    logits = model(data)[test_mask]
    probs = torch.exp(logits)[:, 1].cpu().numpy()
    labels = data.y[test_mask].cpu().numpy()

#apply best threshold
threshold = best_result["thr"]
preds = (probs > threshold).astype(int)

#Classification report
print("\nClassification Report:")
print(classification_report(labels, preds, target_names=["Legit", "Phishing"]))
#compute F2-score
f2 = fbeta_score(labels, preds, beta=2.0)
print(f"\nF2-Score (beta=2.0): {f2:.4f}")


#Confusion Matrix
cm = confusion_matrix(labels, preds)
plt.figure(figsize=(6,5))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=["Legit", "Phishing"], yticklabels=["Legit", "Phishing"])
plt.xlabel("Predicted")
plt.ylabel("True")
plt.title("Confusion Matrix (Threshold = {:.4f})".format(threshold))
plt.show()

prec, rec, thr = precision_recall_curve(labels, probs)

plt.figure(figsize=(8,6))
plt.plot(rec, prec, label="PR Curve")
plt.scatter(best_result["recall"], best_result["precision"], color='red', label="Best Threshold = {:.4f}".format(threshold))
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.title("Precision-Recall Curve")
plt.legend()
plt.grid()
plt.show()