In [1]:
import numpy as np
from scipy.stats import entropy
import torch
from cgan_code_2 import Discriminator  # 假设你有一个定义好的 Discriminator 类
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import json
import struct
from sklearn.metrics import classification_report, roc_auc_score

In [2]:
TOTAL_LEN = 114
# NPRINT_REAL_WIDTH = 50*8
NPRINT_REAL_WIDTH = 22*8
# LABEL_DICT = {'facebook': 0, 'skype': 1, 'aim': 2, 'email': 3, 'voipbuster': 4, 'hangouts': 5, 'youtube': 6, 'sftp': 7, 'icq': 8,  'ftps': 9, 'vimeo': 10, 'spotify': 11, 'netflix': 12, 'bittorrent': 13}
LABEL_DICT = {'facebook': 0, 'skype': 1}
# LABEL_DICT = {'facebook': 0, 'skype': 1, 'email': 2, 'voipbuster': 3, 'hangouts': 4, 'youtube': 5, 'ftps': 6, 'vimeo': 7, 'spotify': 8, 'netflix': 9, 'bittorrent': 10}
# LABEL_DICT = {'facebook': 0, 'skype': 1, 'email': 2, 'voipbuster': 3, 'youtube': 4, 'ftps': 5, 'vimeo': 6, 'spotify': 7, 'netflix': 8, 'bittorrent': 9}
# LABEL_DICT = {'email': 0, 'youtube': 1, 'ftps': 2, 'vimeo': 3, 'spotify': 4, 'netflix': 5, 'bittorrent': 6}

SEQ_DIM = 3
MAX_PKT_LEN = 1500
MAX_TIME = 1000
MAX_PORT = 65536
MAX_SEQ_LEN = 16

label_dim = len(LABEL_DICT) 
image_dim = (1, NPRINT_REAL_WIDTH, NPRINT_REAL_WIDTH)  # 生成单通道图像
noise_dim = 128  # 噪声维度
batch_size = 128
epochs = 200
source_name = './vpn_data_small.json'
bins_name = './bins_small.json'
model_name = './save_all/discriminator_v7.pth'

In [3]:
bins_data = {}
with open(bins_name, 'r') as f_bin:
    bins_data = json.load(f_bin)
    
port_intervals = bins_data['port']['intervals']
pkt_len_intervals = []
for bins in bins_data['packet_len']:
    pkt_len_intervals.append(bins['intervals'])
time_intervals = []
for bins in bins_data['time']:
    time_intervals.append(bins['intervals'])

In [4]:
discriminator = Discriminator(label_dim, SEQ_DIM, MAX_SEQ_LEN, 'cpu')

# 加载模型权重
checkpoint = torch.load(model_name)  # 加载保存的权重字典
discriminator.load_state_dict(checkpoint)  # 将权重字典加载到模型中

# 切换到评估模式
discriminator.eval()

  checkpoint = torch.load(model_name)  # 加载保存的权重字典


Discriminator(
  (lstm): LSTM(3, 512, num_layers=4, batch_first=True)
  (length_fc): Sequential(
    (0): Linear(in_features=16, out_features=128, bias=True)
    (1): ReLU(inplace=True)
  )
  (fc): Sequential(
    (0): Linear(in_features=640, out_features=512, bias=True)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (heads): ModuleList(
    (0-1): 2 x Sequential(
      (0): Linear(in_features=512, out_features=1, bias=True)
      (1): Identity()
    )
  )
)

In [5]:
checkpoint

OrderedDict([('lstm.weight_ih_l0',
              tensor([[-0.0573, -0.1975, -0.0593],
                      [-0.1327, -0.1763, -0.2676],
                      [ 0.0375, -0.0923,  0.1357],
                      ...,
                      [ 0.0051,  0.1225, -0.1281],
                      [-0.1791,  0.0186,  0.0578],
                      [-0.1543, -0.0544, -0.0151]], device='cuda:0')),
             ('lstm.weight_hh_l0',
              tensor([[-0.2117,  0.1721, -0.1151,  ...,  0.0525, -0.0521, -0.2237],
                      [ 0.0655, -0.0545,  0.0064,  ...,  0.1071,  0.0044,  0.0445],
                      [ 0.0699, -0.0598, -0.0008,  ..., -0.0155,  0.1059,  0.0137],
                      ...,
                      [-0.0335, -0.0617, -0.0134,  ..., -0.0827,  0.0627, -0.0322],
                      [-0.0976,  0.0960, -0.2138,  ...,  0.1979, -0.1245, -0.2666],
                      [ 0.0127,  0.0096,  0.0347,  ..., -0.0569,  0.0229,  0.0709]],
                     device='cuda:0')),
     

In [6]:
class MixDataset(Dataset):
    def __init__(self, json_file, class_mapping, max_seq_len, bins_file, label_str, transform=None):
        """
        :param json_file: 存储数据的JSON文件路径
        :param class_mapping: 类别名到整数标签的映射
        :param nprint_width: nprint的固定宽度
        :param transform: 图像预处理转换
        """
        self.json_file = json_file
        self.class_mapping = class_mapping  # 类别映射
        self.max_seq_len = max_seq_len
        self.label_str = label_str
        label_int = self.class_mapping[label_str]
        self.label_one_hot = F.one_hot(torch.tensor(label_int), num_classes=len(self.class_mapping)).float()
        self.transform = transform
        
        # 读取JSON文件
        with open(json_file, 'r') as f:
            self.data = json.load(f)['data']
        
        with open(bins_file, 'r') as f_bin:
            self.bins_data = json.load(f_bin)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        
        def find_interval(value, intervals):
            for idx, [start, end] in enumerate(intervals):
                if start <= value <= end:
                    return idx  # 返回所在区间的下标
            return None
        # 解析每一条数据
        item = self.data[idx]

        metadata = np.array(list(item['meta'].values()), dtype=np.float32)
        length = min(metadata[1],self.max_seq_len)

        labelstr = item['labels'][0]  # 假设 labels 是字符串类型
        
        is_real = np.zeros(2)
        
        if labelstr == self.label_str:
            is_real[0] = 1
        else:
            is_real[1] = 1
        #     print(labelstr,self.label_str)
        # print(is_real)
        

        port_intervals = self.bins_data['port']['intervals']
        pkt_len_intervals = []
        for bins in self.bins_data['packet_len']:
            pkt_len_intervals.append(bins['intervals'])
        time_intervals = []
        for bins in self.bins_data['time']:
            time_intervals.append(bins['intervals'])

        seq = []

        im = bytes.fromhex(item['nprint'])
        # def split_bytes_by_length(data, chunk_size):
        #     return [data[i:i + chunk_size] for i in range(0, len(data), chunk_size)]
        # lines = split_bytes_by_length(im, NPRINT_LINE_LEN)

        line = im[0:TOTAL_LEN]
        tcp_dport = line[32:34]
        udp_dport = line[92:94]
        dport = bytearray(a | b for a, b in zip(tcp_dport, udp_dport))
        dport = int.from_bytes(dport, 'big')
        
        
        dport_id = find_interval(dport,port_intervals)
        dport = dport_id/len(port_intervals)
        # dport /= MAX_PORT
        dport = dport * 2 - 1
        
        
        
        count = 0
        for i in range(0, len(im), TOTAL_LEN):
            # new_line = line[:22]+line[34:46]+line[98:]
            # line = bytes(line) 
            line = im[i:i+TOTAL_LEN]
            # print(line[0:8])
            time_h,time_l, pkt_len = struct.unpack("IIh", line[:10])
            time_l //= 1e4
            time = time_h + time_l/100
            
            time_id = find_interval(time,time_intervals[count])
            pkt_len_id = find_interval(pkt_len,pkt_len_intervals[count])
            # sign = -1
            
            # if pkt_len < 0:
            #     sign = 1
            #     pkt_len = -pkt_len
            
            time = time_id/len(time_intervals[count])
            pkt_len = pkt_len_id/len(pkt_len_intervals[count])
            
            time = time * 2 - 1
            pkt_len = pkt_len * 2 - 1
            
            seq.append([time,pkt_len,dport])
            count += 1
            if count >= self.max_seq_len:
                break
        
        # 填充 nprint，使其宽度固定
        if len(seq) < self.max_seq_len:
            seq = np.pad(seq, ((0, self.max_seq_len - len(seq)), (0, 0)), mode='constant', constant_values=0)
        
        
        # 转换为PyTorch的Tensor
        labels_one_hot = torch.tensor(self.label_one_hot, dtype=torch.float32)  # 转换为Tensor
        seq = torch.tensor(seq, dtype=torch.float32)
        length = torch.tensor(length, dtype=torch.float32)
        is_real = torch.tensor(is_real, dtype=torch.float32)
        
        return seq, labels_one_hot, length, is_real

In [7]:
def test_data(label_str):
    dataset = MixDataset(source_name,LABEL_DICT,MAX_SEQ_LEN,bins_name,label_str)
    dataloader = DataLoader(dataset, batch_size=128, shuffle=False)
    real_score = 0
    fake_score = 0
    total_score = 0
    
    valid_list = []
    real_list = []
    with torch.no_grad():
        for seqs, labels, lengths, is_reals in dataloader:
            seqs = seqs.to(torch.device("cpu"))
            lengths = lengths.to(torch.device("cpu"))  # 确保在同一个设备上
            labels = labels.to(torch.device("cpu"))
            is_reals = is_reals.to(torch.device("cpu"))
            # 生成随机噪声向量
            # noise = torch.randn(len(lengths), noise_dim)
            # 输入生成器生成数据
            # fake_data = generator(labels, noise, lengths)
            validity = discriminator(labels, seqs, lengths)
            
            validity = torch.sigmoid(validity).squeeze(1)
            
            real_score += torch.sum(validity * (is_reals[:,0]))
            fake_score += torch.sum(validity * (is_reals[:,1]))
            total_score += torch.sum(validity)
            
            valid_list += validity.tolist()
            real_list += is_reals[:,0].tolist()
    
    auc = roc_auc_score(real_list, valid_list)
            
    return real_score,fake_score, total_score, auc

In [8]:
for label in LABEL_DICT.keys():
    print(label,":")
    real_score,fake_score,total_score,auc = test_data(label)
    print(f"Real:{real_score}",f"Fake:{fake_score}",f"Total:{total_score}",f"AUC:{auc}")

facebook :


  labels_one_hot = torch.tensor(self.label_one_hot, dtype=torch.float32)  # 转换为Tensor


Real:1991.2022705078125 Fake:1528.9072265625 Total:3520.10986328125 AUC:0.303364259508925
skype :


  labels_one_hot = torch.tensor(self.label_one_hot, dtype=torch.float32)  # 转换为Tensor


Real:1624.4080810546875 Fake:2048.10791015625 Total:3672.51611328125 AUC:0.7923031624673058
