In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import f1_score
import argparse
from torch_geometric.data import Data
import joblib

# **NEW: Import the GCNConv layer from PyTorch Geometric**
from torch_geometric.nn import GCNConv

class GCN(nn.Module):
    def __init__(self, in_feats, hidden_size, out_feats, dropout=0.5):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(in_feats, hidden_size)
        self.conv2 = GCNConv(hidden_size, out_feats)
        self.dropout = nn.Dropout(dropout)
        self.relu = nn.ReLU()

    def forward(self, x, edge_index):
        # The GCNConv layer handles both transformation and aggregation
        h = self.conv1(x, edge_index)
        h = self.relu(h)
        h = self.dropout(h)
        h = self.conv2(h, edge_index)
        return h

def fix_graph_data(graph):
    """Convert CogDL graph to PyTorch Geometric Data format"""
    
    # Extract features and labels
    x = graph.x  # Node features
    y = graph.y  # Node labels (multi-label)
    
    # Extract edge information and ensure it's in the correct format
    
    edge_index = torch.stack(list(graph.edge_index), dim=0)
    
    # Extract masks
    train_mask = graph.train_mask
    val_mask = graph.val_mask
    test_mask = graph.test_mask
    
    print(f"Converted edge_index shape: {edge_index.shape}")
    print(f"Converted edge_index dtype: {edge_index.dtype}")
    
    # Create PyTorch Geometric Data object
    data = Data(
        x=x,
        edge_index=edge_index,
        y=y,
        train_mask=train_mask,
        val_mask=val_mask,
        test_mask=test_mask
    )
    
    return data

scaler = joblib.load('scaler.gz')

In [2]:
import torch

if torch.cuda.is_available() and torch.cuda.device_count() > 1:
  device = torch.device('cuda:4')
  print("✅ Using second GPU: cuda:4")
else:
  device = torch.device('cpu')
  print("⚠️ Second GPU not found, falling back to CPU.")

✅ Using second GPU: cuda:4


In [None]:
data = torch.load('./experiment_runs/run_2025-10-02_14-47-01/final_graph.pt')
data = fix_graph_data(data)
data.to(device)

Converted edge_index shape: torch.Size([2, 124056])
Converted edge_index dtype: torch.int64


Data(x=[389066, 768], edge_index=[2, 124056], y=[389066, 54], train_mask=[389066], val_mask=[389066], test_mask=[389066])

In [4]:


loaded_model = GCN(
    in_feats=data.x.size(1),
    hidden_size=256,
    out_feats=data.y.size(1),
    dropout=0.5
).to(device)

print("\nModel Architecture (True GCN):")
print(loaded_model)


# 2. Load the saved state dictionary
state_dict = torch.load('trained_gcn.pt')

# 3. Apply the weights to the model
loaded_model.load_state_dict(state_dict)

# 4. Set the model to evaluation mode
loaded_model.eval()

# Now the model is ready for inference


Model Architecture (True GCN):
GCN(
  (conv1): GCNConv(768, 256)
  (conv2): GCNConv(256, 54)
  (dropout): Dropout(p=0.5, inplace=False)
  (relu): ReLU()
)


GCN(
  (conv1): GCNConv(768, 256)
  (conv2): GCNConv(256, 54)
  (dropout): Dropout(p=0.5, inplace=False)
  (relu): ReLU()
)

In [8]:
# --- 3. Apply the *Saved* Scaler ---
new_features_scaled = scaler.transform(data.x.cpu())
new_features_tensor = torch.from_numpy(new_features_scaled).float().to(device)

In [9]:
logits = loaded_model(new_features_tensor, data.edge_index)

In [10]:

criterion = nn.BCEWithLogitsLoss(reduction='mean')

mask = data.test_mask
eval_logits = logits[mask]
eval_labels = data.y[mask].float()

loss = criterion(eval_logits, eval_labels)

probs = torch.sigmoid(eval_logits)
preds = (probs > 0.5).int()

f1 = f1_score(
    eval_labels.cpu().numpy(), 
    preds.cpu().numpy(), 
    average="micro", 
    zero_division=0
)
print(f1)

0.48817966903073284


In [11]:
import numpy as np

np.where(preds.sum(axis=1).cpu()>0)

(array([   0,    2,    3, ..., 6073, 6074, 6077]),)

In [12]:
preds[3]

tensor([1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0], device='cuda:4', dtype=torch.int32)