# 导入需要的所有库

In [53]:
import torch as th
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from collections import namedtuple
import random
from copy import deepcopy
from torch.optim import Adam
import numpy as np
import gym
from gym.spaces import Discrete, Box
# from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt

# 定义环境

In [54]:
num_packets = 100000 # 一共生成多少个包 
num_load_balancers = 1 #均衡器的数量
num_servers = 8 #服务器的数量

In [55]:
# 代表包本身
class Packet:
    def __init__(self, ip, time_received, processing_time):
        self.ip = ip
        self.time_received = time_received # 每个包只保留被服务器接收的时间
        self.processing_time = processing_time # 每个包的处理时间
        self.waiting = False
    
    def __repr__(self):
        return " @time: " + str(round(self.time_received,3))

# 代表均衡器
class LoadBalancer:
    def __init__(self, ip):
        self.ip = ip
    
    def distribute(self, server, packet):
        return server.add_packet(packet)
    
    def __repr__(self):
        return "Load Balancer: " + str(self.ip)

# 代表服务器，假设服务器用的是FCFS，并且是non-preemptive
class Server:
    def __init__(self, ip, capacity):
        self.ip = ip
        self.capacity = capacity # 服务器队列的最大限制，最大吞入量
        self.queue = [] # 服务器的请求队列
        self.time_pointer = 0 # 记录相当于服务器而言的时间刻度，服务器只能处理这个时刻之后的包
        self.processed_number = 0 # 记录当前服务器已经处理的包的数量
        self.waiting=0
    
    def reset(self):
        self.queue.clear() #清空请求队列
        self.time_pointer = 0
        self.processed_number = 0
        self.waiting = 0
        
    def add_packet(self, packet):
        if len(self.queue) <= self.capacity: # 如果没有超过最大吞入量，则可以继续加
            packet.waiting = False
            self.queue.append(packet)
            # self.queue.sort(key=lambda x: x.time_received, reverse=False)
            return True
        else:
            packet.waiting = True
            self.queue.append(packet)
            return False
                
    # 还得改
    def process(self, current_time): # 处理当前时刻之前的所有包
        if current_time < self.time_pointer or len(self.queue) == 0:
            pass
        
        while(len(self.queue) > 0 and self.queue[0].time_received + self.queue[0].processing_time <= current_time ):
            self.processed_number += 1
            if self.queue[0].time_received + self.queue[0].processing_time > self.time_pointer:
                self.time_pointer = self.queue[0].time_received + self.queue[0].processing_time
            
            # if self.queue[0].time_received + self.queue[0].processing_time
            if self.queue[0].waiting == True:
                self.waiting += current_time-self.queue[0].time_received
            
            self.queue.pop(0)
            
    def get_processed_number(self):
        return self.processed_number
    
    def __repr__(self):
        return str(self.queue)

In [56]:
class NetworkEnv(gym.Env):
    def __init__(self, num_packets = 0, num_servers=0, num_balancers=0, collaborative=True, server_capacity=None):
        assert num_servers == len(server_capacity)
        
        self.packets = [Packet(ip=i, time_received=random.randrange(0,1), processing_time=random.gauss(0.004,0.002)) for i in range(num_packets)]
        self.packets.sort(key=lambda x: x.time_received, reverse=False)
        
        self.server_capacity = server_capacity # 记录所有server的最大的吞入量

        self.n = num_balancers # agents的数量，也就是均衡器的数量
        self.shared_reward = collaborative #是否是合作模式
        self.time = 0 # 当前的时刻
        self.agents = [LoadBalancer(i) for i in range(num_balancers)]
        self.servers = [Server(i, self.server_capacity[i]) for i in range(num_servers)]
        self.waiting_packets = []
        self.index = 0
        
    def step(self, action_n):
        info = {}
        
        # self.time += 0.004
        # 这一段时间内所有的packet都拿出来然后按照给对应的balancer
        packet=[]+self.waiting_packets
        self.waiting_packets=[]
        mini = 10
        for p in self.packets[self.index:]:
            if p.time_received < self.time:
                packet.append(p)
                self.index += 1
                mini = min(p.processing_time, mini)
            else:
                break

        self.time += min(mini,0.004) #环境时间每次加处理最短的时间
        packet.sort(key=lambda x: x.time_received, reverse=False)        
        # 把包按照他们的ip平均分配给均衡器，再由均衡器按照他们的action跟配给对应的server
        for p in packet:
            lb_id = p.ip % self.n
            self.agents[lb_id].distribute(self.servers[action_n[lb_id]], p)
            # if bol == False:
            #     self.waiting_packets.append(p) #没有分配成功的包进入等待 （还有个想法：给每个包一个bol，分配成功TRUE，分配失败FALSE，之后只计算等待的包的时间，现在是计算所有包的等待时间）
                # print(len(self.waiting_packets))
            
        # 让每个server进行process
        for s in self.servers:
            s.process(self.time)
            
        done = (self.time >= self.packets[-1].time_received)
        
        temp = []
        for s in self.servers:
            if len(s.queue) !=0 :
                temp.append(sum([p.processing_time for p in s.queue])/len(s.queue))
            else:
                temp.append(0)
            temp.append(s.time_pointer)
            #temp.append(s.get_processed_number())
        obs = [temp] * self.n
        
        reward = []
        if done:
            reward = [-sum([s.waiting  for s in self.servers])/len(self.servers)] * self.n # 需要加上std吗
            # reward = [len(self.waiting_packets)]* self.n
        else:
            reward = [0] * self.n
        
        
        return obs, reward, done, info
    
    def reset(self):
        self.packets = [Packet(ip=i, time_received=random.randrange(0,10), processing_time=random.gauss(0.004,0.002)) for i in range(num_packets)]
        self.packets.sort(key=lambda x: x.time_received, reverse=False)
        self.time = 0 # 当前的时刻
        self.index = 0
        # self.agents = [LoadBalancer(i) for i in range(self.n)]
        # self.servers = [s.reset() for s in self.servers]
        temp = []
        for s in self.servers:
            s.reset()
            if len(s.queue) !=0 :
                temp.append(sum([p.processing_time for p in s.queue])/len(s.queue))
            else:
                temp.append(0)
            temp.append(s.time_pointer)
            # temp.append(s.get_processed_number())
        obs = [temp] * self.n
        return obs

In [57]:
env = NetworkEnv(num_packets=num_packets, num_servers=num_servers, num_balancers=num_load_balancers, collaborative=True, server_capacity=[500,100,200,200,1000,500,600,100])

In [58]:
episodes = 10
for e in range(1, episodes + 1):
    state = env.reset()
    done = False
    score = 0 # this is the return
    
    while not done:
        action_n = [random.randint(0, num_servers - 1) for i in range(env.n)]
        obs, reward, done, info = env.step(action_n)
        score += reward[0]
    print('Episode: {} Score: {}'.format(e, score))

Episode: 1 Score: -142.75520038587877
Episode: 2 Score: -143.44209515150146
Episode: 3 Score: -143.132197728536
Episode: 4 Score: -142.74988363269364
Episode: 5 Score: -143.75438204853728
Episode: 6 Score: -145.7352247480723
Episode: 7 Score: -145.1783655520474
Episode: 8 Score: -142.34326776661766
Episode: 9 Score: -141.92909301741912
Episode: 10 Score: -143.13308845723958
