In [None]:
import torch
import torch.nn.functional as F
from torch.nn import Linear, ModuleDict
from torch_geometric.data import DataLoader
from torch_geometric.nn import HeteroConv, GATConv, global_mean_pool
from torch_geometric.transforms import ToUndirected
import os, glob, json
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_score, recall_score, f1_score
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
GRAPH_FOLDER = "heterographs/"

In [None]:
# Load data
with open("data/labels.json") as f:
    labels_dict = json.load(f)

graphs, labels = [], []
for file in glob.glob(os.path.join(GRAPH_FOLDER, "*.pt")):
    graph = torch.load(file, weights_only=False)
    graph = ToUndirected()(graph)
    graph['label'] = torch.tensor([labels_dict[os.path.basename(file)]], dtype=torch.float)
    graphs.append(graph)
    labels.append(labels_dict[os.path.basename(file)])

train_graphs, test_graphs = train_test_split(graphs, test_size=0.2, stratify=labels, random_state=42)
train_loader = DataLoader(train_graphs, batch_size=2, shuffle=True)
test_loader = DataLoader(test_graphs, batch_size=2)

In [None]:
# Define simplified MAGNN
class MAGNN(torch.nn.Module):
    def __init__(self, metadata, meta_paths, hidden_channels=32):
        super().__init__()
        self.meta_paths = meta_paths
        self.gnn_per_metapath = ModuleDict()
        for i, path in enumerate(meta_paths):
            self.gnn_per_metapath[f'meta_{i}'] = HeteroConv({
                edge_type: GATConv((-1, -1), hidden_channels)
                for edge_type in path
            }, aggr='sum')
        self.attn = Linear(hidden_channels, 1)
        self.final = Linear(hidden_channels, 1)

    def forward(self, x_dict, edge_index_dict, batch_dict):
        meta_outs = []
        for key, conv in self.gnn_per_metapath.items():
            x = conv(x_dict, edge_index_dict)
            x = {k: F.relu(v) for k, v in x.items()}
            pooled = [global_mean_pool(x[ntype], batch_dict[ntype]) for ntype in x]
            meta_outs.append(torch.stack(pooled).sum(dim=0))
        meta_outs = torch.stack(meta_outs, dim=1)
        attn_weights = F.softmax(self.attn(meta_outs).squeeze(-1), dim=1)
        out = (meta_outs * attn_weights.unsqueeze(-1)).sum(dim=1)
        return self.final(out).view(-1)

# Define meta-paths manually
meta_paths = [
    [('wallet', 'wallet_token', 'token'), ('token', 'token_wallet', 'wallet')],
    [('wallet', 'wallet_dev', 'dev'), ('dev', 'dev_wallet', 'wallet')],
]

In [None]:
# Training and testing
model = MAGNN(train_graphs[0].metadata(), meta_paths).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
criterion = torch.nn.BCEWithLogitsLoss()

def train():
    model.train()
    total_loss = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data.x_dict, data.edge_index_dict, data.batch_dict)
        loss = criterion(out, data['label'].to(device))
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * data.num_graphs
    return total_loss / len(train_loader.dataset)

def test(loader):
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            out = torch.sigmoid(model(data.x_dict, data.edge_index_dict, data.batch_dict))
            pred = (out > 0.5).float()
            all_preds.extend(pred.cpu().numpy())
            all_labels.extend(data['label'].cpu().numpy())
    accuracy = (np.array(all_preds) == np.array(all_labels)).mean()
    return accuracy, precision_score(all_labels, all_preds), recall_score(all_labels, all_preds), f1_score(all_labels, all_preds)

# Train loop
df = pd.DataFrame(columns=['epoch', 'loss', 'accuracy', 'precision', 'recall', 'f1'])
for epoch in range(1, 201):
    loss = train()
    accuracy, precision, recall, f1 = test(test_loader)
    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {loss:.4f}, Acc: {accuracy:.4f}, P: {precision:.4f}, R: {recall:.4f}, F1: {f1:.4f}")
    df = pd.concat([df, pd.DataFrame([{
        'epoch': epoch, 'loss': loss, 'accuracy': accuracy,
        'precision': precision, 'recall': recall, 'f1': f1
    }])], ignore_index=True)
    if loss < 0.01:
        break

df.to_csv("model/magnn_training_results.csv", index=False)
torch.save(model.state_dict(), f"model/magnn_epoch_{epoch}.pth")

In [None]:
# Plot
df['loss'] = (df['loss'] - df['loss'].min()) / (df['loss'].max() - df['loss'].min())
plt.figure(figsize=(10, 6))
for col in ['loss', 'accuracy', 'precision', 'recall', 'f1']:
    plt.plot(df['epoch'], df[col], label=col)
plt.xlabel('Epoch'); plt.ylabel('Metric'); plt.title('MAGNN Training Metrics'); plt.legend(); plt.grid(True); plt.show()