In [None]:
import torch
import torch.nn.functional as F
from torch.nn import Linear
from torch_geometric.data import DataLoader
from torch_geometric.nn import HeteroConv, GCNConv, global_mean_pool
from torch_geometric.transforms import ToUndirected
import os, glob, json
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_score, recall_score, f1_score
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
GRAPH_FOLDER = "heterographs/"

In [None]:
# Load data
with open("data/labels.json") as f:
    labels_dict = json.load(f)

graphs, labels = [], []
for file in glob.glob(os.path.join(GRAPH_FOLDER, "*.pt")):
    graph = torch.load(file, weights_only=False)
    graph = ToUndirected()(graph)
    graph['label'] = torch.tensor([labels_dict[os.path.basename(file)]], dtype=torch.float)
    graphs.append(graph)
    labels.append(labels_dict[os.path.basename(file)])

train_graphs, test_graphs = train_test_split(graphs, test_size=0.2, stratify=labels, random_state=42)
train_loader = DataLoader(train_graphs, batch_size=2, shuffle=True)
test_loader = DataLoader(test_graphs, batch_size=2)

In [None]:
# Define model
class RGCN(torch.nn.Module):
    def __init__(self, metadata, hidden_channels=32):
        super().__init__()
        self.conv1 = HeteroConv({
            edge_type: GCNConv((-1, -1), hidden_channels)
            for edge_type in metadata[1]
        }, aggr='sum')
        self.lin = Linear(hidden_channels, 1)

    def forward(self, x_dict, edge_index_dict, batch_dict):
        x_dict = self.conv1(x_dict, edge_index_dict)
        x_dict = {k: F.relu(v) for k, v in x_dict.items()}
        pooled = [global_mean_pool(x_dict[ntype], batch_dict[ntype]) for ntype in x_dict]
        out = torch.stack(pooled).sum(dim=0)
        return self.lin(out).view(-1)

In [None]:
# Training and testing
model = RGCN(train_graphs[0].metadata()).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
criterion = torch.nn.BCEWithLogitsLoss()

def train():
    model.train()
    total_loss = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data.x_dict, data.edge_index_dict, data.batch_dict)
        loss = criterion(out, data['label'].to(device))
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * data.num_graphs
    return total_loss / len(train_loader.dataset)

def test(loader):
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            out = torch.sigmoid(model(data.x_dict, data.edge_index_dict, data.batch_dict))
            pred = (out > 0.5).float()
            all_preds.extend(pred.cpu().numpy())
            all_labels.extend(data['label'].cpu().numpy())
    accuracy = (np.array(all_preds) == np.array(all_labels)).mean()
    return accuracy, precision_score(all_labels, all_preds), recall_score(all_labels, all_preds), f1_score(all_labels, all_preds)

# Train loop
df = pd.DataFrame(columns=['epoch', 'loss', 'accuracy', 'precision', 'recall', 'f1'])
for epoch in range(1, 201):
    loss = train()
    accuracy, precision, recall, f1 = test(test_loader)
    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {loss:.4f}, Acc: {accuracy:.4f}, P: {precision:.4f}, R: {recall:.4f}, F1: {f1:.4f}")
    df = pd.concat([df, pd.DataFrame([{
        'epoch': epoch, 'loss': loss, 'accuracy': accuracy,
        'precision': precision, 'recall': recall, 'f1': f1
    }])], ignore_index=True)
    if loss < 0.01:
        break

df.to_csv("model/rgcn_training_results.csv", index=False)
torch.save(model.state_dict(), f"model/rgcn_epoch_{epoch}.pth")

In [None]:
# Plot
df['loss'] = (df['loss'] - df['loss'].min()) / (df['loss'].max() - df['loss'].min())
plt.figure(figsize=(10, 6))
for col in ['loss', 'accuracy', 'precision', 'recall', 'f1']:
    plt.plot(df['epoch'], df[col], label=col)
plt.xlabel('Epoch'); plt.ylabel('Metric'); plt.title('R-GCN Training Metrics'); plt.legend(); plt.grid(True); plt.show()
