In [1]:
from ImportLocalData import loadData

In [2]:
from BalanceClassDistribution import AdjustClassSamples, NumberOfSamplesClass

In [None]:
# Import one of the custom STKG files
# Call BalanceClass... to handle outliers if you need
# Please change path names based on your local files 
data = loadData('.../node_features.txt', '.../edges.txt', '.../edge_features.txt', '.../node_labels.txt')
data = AdjustClassSamples(data) #this is optional, yet in paper we used

Train and test the model for top-1 / top-5 accuracy

Import required packages

In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.loader import DataLoader
from torch.nn import LayerNorm
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from torch_geometric.utils import subgraph
from torch_geometric.data import Data

StableGCN Model Structure

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class StableGCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(StableGCN, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.norm1 = LayerNorm(hidden_channels)
        self.norm2 = LayerNorm(hidden_channels)
        self.fc = torch.nn.Linear(hidden_channels, out_channels)
        self.dropout = torch.nn.Dropout(p=0.5)
        self.skip_proj = torch.nn.Linear(in_channels, hidden_channels)

    def forward(self, data):
        x, edge_index = data.x.to(device), data.edge_index.to(device)        
        x_init = self.skip_proj(x)
        x = self.conv1(x, edge_index)
        x = self.norm1(x)
        x = F.relu(x)
        x = x + x_init  # Skip connection
        x = self.dropout(x)
        x_init = x
        x = self.conv2(x, edge_index)
        x = self.norm2(x)
        x = F.relu(x)
        x = x + x_init
        x = self.dropout(x)
        out = self.fc(x)
        return out



Prepare data and subgraphs

In [5]:
def prepareData(data):
    scaler = StandardScaler()
    data.x = data.x.cpu()
    x_scaled = scaler.fit_transform(data.x.numpy())
    data.x = torch.tensor(x_scaled, dtype=torch.float).to(device)
    
# tarin-test split
    train_idx, test_idx = train_test_split(torch.arange(data.num_nodes), test_size=0.2, stratify=data.y.cpu().numpy())
    train_idx = torch.tensor(train_idx, dtype=torch.long)
    test_idx = torch.tensor(test_idx, dtype=torch.long)    
# create subgraps
    trainSubgraph = subgraph(train_idx, data.edge_index, relabel_nodes=True, num_nodes=data.num_nodes)
    testSubgraph = subgraph(test_idx, data.edge_index, relabel_nodes=True, num_nodes=data.num_nodes)
    train_data = Data(x=data.x[train_idx], edge_index=trainSubgraph[0], y=data.y[train_idx])
    test_data = Data(x=data.x[test_idx], edge_index=testSubgraph[0], y=data.y[test_idx])
    return train_data, test_data


Train / Evaluate Model 

In [None]:
def train_model(model, train_data, num_epochs=4000):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.5)
    loss_fn = torch.nn.CrossEntropyLoss()
    train_loader = DataLoader([train_data], batch_size=32, shuffle=True)
    
    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0       
        for batch in train_loader:
            batch = batch.to(device)
            optimizer.zero_grad()
            out = model(batch)
            loss = loss_fn(out, batch.y)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            epoch_loss += loss.item()        
        scheduler.step()        
        if (epoch + 1) % 100 == 0:
            print(f'Epoch: {epoch+1}, Loss: {epoch_loss:.4f}, LR: {scheduler.get_last_lr()[0]:.6f}')

# Evalution fution with Top-1 and Top-5 Acuracy
def evaluate_model(model, test_data):
    model.eval()
    test_loader = DataLoader([test_data], batch_size=32)
    correctTop1 = 0
    correctTop5 = 0
    total = 0
    
    with torch.no_grad():
        for batch in test_loader:
            batch = batch.to(device)
            outputs = model(batch)
            _, predicted = torch.max(outputs.data, 1)  # Top1 accuracy
            correctTop1 += (predicted == batch.y).sum().item()
            _, top5_pred = outputs.topk(5, 1, True, True)  # Top-5 accuray
            correctTop5 += top5_pred.eq(batch.y.view(-1, 1)).sum().item()            
            total += batch.y.size(0)
    
    top1Accuracy = 100 * correctTop1 / total
    top5Accuracy = 100 * correctTop5 / total
    
    print(f'Top-1 Accuracy: {top1Accuracy:.2f}%')
    print(f'Top-5 Accuracy: {top5Accuracy:.2f}%')

if __name__ == "__main__":
    train_data, test_data = prepareData(data)
    
    # Initialize model
    num_node_features = data.num_node_features
    num_classes = len(data.y.unique())
    model = StableGCN(num_node_features, 256, num_classes).to(device)
    
    train_model(model, train_data)
    evaluate_model(model, test_data)
    
    torch.save(model.state_dict(), 'StableGCN_model.pth')