## install requirements

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv
import networkx as nx
import numpy as np


## Create sample data for testing

In [2]:
def create_sample_data():
    # Create a small network with 10 nodes and 15 edges
    num_nodes = 10
    num_edges = 15
    
    # Node features (random for demonstration)
    x = torch.randn(num_nodes, 16)  # 16 features per node
    
    # Random edges
    edge_index = torch.randint(0, num_nodes, (2, num_edges))
    
    # Edge features
    edge_attr = torch.randn(num_edges, 8)  # 8 features per edge
    
    # Labels (for supervised learning)
    y = torch.randint(0, 2, (num_nodes,)).float()
    
    return x, edge_index, edge_attr, y


## Graph Construction


In [3]:
class NetworkGraph:
    def construct_graph(self, network_data):
        nodes = self.create_nodes(network_data)
        edges = self.create_edges(network_data)
        
        for node in nodes:
            node.features = self.extract_node_features(node)
        
        for edge in edges:
            edge.features = self.extract_edge_features(edge)
        
        return nodes, edges
    
    def create_nodes(self, network_data):
        # Simplified example
        return [{'id': i} for i in range(len(network_data))]
    
    def create_edges(self, network_data):
        # Simplified example
        return [{'source': i, 'target': i+1} for i in range(len(network_data)-1)]
    
    def extract_node_features(self, node):
        # Simplified feature extraction
        return torch.randn(16)  # 16-dimensional feature vector
    
    def extract_edge_features(self, edge):
        # Simplified feature extraction
        return torch.randn(8)  # 8-dimensional feature vector


## GNN Model Architecture

In [4]:
class APTDetectionModel(nn.Module):
    def __init__(self, in_features, hidden_features, num_layers):
        super().__init__()
        self.gat_layers = nn.ModuleList([
            GATConv(
                in_features if i == 0 else hidden_features,
                hidden_features
            ) for i in range(num_layers)
        ])
        self.gru = nn.GRU(hidden_features, hidden_features)
        self.output = nn.Linear(hidden_features, 1)
    
    def forward(self, x, edge_index, edge_attr):
        for gat in self.gat_layers:
            x = F.relu(gat(x, edge_index, edge_attr))
        x, _ = self.gru(x.unsqueeze(0))
        return self.output(x.squeeze(0)).squeeze(-1)
    
    def get_attention_weights(self, node):
        # Simplified attention weight extraction
        return torch.randn(10)  # Random weights for demonstration


## Training Function

In [5]:
def train_model(model, labeled_data, unlabeled_data, num_epochs=10):
    optimizer = torch.optim.Adam(model.parameters())
    
    for epoch in range(num_epochs):
        model.train()
        
        # Supervised learning
        optimizer.zero_grad()
        x, edge_index, edge_attr, y = labeled_data
        out = model(x, edge_index, edge_attr)
        loss = F.binary_cross_entropy_with_logits(out, y)
        loss.backward()
        optimizer.step()
        
        # Print training progress
        if epoch % 2 == 0:
            print(f"Epoch {epoch}, Loss: {loss.item():.4f}")


## APT Detection Function

In [6]:
def detect_apts(model, graph, threshold=0.5):
    model.eval()
    with torch.no_grad():
        x, edge_index, edge_attr, _ = graph
        anomaly_scores = torch.sigmoid(model(x, edge_index, edge_attr))
        suspicious_nodes = (anomaly_scores > threshold).nonzero().flatten()
        
        results = []
        for node in suspicious_nodes:
            attention_weights = model.get_attention_weights(node)
            results.append({
                'node': node.item(),
                'score': anomaly_scores[node].item(),
                'attention': attention_weights
            })
        
        return results


## Test the implementation

In [12]:
if __name__ == "__main__":
    # Create sample data
    x, edge_index, edge_attr, y = create_sample_data()
    
    # Initialize model
    model = APTDetectionModel(in_features=16, hidden_features=32, num_layers=2)
    
    # Train model
    print("Training model...")
    train_model(model, (x, edge_index, edge_attr, y), None)
    
    # Detect APTs
    print("\nDetecting APTs...")
    results = detect_apts(model, (x, edge_index, edge_attr, y))
    
    # Print results
    print("\nDetection Results:\n")
    for result in results:
        print(f"Node {result['node']}: Anomaly Score = {result['score']:.4f}")


Training model...
Epoch 0, Loss: 0.6991
Epoch 2, Loss: 0.6949
Epoch 4, Loss: 0.6909
Epoch 6, Loss: 0.6869
Epoch 8, Loss: 0.6827

Detecting APTs...

Detection Results:

Node 0: Anomaly Score = 0.5238
Node 1: Anomaly Score = 0.5263
Node 2: Anomaly Score = 0.5223
Node 3: Anomaly Score = 0.5004
Node 4: Anomaly Score = 0.5336
Node 5: Anomaly Score = 0.5058
Node 7: Anomaly Score = 0.5324
Node 8: Anomaly Score = 0.5373
Node 9: Anomaly Score = 0.5334
