In [None]:
from ImportLocalData import loadData

In [None]:
from BalanceClassDistribution import AdjustClassSamples, NumberOfSamplesClass

In [None]:
# Import one of the custom STKG files
# Call BalanceClass... to handle outliers if you need
# Please change path names based on your local files 
data = loadData('.../node_features.txt', '.../edges.txt', '.../edge_features.txt', '.../node_labels.txt')
data = AdjustClassSamples(data) #this is optional, yet in paper we used

Import required packages

In [None]:
from torch_geometric.nn import SAGEConv
import torch.nn.functional as F
from torch.nn import LSTM, BatchNorm1d
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from torch_geometric.utils import subgraph, add_self_loops
from torch_geometric.loader import DataLoader
import torch
import numpy as np
from torch_geometric.data import Data
from sklearn.metrics.pairwise import cosine_similarity
import networkx as nx


TemporalSAGE model structure

In [None]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class TemporalSAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, lstm_hidden_size, out_channels):
        super(TemporalSAGE, self).__init__()
        self.conv1 = SAGEConv(in_channels, hidden_channels)
        self.bn1 = BatchNorm1d(hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, hidden_channels)
        self.bn2 = BatchNorm1d(hidden_channels)
        self.lstm = LSTM(hidden_channels, lstm_hidden_size, batch_first=True)
        self.fc = torch.nn.Linear(lstm_hidden_size, out_channels)
        self.dropout = torch.nn.Dropout(p=0.5)

    def forward(self, data):
        x, edge_index = data.x.to(device), data.edge_index.to(device)
        x = self.conv1(x, edge_index)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.conv2(x, edge_index)
        x = self.bn2(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = x.view(data.num_nodes, -1, x.size(1))
        lstm_out, _ = self.lstm(x)
        lstm_out = lstm_out[:, -1, :]
        out = self.fc(lstm_out)
        return out        
        


Cosine similarity based data augmentation

In [None]:

def addEdgesSimilarity(data, similarity_threshold=0.8, max_new_edges=100):
# Calculate cosine similarity for all node pairs
    nodeFeatures = data.x.cpu().numpy()
    similarity_matrix = cosine_similarity(nodeFeatures)
    new_edges = []
    numNodes = data.num_nodes
    for i in range(numNodes):
        for j in range(i+1, numNodes):
            if similarity_matrix[i, j] > similarity_threshold:
                new_edges.append((i, j))
                if len(new_edges) >= max_new_edges:
                    break
        if len(new_edges) >= max_new_edges:
            break
    newEdgesTensor = torch.tensor(new_edges, dtype=torch.long).t().to(data.edge_index.device)
    data.edge_index = torch.cat([data.edge_index, newEdgesTensor], dim=1)
    return data

def removeEdgesCentrality(data, remove_ratio=0.1):
    G = nx.Graph()
    edges = data.edge_index.t().cpu().numpy()
    G.add_edges_from(edges)
# Compute degree centrality for each node
    centrality = nx.degree_centrality(G)
    edgeImportance = [(u, v, centrality[u] + centrality[v]) for u, v in G.edges]
# Sort edges by importance and select edges to remove
    edgeImportance.sort(key=lambda x: x[2])
    removeCount = int(len(edgeImportance) * remove_ratio)
    edgesRemove = edgeImportance[:removeCount]    
# Remove low-centrality edges
    for u, v, _ in edgesRemove:
        G.remove_edge(u, v)    
# Update edge_index after removal
    data.edge_index = torch.tensor(list(G.edges), dtype=torch.long).t().to(data.edge_index.device)    
    return data


# For balance class distribution    
class FocalLoss(torch.nn.Module):
    def __init__(self, alpha=1, gamma=2, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma 
        self.reduction = reduction

    def forward(self, inputs, targets):
        CeLoss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-CeLoss)
        # Focal Loss calculation
        F_loss = self.alpha * (1 - pt) ** self.gamma * CeLoss
        if self.reduction == 'mean':
            return torch.mean(F_loss)
        elif self.reduction == 'sum':
            return torch.sum(F_loss)
        else:
            return F_loss


Function for Top-1/ Top-5 accuracy 

In [None]:

def calculate_accuracy(outputs, labels):
# Top-1 accuracy
    _, preds = torch.max(outputs, 1)
    top1_correct = preds.eq(labels).sum().item()
    top1_acc = top1_correct / len(labels)    
# Top-5 accuracy
    _, top5_preds = outputs.topk(5, 1, True, True)
    top5_correct = top5_preds.eq(labels.view(-1, 1)).sum().item()
    top5_acc = top5_correct / len(labels)
    return top1_acc, top5_acc

Data preperation module

In [None]:
scaler = StandardScaler()
x_scaled = scaler.fit_transform(data.x.numpy())
data.x = torch.tensor(x_scaled, dtype=torch.float)
data = addEdgesSimilarity(data, similarity_threshold=0.8, max_new_edges=100)
data = removeEdgesCentrality(data, remove_ratio=0.1)

# Split data into train (80%), validation (10%), and test (10%)
train_idx, test_idx = train_test_split(torch.arange(data.num_nodes), test_size=0.2, stratify=data.y.cpu().numpy())
train_idx, val_idx = train_test_split(train_idx, test_size=0.125, stratify=data.y[train_idx].cpu().numpy())  # 0.125 * 0.8 = 0.1
train_idx = torch.tensor(train_idx, dtype=torch.long)
val_idx = torch.tensor(val_idx, dtype=torch.long)
test_idx = torch.tensor(test_idx, dtype=torch.long)

# Create subgraphs for each split
trainSubGraph = subgraph(train_idx, data.edge_index, relabel_nodes=True, num_nodes=data.num_nodes)
valSubGraph = subgraph(val_idx, data.edge_index, relabel_nodes=True, num_nodes=data.num_nodes)
testSubGraph = subgraph(test_idx, data.edge_index, relabel_nodes=True, num_nodes=data.num_nodes)
train_data = Data(x=data.x[train_idx], edge_index=trainSubGraph[0], y=data.y[train_idx])
val_data = Data(x=data.x[val_idx], edge_index=valSubGraph[0], y=data.y[val_idx])
test_data = Data(x=data.x[test_idx], edge_index=testSubGraph[0], y=data.y[test_idx])


Initialization of model parameters

In [None]:
batch_size = 64
train_loader = DataLoader([train_data], batch_size=batch_size, shuffle=True)
val_loader = DataLoader([val_data], batch_size=batch_size)
test_loader = DataLoader([test_data], batch_size=batch_size)
numNodeFeatures = data.num_node_features
num_classes = len(data.y.unique())
hidden_channels = 256
lstm_hidden_size = 128

model = TemporalSAGE(numNodeFeatures, hidden_channels, lstm_hidden_size, num_classes).to(device)
optim= torch.optim.Adam(model.parameters(), lr=0.02, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optim, step_size=1000, gamma=0.5)
loss_fn = torch.nn.CrossEntropyLoss()
loss_focal = FocalLoss()

Training Phase

In [None]:
for epoch in range(2000):
    model.train()
    epoch_loss = 0    
    for batch in train_loader:
        optim.zero_grad()
        out = model(batch)
        loss_ce = loss_fn(out, batch.y.to(device))
        loss_focal_value = loss_focal(out, batch.y.to(device))
        loss = 0.1 * loss_ce + 0.9 * loss_focal_value
        loss.backward()
        optim.step()
        epoch_loss += loss.item()
    scheduler.step()
    
    if (epoch + 1) % 100 == 0:
        model.eval()
        val_losses = []
        val_top1_correct = 0
        val_top5_correct = 0
        val_sum = 0
        with torch.no_grad():
            for batch in val_loader:
                out_val = model(batch)
                val_loss1 = loss_fn(out_val, batch.y.to(device))
                val_loss2 = loss_focal(out_val, batch.y.to(device))
                val_loss = 0.1 * val_loss1 + 0.9 * val_loss2
                val_losses.append(val_loss.item())
                top1_acc, top5_acc = calculate_accuracy(out_val, batch.y.to(device))
                val_top1_correct += top1_acc * len(batch.y)
                val_top5_correct += top5_acc * len(batch.y)
                val_sum += len(batch.y)
            val_loss = np.mean(val_losses)
            val_top1Accuracy = val_top1_correct / val_sum
            val_top5Accuracy = val_top5_correct / val_sum
            
        print(f'Epoch: {epoch+1}, Loss: {epoch_loss:.4f}, Val Loss: {val_loss:.4f}')
        print(f'Val Top-1 Acc: {val_top1Accuracy:.4f}, Val Top-5 Acc: {val_top5Accuracy:.4f}, LR: {scheduler.get_last_lr()[0]:.6f}')


Model evaluation

In [None]:
model.eval()
test_top1_correct = 0
test_top5_correct = 0
test_total = 0

with torch.no_grad():
    for batch in test_loader:
        out_test = model(batch)
        top1_acc, top5_acc = calculate_accuracy(out_test, batch.y.to(device))
        test_top1_correct += top1_acc * len(batch.y)
        test_top5_correct += top5_acc * len(batch.y)
        test_total += len(batch.y)
test_top1Accuracy = test_top1_correct / test_total
test_top5Accuracy = test_top5_correct / test_total

print(f'\nTest Top-1 Accuracy: {test_top1Accuracy:.4f}')
print(f'Test Top-5 Accuracy: {test_top5Accuracy:.4f}')
torch.save(model.state_dict(), 'TemporalSAGE_model.pth')