In [51]:
# -------------------------
# IMPORTS AND SETUP
# -------------------------

import os
import random
import numpy as np
import networkx as nx

import matplotlib.pyplot as plt
import numpy as np
import random

#import range tqdm
from tqdm import tqdm
from tqdm import trange

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.data import DataLoader
from torch_geometric.data import Data
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATv2Conv, global_mean_pool



In [57]:


class CustomGNN(nn.Module):
    def __init__(self, num_node_features, num_edge_features, num_classes):
        super(CustomGNN, self).__init__()
        self.conv1 = GATv2Conv(num_node_features, 16, edge_dim=num_edge_features)
        self.conv2 = GATv2Conv(16, 32, edge_dim=num_edge_features)
        self.fc1 = nn.Linear(32, 16)
        self.fc2 = nn.Linear(16, num_classes)

    def forward(self, data):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch

        x = F.relu(self.conv1(x, edge_index, edge_attr))
        x = F.dropout(x, p=0.5, training=self.training)

        x = F.relu(self.conv2(x, edge_index, edge_attr))
        x = global_mean_pool(x, batch)  # Only if it's a graph classification task

        x = F.relu(self.fc1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.fc2(x)

        return x


In [61]:



# -------------------------
# GRAPH PROCESSING
# -------------------------


def graph_to_data(graph):
    # Get a mapping from old node indices to new ones
    node_mapping = {node: i for i, node in enumerate(graph.nodes())}

    # Use the node mapping to convert node indices
    edge_index = torch.tensor([(node_mapping[u], node_mapping[v]) for u, v in graph.edges()], dtype=torch.long).t().contiguous()


    x = torch.tensor([[
        attributes['struct_size'],
        attributes['valid_pointer_count'],
        attributes['invalid_pointer_count'],
        attributes['first_pointer_offset'],
        attributes['last_pointer_offset'],
        attributes['first_valid_pointer_offset'],
        attributes['last_valid_pointer_offset'],
    ] for _, attributes in graph.nodes(data=True)], dtype=torch.float)

    edge_attr = torch.tensor([data['offset'] for u, v, data in graph.edges(data=True)], dtype=torch.float).unsqueeze(1)
    
    # Convert x to a numpy array for normalization
    x_np = x.numpy()

    # Standardize features (subtract mean, divide by standard deviation)
    x_np = (x_np - np.mean(x_np, axis=0)) / np.std(x_np, axis=0)

    # Convert back to tensor
    x = torch.tensor(x_np, dtype=torch.float)

    edge_attr_np = edge_attr.numpy()
    edge_attr_np = (edge_attr_np - np.mean(edge_attr_np, axis=0)) / np.std(edge_attr_np, axis=0)
    edge_attr = torch.tensor(edge_attr_np, dtype=torch.float)

    # if there are 2 keys then y = 0, if there are 4 keys then y = 1, if there are 6 keys then y = 2
    key_count = len([node for node in graph.nodes() if graph.nodes[node]['cat'] == 1])

    if key_count == 2:
        y = torch.tensor(0, dtype=torch.long)  # Class index for 2 keys
    elif key_count == 4:
        y = torch.tensor(1, dtype=torch.long)  # Class index for 4 keys
    elif key_count == 6:
        y = torch.tensor(2, dtype=torch.long)  # Class index for 6 keys
    else:
        raise ValueError(f"Invalid number of keys: {key_count}")
    

    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)

def remove_all_isolated_nodes(graph):
    graph.remove_nodes_from(list(nx.isolates(graph)))
    return graph

def convert_types(G):
    # Convert the string attributes to their corresponding types
    for node, data in G.nodes(data=True):
        # The label remains a string, so no conversion is needed for 'label'
        # Convert struct_size, valid_pointer_count, invalid_pointer_count,
        # first_pointer_offset, last_pointer_offset, first_valid_pointer_offset,
        # last_valid_pointer_offset, and address to int
        data['struct_size'] = int(data['struct_size'])
        data['valid_pointer_count'] = int(data['valid_pointer_count'])
        data['invalid_pointer_count'] = int(data['invalid_pointer_count'])
        data['first_pointer_offset'] = int(data['first_pointer_offset'])
        data['last_pointer_offset'] = int(data['last_pointer_offset'])
        data['first_valid_pointer_offset'] = int(data['first_valid_pointer_offset'])
        data['last_valid_pointer_offset'] = int(data['last_valid_pointer_offset'])
        data['address'] = int(data['address'])

        # Convert cat to an integer and ensure it's within the range of a byte (0-255)
        data['cat'] = int(data['cat'])
        if not (0 <= data['cat'] <= 255):
            raise ValueError(f"Value of 'cat' out of range for u8: {data['cat']}")

    # Convert edges to their corresponding types
    for u, v, data in G.edges(data=True):
        # Convert offset to int
        data['offset'] = int(data['offset'])
    return G



def load_graphs(root_folder, max_per_subfolder=10, shuffle=False):
    all_graphs = []

    for subdir, dirs, files in os.walk(root_folder):
        print(f"Processing {subdir}...")
        graph_count = 0
        for file in files:
            if file.endswith('.graphml') and (max_per_subfolder == -1 or graph_count < max_per_subfolder ):
                file_path = os.path.join(subdir, file)
                try:
                    graph = nx.read_graphml(file_path)
                    all_graphs.append(graph)
                    graph_count += 1
                except Exception as e:
                    print(f"Error loading {file_path}: {e}")

    if shuffle:
        random.shuffle(all_graphs)

    return all_graphs


def train(dataset):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = CustomGNN(num_node_features=7, num_classes=3, num_edge_features=1).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    criterion = torch.nn.CrossEntropyLoss()
    # DataLoader
    loader = DataLoader(dataset, batch_size=64, shuffle=True)
    print(f"loader: {loader}")
    # Training Loop
    for epoch in range(1000):
        
        model.train()
        total_loss = 0
        for data in loader:
            data = data.to(device)
            optimizer.zero_grad()
            out = model(data)
            loss = criterion(out, data.y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f'Epoch {epoch}, Loss: {total_loss / len(loader)}')
    
    return model

def test(dataset, model):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.eval()
    loader = DataLoader(dataset, batch_size=20, shuffle=False)
    correct = 0
    for data in loader:
        data = data.to(device)
        out = model(data)
        pred = out.argmax(dim=1)
        correct += int((pred == data.y).sum())
    return correct / len(dataset)




In [46]:

folder = '/home/cyril/ssh-rlkex/Generated_Graphs/output'
graphs = load_graphs(folder)
print(f"Loaded {len(graphs)} graphs")
graphs = [convert_types(graph) for graph in graphs]
graphs = [remove_all_isolated_nodes(graph) for graph in graphs]
print(f"Removed isolated nodes from graphs")
print(f"Loaded {len(graphs)} graphs")


Processing /home/cyril/ssh-rlkex/Generated_Graphs/output...
Processing /home/cyril/ssh-rlkex/Generated_Graphs/output/port-forwarding...
Processing /home/cyril/ssh-rlkex/Generated_Graphs/output/port-forwarding/V_7_8_P1...
Processing /home/cyril/ssh-rlkex/Generated_Graphs/output/port-forwarding/V_7_8_P1/64...
Processing /home/cyril/ssh-rlkex/Generated_Graphs/output/port-forwarding/V_7_8_P1/16...
Processing /home/cyril/ssh-rlkex/Generated_Graphs/output/port-forwarding/V_7_8_P1/32...
Processing /home/cyril/ssh-rlkex/Generated_Graphs/output/port-forwarding/V_7_8_P1/24...
Processing /home/cyril/ssh-rlkex/Generated_Graphs/output/port-forwarding/V_8_0_P1...
Processing /home/cyril/ssh-rlkex/Generated_Graphs/output/port-forwarding/V_8_0_P1/64...
Processing /home/cyril/ssh-rlkex/Generated_Graphs/output/port-forwarding/V_8_0_P1/16...
Processing /home/cyril/ssh-rlkex/Generated_Graphs/output/port-forwarding/V_8_0_P1/32...
Processing /home/cyril/ssh-rlkex/Generated_Graphs/output/port-forwarding/V_8_0

In [63]:
dataset = [graph_to_data(graph) for graph in graphs]
print(f"Converted graphs to data")

#split the dataset into train and test and shuffle
random.shuffle(dataset)
train_factor = 0.75
train_dataset = dataset[:int(len(dataset)*train_factor)]
test_dataset = dataset[int(len(dataset)*train_factor):]


Converted graphs to data


In [64]:

print(f"Loaded {len(dataset)} graphs")
model = train(train_dataset)
print(f"Trained model")


accuracy = test(test_dataset, model)
print(f"Test accuracy: {accuracy}")


Loaded 670 graphs
loader: <torch_geometric.deprecation.DataLoader object at 0x7f12c817b430>
Epoch 0, Loss: 1.0937160849571228
Epoch 1, Loss: 1.0611453615128994
Epoch 2, Loss: 1.066734004765749
Epoch 3, Loss: 1.0591284334659576
Epoch 4, Loss: 1.0581242069602013
Epoch 5, Loss: 1.0503958389163017
Epoch 6, Loss: 1.0264973491430283
Epoch 7, Loss: 0.9937889464199543
Epoch 8, Loss: 0.9446803852915764
Epoch 9, Loss: 0.9085889086127281
Epoch 10, Loss: 0.8577879555523396
Epoch 11, Loss: 0.854322075843811
Epoch 12, Loss: 0.812379963696003
Epoch 13, Loss: 0.7697853036224842
Epoch 14, Loss: 0.6817994341254234
Epoch 15, Loss: 0.6911997906863689
Epoch 16, Loss: 0.6573486085981131
Epoch 17, Loss: 0.5941939689218998
Epoch 18, Loss: 0.5819779094308615
Epoch 19, Loss: 0.5373622830957174
Epoch 20, Loss: 0.5365627594292164
Epoch 21, Loss: 0.49963308311998844
Epoch 22, Loss: 0.4870981313288212
Epoch 23, Loss: 0.46803306229412556
Epoch 24, Loss: 0.5055736117064953
Epoch 25, Loss: 0.4719718210399151
Epoch 26,

In [22]:
for name, param in model.named_parameters():
    if param.requires_grad:
        print(name, param.grad.norm())


conv1.bias tensor(0., device='cuda:0')
conv1.lin.weight tensor(0., device='cuda:0')
conv2.bias tensor(0., device='cuda:0')
conv2.lin.weight tensor(0., device='cuda:0')
fc.weight tensor(0., device='cuda:0')
fc.bias tensor(0., device='cuda:0')


In [None]:
#test GCNCOnv didn't work well, output was all 0, I don't really know why
#Then I tried GATConv, it worked well, but the accuracy is not very high, I think it's because the dataset is too small
#TRying gatv2conv with edge attributes, it worked well
