In [38]:
# -------------------------
# IMPORTS AND SETUP
# -------------------------

import os
import random
import numpy as np
import networkx as nx
import torch
import torch.nn.functional as F
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.nn import SAGEConv, global_mean_pool
from torch_scatter import scatter_max

from rl_env_graph_obs_variable_action_space import GraphTraversalEnv
from collections import deque
import numpy as np
import random
import torch
import torch.nn.functional as F
from torch import optim
from collections import namedtuple, deque
#import range tqdm
from tqdm import tqdm
from tqdm import trange
from torch_geometric.loader import DataLoader
import torch.nn.functional as F
from torch_geometric.nn import GATv2Conv, global_mean_pool

from torch_geometric.nn import GCNConv, global_mean_pool, SAGPooling, global_add_pool, global_max_pool
from torch_geometric.data import Data
import torch_geometric.transforms as T
from torch_geometric.nn import GATConv
from torch.utils.tensorboard import SummaryWriter

import heapq  # For priority queue
import time
from agent_variable_action_space import Agent
from utils import preprocess_graph, convert_types, add_global_root_node, connect_components, remove_all_isolated_nodes
from networkx.drawing.nx_pydot import graphviz_layout
from matplotlib import pyplot as plt 

%matplotlib inline
import mpld3
mpld3.enable_notebook()

In [39]:
from agent_variable_action_space import GraphQNetwork, GraphQNetworkNew

STATE_SPACE = 7
EDGE_ATTR_SIZE = 1
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GraphQNetworkNew(STATE_SPACE, EDGE_ATTR_SIZE, 0).to(device)
model_path = "/root/ssh-rlkex/models/rl/VACTION_SPACE_GOAL_GraphQNetworkNew_20240229-070657/460_6.57.pt"
model.load_state_dict(torch.load(model_path))

<All keys matched successfully>

In [40]:

# -------------------------
# HYPERPARAMETERS
# -------------------------

def act(model, state, action_mask, goal, visited_subgraph, current_node):
    state = state
    action_mask = action_mask.to(device)
    goal = goal.to(device)
    goal = goal.unsqueeze(0)
    model.eval()
    x = state.x.to(device)
    edge_index = state.edge_index.to(device)
    edge_attr = state.edge_attr.to(device)
    current_node = torch.tensor([current_node], dtype=torch.long).to(device)

    with torch.no_grad():  # Wrap in no_grad
        action_values = model(x, edge_index, edge_attr, None, action_mask, goal, visited_subgraph, None, current_node)
    return_values = action_values.cpu()
    
    selected_action = torch.argmax(return_values).item()
    torch.cuda.empty_cache()
    return selected_action, return_values



def define_targets(graph):
    target_nodes_map = {}
    for node, attributes in graph.nodes(data=True):
        cat_value = attributes['cat']
        if cat_value >= 0 and cat_value <= 3: #Only take the encryption and initialization keys, ignore integrity keys
            target_nodes_map[node] = attributes['cat'] 
    return target_nodes_map


In [41]:
from root_heuristic_rf import GraphPredictor

root_detection_model_path="/root/ssh-rlkex/models/root_heuristic_model.joblib"
 
root_detector = GraphPredictor(root_detection_model_path)

In [42]:


def show_graph(graph, goal, current_node, neighbours_qvalues):
    #for all target nodes, if the value of target_nodes[node] is 0 then label is 'A', if 1 then label is 'B', etc.. up to 5 (F)
    labels = {}
    for node, attributes in graph.nodes(data=True):
        if attributes['cat'] >= 0:
            labels[str(node)] = chr(ord('A') + attributes['cat'])
            #if is goal concatenate with "G"
            if node == goal:
                labels[str(node)] = labels[str(node)] + " G"

        elif node == current_node:
            labels[str(node)] = "X"
        else:
            labels[str(node)] = ""
        #if node is a neighbou concatenate the qvalue
        if node in neighbours_qvalues:
            labels[str(node)] = f"{labels[str(node)]} : {neighbours_qvalues[node].item():.2f}"
    #set colors of target nodes to red
            
    #set colors of neighbours as a heatmap of the qvalues, closer to 1 is red, closer to 0 is blue
    colors = []
    for node, attributes in graph.nodes(data=True):
        if attributes['cat'] >= 0:
            colors.append((1, 0, 0))
        elif node == current_node:
            colors.append((0, 1, 0))
        elif node in neighbours_qvalues:
            
            min_qvalue = min(neighbours_qvalues.values()).item()
            max_qvalue = max(neighbours_qvalues.values()).item()


            qvalue = neighbours_qvalues[node] 

            #convert qvalue to regular float
            qvalue = qvalue.item()
            

            #normalize between 0 and 1 in case
            qvalue = 0 if max_qvalue == min_qvalue else (qvalue - min_qvalue) / (max_qvalue - min_qvalue)
            
            colors.append((qvalue, 0, 1 - qvalue))
        else:
            colors.append((0, 0, 0))
    
    G_temp = nx.DiGraph()
    G_temp.add_nodes_from(str(n) for n in graph.nodes())    
    G_temp.add_edges_from((str(u), str(v)) for u, v in graph.edges())
    #draw the graph with labels and colors
    pos = graphviz_layout(G_temp, prog='dot')
    nx.draw(G_temp, labels=labels, node_color=colors, pos = pos)
    plt.show()


In [43]:
SHOW_GAPH_TEST = False

In [44]:

def test_for_graph(file):
    """Basically the same as the training function, but without training"""
    graph = nx.read_graphml(file)
    graph = preprocess_graph(graph)

    #get all target_nodes, check if nodes has 'cat' = 1
    target_nodes = define_targets(graph=graph)
    episode_rewards = []
    #data = graph_to_data(graph)
    env = GraphTraversalEnv(graph, target_nodes,root_detector=root_detector, obs_is_full_graph=True)
    
    total_reward = 0
    total_key_found = 0

    for target in target_nodes:
        done = False

        goal = target_nodes[target]
        observation = env.reset()
        goal_one_hot = env.get_goal_one_hot(goal)
        env.set_target_goal(goal)
        display_graph = env.graph
        while not done:
            
            action_mask = env._get_action_mask()
            visited_subgraph = env.get_visited_subgraph()
            current_node = env.get_current_node()
            action, qvalues = act(model, observation, action_mask, goal_one_hot, visited_subgraph, current_node)
            node_qvalues_map = {}
            for i, qvalue in enumerate(qvalues):
                if action_mask[i] == 1:
                    node_qvalues_map[env.inverse_node_mapping[i]] = qvalue
            if SHOW_GAPH_TEST:
                show_graph(display_graph, target, env.current_node, node_qvalues_map)
            
            new_observation, reward, done, info, new_goal = env.step(action)
            total_reward += reward
            if done:
                if info["found_target"]:
                    total_key_found += 1
            
            observation = new_observation
    
    return total_reward, total_key_found, len(target_nodes)


In [45]:
FOLDER = "/root/ssh-rlkex/Test_Graphs"
#get all files in the folder recursively
all_files = []
for root, dirs, files in os.walk(FOLDER):
    for file in files:
        if file.endswith(".graphml"):
            all_files.append(os.path.join(root, file))


print(f"Total files: {len(all_files)}")

Total files: 200


In [46]:

#shuffle the files
random.shuffle(all_files)

nb_test_files = len(all_files)
test_files = all_files[:nb_test_files]

In [47]:
#Test for each file
print(f"Executing Testing ...")
test_rewards = []
test_success_rate = []
for i, file in enumerate(test_files):
    if file.endswith(".graphml"):
        print(f"[{i} / {nb_test_files}] : Executing Testing for {file}")
        reward, nb_found_keys, nb_keys = test_for_graph(file)
        print(f"Found {nb_found_keys} / {nb_keys} keys with a mean reward of {reward}")
        test_rewards.append(reward)
        test_success_rate.append(nb_found_keys / nb_keys)

print(f"Testing done with a mean reward of {np.mean(test_rewards)} and a success rate of {np.mean(test_success_rate)}")

Executing Testing ...
[0 / 200] : Executing Testing for /root/ssh-rlkex/Test_Graphs/basic/V_7_8_P1/64/28527-1643975243.graphml
Model loaded!
-------------------- ASSESSING TARGET COMPLEXITY ----------------------
Has cycles: False
Number of targets: 2
Number of nodes in the graph: 24
Number of edges in the graph: 23
Path length from current node to target nodes: 4
Mean number of neighbors: 0.9583333333333334
Depth of the graph: 2
Number of neighbors of the root: 9
------------------------------------------------------------------------
Found 2 / 2 keys with a mean reward of 28
[1 / 200] : Executing Testing for /root/ssh-rlkex/Test_Graphs/basic/V_8_8_P1/32/17960-1643985448.graphml
Model loaded!
-------------------- ASSESSING TARGET COMPLEXITY ----------------------
Has cycles: False
Number of targets: 4
Number of nodes in the graph: 30
Number of edges in the graph: 29
Path length from current node to target nodes: 8
Mean number of neighbors: 0.9666666666666667
Depth of the graph: 4
Numb