In [3]:
# Download pre-trained word2vec model.
!python -m gensim.downloader --download word2vec-google-news-300

In [22]:
from __future__ import annotations

from typing import Tuple, List, Dict, Union
from collections import deque, namedtuple
import socket
from threading import Thread
from math import log10
import json
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from gensim.models import KeyedVectors

from action_matrix import ACTION_ID2NAME, ACTION_LEN, ACTIVITY_IDF2ID, ACTIVITY_IDF2NAME, ACTIVITY_LEN, MAPPING_ACTIVITY2ACTION

In [23]:
torch.device('cuda')

In [24]:
word2vec_model = KeyedVectors.load_word2vec_format('/root/gensim-data/word2vec-google-news-300/word2vec-google-news-300', binary=True)

In [25]:
SERVER_ADDRESS = '0.0.0.0'
SERVER_PORT = 6416
LOGIN_STATE_ID2NAME = ('Trying', 'Normal User', 'Root User')
ATTRIBUTE_NAME = ('exec', 'mail_sender', 'mail_receiver', 'username', 'password', 'interface', 'domain', 'ip', 'port', 'path')
ATTRIBUTE_DEFAULT = ('', '', '', '', '', 'enp1s0', 'localhost', '127.0.0.1', '', '')
NUM_EPISODES = 1000
TARGET_UPDATE_STEP = 10

In [26]:
# Define the environment.
class CustomEnvironment:
    @staticmethod
    def reset() -> Tuple[int, List[str], List[int]]:
        
        return 0, list(ATTRIBUTE_DEFAULT), []
    
    
    @staticmethod
    def action_evaluation(suggested_activities: List[int], encoded_action: int) -> float:
        
        suggested_activities_id = [ACTIVITY_IDF2ID.index(f'EAC{act:04d}') for act in suggested_activities if act in ACTIVITY_IDF2ID]
        suggested_actions = []
        
        for act_id in suggested_activities_id:
            if 1 <= act_id <= ACTIVITY_LEN:
                login_state, action_id = MAPPING_ACTIVITY2ACTION[str(act_id)]
                encoded_action = login_state * 100 + action_id
                
                if encoded_action not in suggested_actions:
                    suggested_actions.append(encoded_action)
              
        if encoded_action in suggested_actions:  
            return 1
        else:
            return -1
    
    
    def __init__(self):
        
        self.previous_state = self.reset()
        self.current_state = self.reset()
        self.done = False
        self.step_counter = 0


    def step(self, action: int, state: Tuple[int, List[str], List[int]]) -> float:
        
        self.previous_state = self.current_state
        self.current_state = state
        self.step_counter += 1
        
        reward = self.action_evaluation(self.previous_state[2], (self.previous_state[0] * 100 + action)) + log10(self.step_counter)
        
        if state[0] == -1:
            self.done = True
        
        return reward


# Define the neural network architecture.
class DQN(nn.Module):
    def __init__(self, input_size, output_size):
        
        super(DQN, self).__init__()
        
        self.fc1 = nn.Linear(input_size, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, output_size)


    def forward(self, x):
        
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        
        return x


# Define the double DQN with replay buffer.
class DoubleDQN:
    transition = namedtuple('transition', ('state', 'action', 'reward', 'next_state', 'done'))
    
    def __init__(self, lr=0.001, gamma=0.99, epsilon=1.0, epsilon_decay=0.999, epsilon_min=0.01, batch_size=32, buffer_size=100):
        
        self.input_size = 317
        self.output_size = 16
        self.gamma = gamma
        self.epsilon = epsilon
        self.epsilon_decay = epsilon_decay
        self.epsilon_min = epsilon_min
        self.batch_size = batch_size

        self.policy_net = DQN(self.input_size, self.output_size)
        self.target_net = DQN(self.input_size, self.output_size)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()

        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=lr)
        self.criterion = nn.MSELoss()

        self.replay_buffer = deque(maxlen=buffer_size)


    def select_action(self, state: torch.Tensor) -> int:
        
        if random.random() < self.epsilon:
            return random.randint(0, self.output_size - 1) + 1
        
        with torch.no_grad():
            q_values = self.policy_net(state)
            return q_values.argmax().item() + 1


    def store_transition(self, state: torch.Tensor, action: int, reward: float, next_state: torch.Tensor, done: bool) -> None:
        
        self.replay_buffer.append(DoubleDQN.transition(state, action, reward, next_state, done))


    def sample_batch(self) -> List[DoubleDQN.transition]:
        
        transitions = random.sample(self.replay_buffer, self.batch_size)
        
        return transitions
        
        # batch = DoubleDQN.transition(*zip(*transitions))
        # states = torch.tensor(batch.state, dtype=torch.float32)
        # actions = torch.tensor(batch.action, dtype=torch.int64).unsqueeze(1)
        # rewards = torch.tensor(batch.reward, dtype=torch.float32).unsqueeze(1)
        # next_states = torch.tensor(batch.next_state, dtype=torch.float32)
        # dones = torch.tensor(batch.done, dtype=torch.float32).unsqueeze(1)
        # 
        # return states, actions, rewards, next_states, dones


    def train(self, state: torch.Tensor, action: int, reward: float, next_state: torch.Tensor, done: bool) -> None:
        
        self.store_transition(state, action, reward, next_state, done)
        
        if len(self.replay_buffer) < self.batch_size:
            return

        for state, action, reward, next_state, done in self.sample_batch():
            try:
                q_values = self.policy_net(state)
            except:
                print('State =', state)
                
            next_q_value = self.target_net(next_state).detach()[self.policy_net(next_state).argmax()]
    
            q_value = q_values[action]
            expected_q_value = reward + (1 - done) * self.gamma * next_q_value
    
            loss = self.criterion(q_value, expected_q_value)

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay


    def update_target_network(self) -> None:
        
        self.target_net.load_state_dict(self.policy_net.state_dict())
        
        
class EngageHandler:
    
    @staticmethod
    def decode_state(encoded_state: bytes) -> Tuple[int, List[str], List[int]]:
        
        state: List[str] = json.loads(encoded_state.decode())['state']
                
        login_state = int(state[0])
        attr = state[1:11]
        
        if len(state) > 11:
            suggested_activities = [int(act.replace('EAC', '')) for act in state[11:]]
        else:
            suggested_activities = []

        return login_state, attr, suggested_activities
        
    
    def __init__(self):
        
        self.keep_running = True
        self.server: Union[socket.socket, None] = None
        self.clients: List[Thread] = []
        self.states: List[Union[Tuple[int, List[str], List[int]], None]] = []
        self.actions: List[int] = []
        
        self.init_server()
        
        
    def init_server(self):
        
        self.server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self.server.bind((SERVER_ADDRESS, SERVER_PORT))
        self.server.listen(5)
        
        print(f'Server is listening on "{SERVER_ADDRESS}:{SERVER_PORT}".')
        
        while self.keep_running:
            try:
                con, client = self.server.accept()
            except KeyboardInterrupt:
                self.close_server()
                break
            
            t = Thread(target=self.handle_con, args=(client, con))
            t.start()
            
            self.clients.append(t)
            
            
    def close_server(self):
        
        self.keep_running = False
        
        if self.server is not None:
            self.server.close()
            
        print('Server closed.')
        
        
    def handle_con(self, client: Tuple[str, int], con: socket.socket):
        
        client_id = len(self.states)
        self.states.append(None)
        self.actions.append(-1)
        
        print(f'[{client_id}] Accepted connection from "{client[0]}:{client[1]}". Client id is {client_id}.')
        
        env = CustomEnvironment()
        agent = DoubleDQN()
        total_reward = 0
        PREDEFINED_ACTIONS = [[1, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3], [4, 4, 14, 14, 5, 11], [11, 11], [12, 13, 13, 9, 7, 8], [6, 10, 15, 15, 16, 16, 16, 16, 16, 16], []]
        
        with con:
            con.settimeout(1)
            
            while self.keep_running and not env.done:
                action = agent.select_action(state_to_tensor(env.current_state))
                
                # if len(PREDEFINED_ACTIONS[client_id]) > 0:
                #     action = PREDEFINED_ACTIONS[client_id].pop(0)
                
                con.sendall(str(action).encode())
                print(f'[{client_id}] Selected action is {action} "{ACTION_ID2NAME[action]}".')
                
                while self.keep_running:
                    try:
                        encoded_state = con.recv(1024)
                    except ConnectionResetError:
                        encoded_state = ''
                        break
                    except TimeoutError:
                        pass
                    else:
                        break
                        
                if len(encoded_state) == 0:
                    break
                
                if self.keep_running:
                    if encoded_state:
                        next_state = EngageHandler.decode_state(encoded_state)
                    else:
                        next_state = (-1, list(ATTRIBUTE_DEFAULT), [])
                        
                    print(f'[{client_id}] Next state is {next_state}.')
                    
                    reward = env.step(action - 1, next_state)
                    total_reward += reward
                    
                    print(f'[{client_id}] [{env.step_counter}] From state {env.previous_state}, do action {action}, to state {env.current_state}. (reward={reward}, total={total_reward})')
                    
                    agent.train(state_to_tensor(env.previous_state), action, reward, state_to_tensor(env.current_state), env.done)
            
                    if env.step_counter % TARGET_UPDATE_STEP == 0:
                        agent.update_target_network()
                
        print(f'[{client_id}] Disconnected from {client[0]}:{client[1]} .')


def encode_login_state(state: int) -> np.ndarray:
    
    # Convert integer to one-hot encoding.
    return np.eye(3)[state]


def embed_attributes(attributes: List[str]) -> np.ndarray:
    
    embedded_attributes = []
    
    for attr in attributes:
        # Assuming you have a trained Word2Vec model
        embedded_attr = [word2vec_model[word] for word in attr.split() if word in word2vec_model]
        embedded_attributes += embedded_attr
        
    return np.array(embedded_attributes).flatten()


def encode_suggested_activities(activities: List[int]) -> np.ndarray:
    
    # Convert activity IDFs to Multi-hot encoding.
    encoded_activities = np.zeros(ACTIVITY_LEN)
    
    for activity in activities:
        if 1 <= activity <= ACTIVITY_LEN:
            encoded_activities[activity - 1] = 1
    
    return encoded_activities


def state_to_tensor(state: Tuple[int, List[str], List[int]]) -> torch.Tensor:
    
    # Return an one dimension tensor: 3 + 30 * 10 + 14 = 317.
    
    login_state, attributes, suggested_activities = state
    login_state_tensor = torch.tensor(encode_login_state(login_state), dtype=torch.float32)
    attribute_tensor = torch.tensor(embed_attributes(attributes), dtype=torch.float32)
    suggested_activities_tensor = torch.tensor(encode_suggested_activities(suggested_activities), dtype=torch.float32)
    
    return torch.cat((login_state_tensor, attribute_tensor, suggested_activities_tensor))

In [None]:
eh = EngageHandler()