In [1]:
import mmh3
import pickle
import numpy as np
from tqdm import tqdm
from collections import defaultdict
import pandas as pd
import matplotlib.pyplot as plt

In [2]:
plot_titles = {
    0: "per-source destination flow",
    1: "per-source port flow",
    2: "per-destination source flow",
    3: "per-service source flow",
    4: "per-source service flow"
}

In [3]:
class MIME:
    '''
     prob: the overall sampling rate of MIME
     bits: the number of total bits in MIME
    '''
    def __init__(self, prob, bits, seed):
        self.prob = prob
        self.bits = bits
        self.B = np.zeros(shape = (self.bits, ), dtype = np.int8)
        self.hash_seed = seed
        self.bitcube = dict()
        self.invbitcube = dict()
        self.servicebitcube = dict()
        self.sample_nums = 0
        self.count_one = 0
        self.post_sampling_rate = prob
        self.split_num = int(round(1 / self.prob))
        self.real_spread_set_dc = defaultdict(set)
        self.real_spread_set_dpc = defaultdict(set)
        self.real_spread_set_sc = defaultdict(set)
        self.real_spread_set_service = defaultdict(set)
        self.real_spread_set_sservice = defaultdict(set)
        
        self.real_spreads_dc = defaultdict(int)
        self.pred_spreads_dc = defaultdict(int)
        self.real_spreads_dpc = defaultdict(int)
        self.real_spread_service = defaultdict(int)
        self.real_spread_sservice = defaultdict(int)
        
        self.pred_spreads_dpc = defaultdict(int)
        self.real_spreads_sc = defaultdict(int)
        self.pred_spreads_sc = defaultdict(int)
        self.pred_spreads_service = defaultdict(int)
        self.pred_spreads_sservice = defaultdict(int) 
        
        self.are_dc = 0
        self.are_dpc = 0
        self.are_sc = 0
        self.are_service = 0
        self.are_sservice = 0
        
        self.dc_count = 0
        self.dpc_count = 0
        self.sc_count = 0
        self.service_count = 0
        self.sservice_count = 0
        
        self.are_range_dc = dict()
        self.are_range_dpc = dict()
        self.are_range_sc = dict()
        self.are_range_service = dict()
        self.are_range_sservice = dict()
        
        self.are_count_dc = dict()
        self.are_count_dpc = dict()
        self.are_count_sc = dict()
        self.are_count_service = dict()
        self.are_count_sservice = dict()
        
        for i in range(7):
            self.are_count_dc[i] = 0
            self.are_count_dpc[i] = 0
            self.are_count_sc[i] = 0
            self.are_count_service[i] = 0
            self.are_count_sservice[i] = 0
            
            self.are_range_dc[i] = 0
            self.are_range_dpc[i] = 0
            self.are_range_sc[i] = 0
            self.are_range_service[i] = 0
            self.are_range_sservice[i] = 0
    
    def sample(self, src, dst, port):
        key = src + dst + port
        hash_idx = mmh3.hash(key, seed = self.hash_seed) % self.bits
        if self.B[hash_idx] == 0:
            #sample_idx = mmh3.hash(key, seed = self.sample_seed) % 0xffffffff
            self.B[hash_idx] = 1
            self.count_one += 1
            #if sample_idx <= self.post_sampling_rate * 0xffffffff:
            if hash_idx <= self.post_sampling_rate * self.bits:
                if src not in self.bitcube:
                    self.bitcube[src] = {dst : {port}}
                else:
                    if dst not in self.bitcube[src]:
                        self.bitcube[src][dst] = {port}
                    else:
                        self.bitcube[src][dst].add(port)
                self.sample_nums += 1
            self.post_sampling_rate = (self.prob * self.bits) / (self.bits - self.count_one)
    
    def build_inv_table(self):
        for src in self.bitcube:
            for dst in self.bitcube[src]:
                if dst not in self.invbitcube:
                    self.invbitcube[dst] = {src : self.bitcube[src][dst]}
                else:
                    if src not in self.invbitcube[dst]:
                        self.invbitcube[dst][src] = self.bitcube[src][dst]
    
    def build_service_table(self):
        for dst in self.invbitcube:
            self.servicebitcube[dst] = dict()
            for src in self.invbitcube[dst]:
                for port in self.invbitcube[dst][src]:
                    if port not in self.servicebitcube[dst]:
                        self.servicebitcube[dst][port] = {src}
                    else:
                        self.servicebitcube[dst][port].add(src)
    
    def estimate(self, key, task): # flow label：1 element label：1
        length_dict, num_dict = defaultdict(int), defaultdict(int)
        if task == 0: # DC estimation
            if key not in self.bitcube:
                return 1
            temp_bitarray = self.bitcube[key]
            for index in temp_bitarray:
                length_dict[index] = len(temp_bitarray[index])
            for index in length_dict:
                num_dict[length_dict[index]] += 1
        elif task == 1: # DPC estimation
            if key not in self.bitcube:
                return 1
            temp_bitarray = self.bitcube[key]
            for dst in temp_bitarray:
                for port in temp_bitarray[dst]:
                    length_dict[port] += 1
            for index in length_dict:
                num_dict[length_dict[index]] += 1
        elif task == 2: # SC estimation
            if key not in self.invbitcube:
                return 1
            temp_bitarray = self.invbitcube[key]
            for index in temp_bitarray:
                length_dict[index] = len(temp_bitarray[index])
            for index in length_dict:
                num_dict[length_dict[index]] += 1
        estimate_val = 0.0
        for k,v in num_dict.items():
            if k <=  self.split_num:
                estimate_val += v / (1 - (1 - self.prob) ** k)
            else:
                estimate_val += v / (1 - (1 - self.prob) ** (k / self.prob))
        return int(round(estimate_val))
    
    def estimate_per_service_flow(self, key):
        '''
            per-service flow的流标签由dst和port相结合 flow label：2 element label：1
        '''
        if key[0] not in self.servicebitcube:
            return 1
        temp_bitarray = self.servicebitcube[key[0]]
        if key[1] not in temp_bitarray:
            return 1
        else:
            return int(round(len(temp_bitarray[key[1]]) / self.prob))
    
    def estimate_per_source_service_flow(self, key):
        '''
            per-source service flow的流标签由源地址组成，元素标签由目的地址和目的端口组成 flow label：1 element label：2
        '''
        if key not in self.bitcube:
            return 1
        sum_bits = 0
        for ele in self.bitcube[key]:
            sum_bits += len(self.bitcube[key][ele])
        return int(round(sum_bits / self.prob))
        
    def run(self, filename):
        f = open(filename, 'r')
        datas = f.readlines()
        f.close()
        for pkt in tqdm(datas):
            src, dst, port = pkt.strip().split(",")
            self.real_spread_set_dc[src].add(dst)
            self.real_spread_set_dpc[src].add(port)
            self.real_spread_set_sc[dst].add(src)
            self.real_spread_set_service[dst + " " + port].add(src)
            self.real_spread_set_sservice[src].add(dst+ " " + port)
            self.sample(src, dst, port)
    
    def save_file(self, filename1, filename2):
        f = open(filename1, 'wb')
        pickle.dump(self.real_spread_set_dc, f)
        f.close()
        f = open(filename2, 'wb')
        pickle.dump(self.bitcube, f)
        f.close()

In [4]:
import math
def getOptParams(N, p):
    return int(round(-1.0 * N / math.log(p)))

In [5]:
seed1 = 12412
seed2 = 87123
seed3 = 42131
for minute in range(3):
    filename = "../datas/kpse_datas/0{}.txt".format(minute)
    filename1 = "./real_spreads_dc/0{}.pkl".format(minute)
    filename2 = "./pred_spreads/0{}.pkl".format(minute)
    f = open(filename, 'r')
    dat = f.readlines()
    f.close()
    N = len(dat) #10542501
    p = 0.7
    bits = getOptParams(N, p)
    print(bits, " bits, ", bits / 8 /1024 , "KB.")
    mime = None
    if minute == 0:
        mime = MIME(p, bits, seed1)
    elif minute == 1:
        mime = MIME(p, bits, seed2)
    elif minute == 2:
        mime = MIME(p, bits, seed3)
    mime.run(filename)
    mime.save_file(filename1, filename2)

6815503  bits,  831.9705810546875 KB.


100%|██████████| 2430919/2430919 [00:31<00:00, 76284.81it/s] 


6838134  bits,  834.733154296875 KB.


100%|██████████| 2438991/2438991 [00:30<00:00, 80733.27it/s] 


6803778  bits,  830.539306640625 KB.


100%|██████████| 2426737/2426737 [00:32<00:00, 75012.24it/s] 
