Import reuired packages 

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GATConv, global_mean_pool
from sklearn.model_selection import train_test_split
import numpy as np
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.optim.lr_scheduler import LambdaLR
from torch.optim.lr_scheduler import OneCycleLR
from torch.optim.lr_scheduler import CosineAnnealingLR
from sklearn.metrics import top_k_accuracy_score
from torch_geometric.nn import GCNConv

Load Data

In [3]:
#give proper paths 

def load_data():
    node_features = []
    with open('.../node_features.txt', 'r') as f:
        for line in f:
            node_features.append([float(x) for x in line.strip().split()[1:]])
    x = torch.tensor(node_features, dtype=torch.float)

    edge_index = []
    with open('.../edges.txt', 'r') as f:
        for line in f:
            edge_index.append([int(x) for x in line.strip().split()])
    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()

    edge_attr = []
    with open('.../edge_features.txt', 'r') as f:
        for line in f:
            edge_attr.append([float(x) for x in line.strip().split()[2:]])
    edge_attr = torch.tensor(edge_attr, dtype=torch.float)

    y = []
    with open('.../node_labels.txt', 'r') as f:
        for line in f:
            y.append(int(line.strip().split()[1]))
    y = torch.tensor(y, dtype=torch.long)

    # Create batch knowledge
    batch = torch.arange(x.size(0), dtype=torch.long)
    
    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, batch=batch)

We built this module to simulate the temporal continuity between fames as the proposed action recognition module in AKU graphs.

In [4]:
def add_temporal_edges(data, num_frames=8):
    temporal_edges = []
    for i in range(data.num_nodes - 1):
        temporal_edges.append([i, i+1]) #combine nodes in consecutive frames
    temporal_edges = torch.tensor(temporal_edges, dtype=torch.long).t()
    data.edge_index = torch.cat([data.edge_index, temporal_edges], dim=1) # combine orijinal edge_index 
    temporal_attr = torch.ones(temporal_edges.size(1), dtype=torch.float).unsqueeze(-1) ## Attribute for temporal edges (time difference)
    # Make edge attribute size compatible
    if data.edge_attr is not None and data.edge_attr.dim() > 1:
        original_attr = data.edge_attr
    else:
        original_attr = torch.ones(data.edge_index.size(1) - temporal_edges.size(1), 1, dtype=torch.float)
    
    data.edge_attr = torch.cat([original_attr, temporal_attr], dim=0)
    
    return data


AKU-INSPRED MODEL OVER OUR STKG DATA (Since AKU is multimodal, we only concentated on action recognition module)

In [7]:
class SimpleGCNLayer(GCNConv):
    def forward(self, features, indices, _=None):
        return super().forward(features, indices)

class AKUModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_classes):
        super().__init__()
        self.gcn1 = SimpleGCNLayer(input_dim, hidden_dim)
        self.gcn2 = SimpleGCNLayer(hidden_dim, hidden_dim)
        
        self.classifier = nn.Sequential(nn.Linear(hidden_dim, hidden_dim // 2),nn.ReLU(),nn.Linear(hidden_dim // 2, num_classes))

        self.uncertainty_head = nn.Sequential(nn.Linear(hidden_dim, hidden_dim // 2),nn.ReLU(),nn.Linear(hidden_dim // 2, 1), nn.Sigmoid())

    def forward(self, data):
        device = data.x.device
        temporal_edges = torch.stack([torch.arange(data.num_nodes - 1, device=device),torch.arange(1, data.num_nodes, device=device)])
        edge_index = torch.cat([data.edge_index, temporal_edges], dim=1)

        # GCNConv don't use edge_attr
        x = F.relu(self.gcn1(data.x, edge_index))
        x = F.dropout(x, p=0.2, training=self.training)
        x = F.relu(self.gcn2(x, edge_index))

        x = global_mean_pool(x, data.batch if hasattr(data, 'batch') else torch.zeros(data.num_nodes, dtype=torch.long, device=device))

        logits = self.classifier(x)
        uncertainty = self.uncertainty_head(x).squeeze(-1)
        return logits, uncertainty

Loss function, we mimic the uncertainty learning in AKU paper

In [8]:
def uncertainty_loss(pred_logits, uncertainty, targets, alpha=0.5):
    ce_loss = F.cross_entropy(pred_logits, targets)
        pred_probs = F.softmax(pred_logits, dim=1)
    max_probs = pred_probs.max(dim=1)[0]
    unc_loss = F.mse_loss(uncertainty, 1 - max_probs.detach())
    
    return ce_loss + alpha * unc_loss

TRAIN / TEST / EVAL

In [None]:

def train(model, data, optimizer, train_idx):
    model.train()
    optimizer.zero_grad()
    logits, uncertainty = model(data)
    loss = uncertainty_loss(logits[train_idx], uncertainty[train_idx], data.y[train_idx])
    loss.backward()
    optimizer.step()
    return loss.item()

def evaluate(model, data, idx, uncertainty_threshold=0.4):
    model.eval()
    with torch.no_grad():
        logits, uncertainty = model(data)
        pred = logits.argmax(dim=1) 
        # Top-1 accuracy calculation
        top1_acc = (pred[idx] == data.y[idx]).float().mean().item()
        # Top-5 accuracy calculation
        top5_acc = top_k_accuracy_score(data.y[idx].cpu().numpy(), logits[idx].cpu().numpy(), k=5, labels=np.arange(logits.size(1)))
        
        # Filtering based on uncertainty
        mask = uncertainty < uncertainty_threshold
        filtered_idx = idx[mask[idx]]
        if len(filtered_idx) > 0:
            acc = (pred[filtered_idx] == data.y[filtered_idx]).float().mean().item()
            top5_confident = top_k_accuracy_score(data.y[filtered_idx].cpu().numpy(), logits[filtered_idx].cpu().numpy(), k=5, labels=np.arange(logits.size(1)))
        else:
            acc = 0.0
            top5_confident = 0.0
            
        return (top1_acc, top5_acc)  # Top1 and Top5 Acc (all samples)



The main function

In [None]:
def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    data = load_data()
    data = data.to(device)

    # Split data with proper distribution betweeen classes
    indices = list(range(data.num_nodes))
    train_idx, test_val_idx = train_test_split(indices, test_size=0.2, stratify=data.y.cpu().numpy())
    val_idx, test_idx = train_test_split(test_val_idx, test_size=0.5, stratify=data.y[test_val_idx].cpu().numpy())
    train_idx = torch.tensor(train_idx, dtype=torch.long, device=device)
    val_idx = torch.tensor(val_idx, dtype=torch.long, device=device)
    test_idx = torch.tensor(test_idx, dtype=torch.long, device=device)

    num_classes = int(data.y.max().item()) + 1
    model = AKUModel(input_dim=data.x.size(1), hidden_dim=64, num_classes=num_classes).to(device)
    
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, nesterov=True, weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=1, epochs=4000, pct_start=0.3, anneal_strategy='cos')

    best_val_acc = 0
    best_val_top5 = 0
    best_model = None
    patience = 3600
    patience_counter = 0

    for epoch in range(1, 4001):
        loss = train(model, data, optimizer, train_idx)
        val_top1, val_top5, val_acc, val_top5_confident, val_confident_ratio = evaluate(model, data, val_idx)
        
        scheduler.step()
        
        # Early stopping check
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_val_top5 = val_top5
            best_model = model.state_dict()
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"\nEarly stopping at epoch {epoch} as validation accuracy didn't improve for {patience} evaluations")
                break

        if epoch % 10 == 0 or epoch == 1:
            current_lr = optimizer.param_groups[0]['lr']
            print(f"Epoch {epoch:03d} | Loss: {loss:.4f} | "
                  f"Val Top1: {val_top1*100:.2f}% | Val Top5: {val_top5*100:.2f}% | "
                  f"Val Acc: {val_acc*100:.2f}% | Val Top5 Conf: {val_top5_confident*100:.2f}% | "
                  f"Confident: {val_confident_ratio*100:.1f}% | LR: {current_lr:.6f}")

    model.load_state_dict(best_model)
    test_top1, test_top5, test_acc, test_top5_confident, test_confident_ratio = evaluate(model, data, test_idx)
    print(f"\n✅ Final Test Results:")
    print(f"Top-1 Accuracy: {test_top1*100:.2f}%")
    print(f"Top-5 Accuracy: {test_top5*100:.2f}%")

if __name__ == '__main__':
    main()