In [10]:
import gym
import tianshou as ts
import numpy as np
import torch                                    # 导入torch
import torch.nn as nn                           # 导入torch.nn
import torch.nn.functional as F                 # 导入torch.nn.functional

import gym_waf.envs.wafEnv
from gym_waf.envs.wafEnv  import samples_test, samples_train
from gym_waf.envs.features import Features
from gym_waf.envs.waf import Waf_Check
from gym_waf.envs.xss_manipulator import Xss_Manipulator

In [11]:
class Net(nn.Module):
    def __init__(self, state_shape, action_shape):
        super().__init__()
        self.model = nn.Sequential(*[
            nn.Linear(np.prod(state_shape), 16), nn.ReLU(inplace=True),
            nn.Linear(16, 32), nn.ReLU(inplace=True),
            nn.Linear(32, 16), nn.ReLU(inplace=True),
            nn.Linear(16, np.prod(action_shape))
        ])
    def forward(self, obs, state=None, info={}):
        if not isinstance(obs, torch.Tensor):
            obs = torch.tensor(obs, dtype=torch.float)
        batch = obs.shape[0]
        logits = self.model(obs.view(batch, -1))
        return logits, state

In [12]:
# 创建虚拟环境
ENV_NAME = 'Waf-v0' 
train_envs = gym.make(ENV_NAME)   
test_envs = gym.make(ENV_NAME) 

state_shape = train_envs.observation_space.shape or train_envs.observation_space.n
action_shape = test_envs.action_space.shape or test_envs.action_space.n
print("state_shape:{}".format(state_shape))
print("action_shape:{}".format(action_shape))

state_shape:(1, 257)
action_shape:4


In [13]:
# 超参数
BATCH_SIZE = 32                                 # 样本数量
LR = 0.01                                       # 学习率
EPSILON = 0.9                                   # greedy policy
GAMMA = 0.9                                     # reward discount
TARGET_REPLACE_ITER = 100                       # 目标网络更新频率
MEMORY_CAPACITY = 2000                          # 记忆库容量

In [14]:
# 尝试的最大次数
max_episode_steps = 5     # fit训练时用到，在一次学习周期中的最大步数(默认一直学习直到“死”)

# 构造动作速查表
ACTION_LOOKUP = {i:act for i,act in enumerate(Xss_Manipulator.ACTION_TABLE.keys())} # key为原动作字典的下标0123，value为原动作字典的key即免杀操作名

In [23]:
def train_dqn_model(layers, rounds=10000):
    env = gym.make(ENV_NAME)    # 创建环境
    env.seed(1)
    window_length = 1       # 窗口长度，后面创建记忆体时用，通常设置为1

    # 打印动作、观测值相关信息
    print("免杀操作的个数：")
    print(state_shape)   # 免杀操作的个数，可自行增加
    print("观测值空间形状：")
    print(action_shape)      # 为(1,257)

    # 创建神经网络模型
    model = Net(state_shape,action_shape)
    optim = torch.optim.Adam(model.parameters(), lr=1e-3)
    # 声明训练策略
    policy = ts.policy.DQNPolicy(
        model,
        optim, 
        discount_factor=GAMMA,                                 # 奖励衰减率
        estimation_step=window_length,                         # 窗口长度，后面创建记忆体时用，通常设置为1
        target_update_freq=TARGET_REPLACE_ITER,                # target网络更新频率
    )
    
    # 声明与环境直接进行交互的collectoer
    train_collector = ts.data.Collector(policy, train_envs, ts.data.ReplayBuffer(size=MEMORY_CAPACITY))
    test_collector = ts.data.Collector(policy, test_envs)
    
    # 训练
    result = ts.trainer.offpolicy_trainer(
        policy, train_collector, test_collector,
        max_epoch=1,                                         # 训练的最大轮数
        step_per_epoch=10000,                                    # 每轮训练要使用多少个随机样本进行强化学习
        step_per_collect=1,                                    # 每收集多少个样本更新到eval网络一次
        episode_per_test=100,                                  # 随机训练多少个样本以后在测试集上进行效果测试
        batch_size=BATCH_SIZE,                                 
        train_fn=lambda epoch, env_step: policy.set_eps(0.1),
        test_fn=lambda epoch, env_step: policy.set_eps(0.05),
        stop_fn = None 
#         stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold
        )
    
    
    features_extra = Features()     # 特征向量
    waf_checker = Waf_Check()   # waf检验免杀效果
    xss_manipulatorer = Xss_Manipulator()   # 根据动作修改当前样本，来达到免杀

    success = 0     # 免杀成功数
    sum = 0     # 总数目


    for sample in samples_test:
        sum += 1

        for _ in range(max_episode_steps):
            if not waf_checker.check_xss(sample) :
                success += 1
                break

            #f = features_extra.extract(sample).reshape(shp)
            f = features_extra.extract(sample)
            act_values = model(f)
            action = np.argmax(act_values[0].detach().numpy())
            sample = xss_manipulatorer.modify(sample,ACTION_LOOKUP[action])

    print("总数量：{} 成功：{}".format(sum,success))


if __name__ == '__main__':
    agent1, model1 = train_dqn_model([5], rounds=100)
    model1.save('waf-v0.h5', overwrite=True)

免杀操作的个数：
(1, 257)
观测值空间形状：
4


Epoch #1: 10001it [02:55, 56.99it/s, env_step=10000, len=4, loss=7.797, n/ep=1, n/st=1, rew=0.00]                           


Epoch #1: test_reward: 5.300000 ± 4.990992, best_reward: 5.700000 ± 4.950758 in #0
总数量：51 成功：28


TypeError: cannot unpack non-iterable NoneType object

In [None]:
def test_rl(features_extra = Features()     # 特征向量
    waf_checker = Waf_Check()   # waf检验免杀效果
    xss_manipulatorer = Xss_Manipulator()   # 根据动作修改当前样本，来达到免杀):
            
    success = 0     # 免杀成功数
    sum = 0     # 总数目
            
    for sample in samples_test:
        sum += 1

        for _ in range(max_episode_steps):
            if not waf_checker.check_xss(sample) :
                success += 1
                break

            f = features_extra.extract(sample)
            act_values = model(f)
            action = np.argmax(act_values[0].detach().numpy())
            sample = xss_manipulatorer.modify(sample,ACTION_LOOKUP[action])

    print("总数量：{} 成功：{}".format(sum,success))