In [1]:
# -------------------------
# IMPORTS AND SETUP
# -------------------------

import os
import random
import numpy as np
import networkx as nx
import torch
import torch.nn.functional as F
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.nn import SAGEConv, global_mean_pool
from torch_scatter import scatter_max

from rl_env_graph_obs_variable_action_space import GraphTraversalEnv
from collections import deque
import numpy as np
import random
import torch
import torch.nn.functional as F
from torch import optim
from collections import namedtuple, deque
#import range tqdm
from tqdm import tqdm
from tqdm import trange
from torch_geometric.loader import DataLoader
import torch.nn.functional as F
from torch_geometric.nn import GATv2Conv, global_mean_pool

from torch_geometric.nn import GCNConv, global_mean_pool, SAGPooling, global_add_pool, global_max_pool
from torch_geometric.data import Data
import torch_geometric.transforms as T
from torch_geometric.nn import GATConv
from torch.utils.tensorboard import SummaryWriter

import heapq  # For priority queue
import time
from agent_variable_action_space import Agent
from utils import preprocess_graph, convert_types, add_global_root_node, connect_components, remove_all_isolated_nodes
from networkx.drawing.nx_pydot import graphviz_layout



KeyboardInterrupt: 

In [None]:
from agent_variable_action_space import GraphQNetwork

STATE_SPACE = 7
EDGE_ATTR_SIZE = 1
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GraphQNetwork(STATE_SPACE, EDGE_ATTR_SIZE, 0).to(device)
model_path = "/root/ssh-rlkex/models/rl/VACTION_SPACE_GOAL_GraphQNetwork_20240212-065428/350_19.96.pt"
model.load_state_dict(torch.load(model_path))

<All keys matched successfully>

In [None]:

# -------------------------
# HYPERPARAMETERS
# -------------------------

def act(model, state, action_mask, goal, visited_subgraph):
    state = state
    action_mask = action_mask.to(device)
    goal = goal.to(device)
    goal = goal.unsqueeze(0)
    model.eval()
    x = state.x.to(device)
    edge_index = state.edge_index.to(device)
    edge_attr = state.edge_attr.to(device)
    with torch.no_grad():  # Wrap in no_grad
        action_values = model(x, edge_index, edge_attr, None, action_mask, goal, visited_subgraph, None)
    return_values = action_values.cpu()
    selected_action = torch.argmax(return_values).item()
    torch.cuda.empty_cache()
    return selected_action, return_values



def define_targets(graph):
    target_nodes_map = {}
    for node, attributes in graph.nodes(data=True):
        cat_value = attributes['cat']
        if cat_value >= 0 and cat_value <= 3: #Only take the encryption and initialization keys, ignore integrity keys
            target_nodes_map[node] = attributes['cat'] 
    return target_nodes_map


In [None]:
from root_heuristic_rf import GraphPredictor

root_detection_model_path="/root/ssh-rlkex/models/root_heuristic_model.joblib"
 
root_detector = GraphPredictor(root_detection_model_path)

In [None]:

def test_for_graph(file):
    """Basically the same as the training function, but without training"""
    graph = nx.read_graphml(file)
    graph = preprocess_graph(graph)

    #get all target_nodes, check if nodes has 'cat' = 1
    target_nodes = define_targets(graph=graph)
    episode_rewards = []
    #data = graph_to_data(graph)
    env = GraphTraversalEnv(graph, target_nodes,root_detector=root_detector, obs_is_full_graph=True)
    
    total_reward = 0
    total_key_found = 0

    for target in target_nodes:
        done = False

        goal = target_nodes[target]
        observation = env.reset()
        goal_one_hot = env.get_goal_one_hot(goal)
        env.set_target_goal(goal)
        display_graph = env.graph
        while not done:
            
            action_mask = env._get_action_mask()
            visited_subgraph = env.get_visited_subgraph()
            action, qvalues = act(model, observation, action_mask, goal_one_hot, visited_subgraph)
            node_qvalues_map = {}
            for i, qvalue in enumerate(qvalues):
                if action_mask[i] == 1:
                    node_qvalues_map[env.inverse_node_mapping[i]] = qvalue
            #if SHOW_GAPH_TEST:
                #show_graph(display_graph, target, env.current_node, node_qvalues_map)
            
            new_observation, reward, done, info, new_goal = env.step(action)
            total_reward += reward
            if done:
                if info["found_target"]:
                    total_key_found += 1
            
            observation = new_observation
    
    return total_reward, total_key_found, len(target_nodes)


In [None]:
FOLDER = "/root/ssh-rlkex/Test_Graphs"
#get all files in the folder recursively
all_files = []
for root, dirs, files in os.walk(FOLDER):
    for file in files:
        if file.endswith(".graphml"):
            all_files.append(os.path.join(root, file))


print(f"Total files: {len(all_files)}")

Total files: 200


In [None]:

#shuffle the files
random.shuffle(all_files)

nb_test_files = len(all_files)
test_files = all_files[:nb_test_files]

In [None]:
#Test for each file
print(f"Executing Testing ...")
test_rewards = []
test_success_rate = []
for i, file in enumerate(test_files):
    if file.endswith(".graphml"):
        print(f"[{i} / {nb_test_files}] : Executing Testing for {file}")
        reward, nb_found_keys, nb_keys = test_for_graph(file)
        print(f"Found {nb_found_keys} / {nb_keys} keys with a mean reward of {reward}")
        test_rewards.append(reward)
        test_success_rate.append(nb_found_keys / nb_keys)

print(f"Testing done with a mean reward of {np.mean(test_rewards)} and a success rate of {np.mean(test_success_rate)}")

Executing Testing ...
[0 / 200] : Executing Testing for /root/ssh-rlkex/Test_Graphs/basic/V_8_8_P1/64/20670-1643986141.graphml
Model loaded!
-------------------- ASSESSING TARGET COMPLEXITY ----------------------
Has cycles: False
Number of targets: 2
Number of nodes in the graph: 34
Number of edges in the graph: 33
Path length from current node to target nodes: 4
Mean number of neighbors: 0.9705882352941176
Depth of the graph: 4
Number of neighbors of the root: 9
------------------------------------------------------------------------
Found 1 / 2 keys with a mean reward of 4
[1 / 200] : Executing Testing for /root/ssh-rlkex/Test_Graphs/basic/V_8_7_P1/64/8757-1643986137.graphml
Model loaded!
-------------------- ASSESSING TARGET COMPLEXITY ----------------------
Has cycles: False
Number of targets: 2
Number of nodes in the graph: 34
Number of edges in the graph: 33
Path length from current node to target nodes: 4
Mean number of neighbors: 0.9705882352941176
Depth of the graph: 4
Number