In [1]:
#!/usr/bin/env python
# coding: utf-8

# In[1]:


import sys
import torch  
import gym
import queue
import numpy as np  
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
import matplotlib.pyplot as plt
import math
import random
import os

FLOAT_BYTE_SIZE = 8
HOST_CPU_FLOP = 62000000000
MAX_MEMORY = 0
MEMORY_LEVELS = 30
INVALID = -1
INVALID_INCR = -0.0001
BETA = 0.5
GAMMA = 0.9
WORKER_MEM = 165

MAX_OP_PARAM = []
MAX_ED_PARAM = []

def init():
    for i in range(5):
        MAX_OP_PARAM.append(0.0)
    for i in range(5):
        MAX_ED_PARAM.append(0.0)


init()

def calc_level(mem):
    #global MAX_MEMORY, MEMORY_LEVELS
    interval = MAX_MEMORY/MEMORY_LEVELS
    
    if(mem<=0):
            return 0

    return int(mem/interval)+1

In [2]:
class Computation_Graph:
    
    def __init__(self, filename, memoryfile):
        
        #batch size
        #number of layers, number of connections
        #for each layer
            #layer name, layer_mem in MB, # of params in layer,  dimension2, dimensioon3, ...

        fileReader = open(filename,'r')
        memoryfile = open(memoryfile,'r')

        #self.batch_size = int(fileReader.readline())
        self.batch_size = 1

        vertex, edge = (fileReader.readline().split())
        self.num_layers = int(vertex)
        self.num_connections = int(edge)

        self.layer_estimated_mem = []
        self.layer_estimated_float_op = [] 
        self.layer_output_shape = []
        self.layer_name_to_idx_map = {}
        self.layer_idx_to_name_map = {}
        self.adj_list = []
        self.rev_adj_list = []

        for i in range(self.num_layers):
            
            print(i)
            self.adj_list.append([])
            self.rev_adj_list.append([])
            
            layer_info = fileReader.readline().split() #layer_info is an array now
            memory_info = memoryfile.readline().split()
            
            print(memory_info)
            
            self.layer_name_to_idx_map[layer_info[0]] = i
            self.layer_idx_to_name_map[i] = layer_info[0]

            self.layer_estimated_mem.append(float(memory_info[1])-WORKER_MEM)
            global MAX_MEMORY
            MAX_MEMORY = max(MAX_MEMORY,float(memory_info[1]))

            self.layer_estimated_float_op.append(float(layer_info[1]))
            
            #this has to be rewritten
            output_byte = self.batch_size*FLOAT_BYTE_SIZE #floating point number is 4 byte

            #for k in range(3,len(layer_info)):
                #output_byte = output_byte*(float(layer_info[k]))

            #self.layer_output_shape.append(float(output_byte))
            self.layer_output_shape.append(0.0)
            #this has to be rewritten

        for i in range(self.num_connections):
            #this has to be rewritten
            u,v,w = fileReader.readline().split()
            self.adj_list[self.layer_name_to_idx_map[u]].append(self.layer_name_to_idx_map[v])
            self.rev_adj_list[self.layer_name_to_idx_map[v]].append(self.layer_name_to_idx_map[u])
            self.layer_output_shape[self.layer_name_to_idx_map[u]] = float(w)

        #print(type(self.adj_list[0][0]))
        #print(self.adj_list)
        fileReader.close()

    def print_info(self):
        print(self.layer_estimated_mem)
        print(self.layer_estimated_float_op)
        print(self.layer_output_shape) 
        
    def get_mem_required(self, layer_name, MB = 1):

        val = self.layer_estimated_mem[self.layer_name_to_idx_map[layer_name]]
        
        if MB == 0:
            return (float(val))*(1024.0*1024.0)

        return val

    def get_mem_level(self, layer_idx):

        global MAX_MEMORY,MEMORY_LEVELS

        interval = MAX_MEMORY/MEMORY_LEVELS

        if(self.layer_estimated_mem[layer_idx]<=0):
            return 0

        return int(self.layer_estimated_mem[layer_idx]/interval)+1


    def get_cpu_required(self, layer_name):

        return self.layer_estimated_float_op[self.layer_name_to_idx_map[layer_name]]

    def get_data_transferred(self, layer_u, layer_v):

        u = self.layer_name_to_idx_map[layer_u]
        v = self.layer_name_to_idx_map[layer_v]

        if v in self.adj_list[u]:
            return self.layer_output_shape[u]

        return 0

    def toposort(self):
        
        topoReader = open("topo_out.txt")
        topo_order = []

        for i in range(self.num_layers):
            v, level = topoReader.readline().split()
            topo_order.append((int(v),int(level)))

        return (1,topo_order)

    def check_topo_sort(self):

        acyclic,topo = self.toposort()

        if(acyclic==1):

            for v in topo:
                print(self.layer_idx_to_name_map[v[0]])

        else:
            print('no topo exists')


#sample input for testing one correctness
# 30
# 3 3
# conv_1 20 10 28 28 1
# conv_2 10 20 32 32 1
# dense_1 100 10 128 128
# conv_2 dense_1
# conv_1 dense_1
# conv_1 conv_2


In [3]:
class Device_Graph:
    
    def __init__(self,filename):
        
        #num_devices.. where devices connected in star topology
        #name, num_flop, memory in MB

        fileReader = open(filename,'r')
        
        device_no = int(fileReader.readline())
        
        self.num_devices = device_no
        self.device_memory = []
        self.device_flops = [] 
        self.device_name_to_idx_map = {}
        self.device_idx_to_name_map = {}
        self.bridge_bandwidth = [] #bandwidth to bridge
        self.filename = filename

        for i in range(device_no):
            
            edge_info = fileReader.readline().split()
            
            self.device_memory.append(float(edge_info[2])-WORKER_MEM)
            global MAX_MEMORY
            MAX_MEMORY = max(MAX_MEMORY,float(edge_info[2]))

            self.device_flops.append(float(edge_info[1]))
            
            self.device_name_to_idx_map[edge_info[0]] = i
            self.device_idx_to_name_map[i] = edge_info[0]

        for i in range(device_no):
            bw_master = float(fileReader.readline())
            self.bridge_bandwidth.append(bw_master)
        
        #print(self.bridge_bandwidth)

    def get_mem_level(self, device_idx):

        interval = MAX_MEMORY/MEMORY_LEVELS

        if(self.device_memory[device_idx]<=0):
            return 0

        return int(self.device_memory[device_idx]/interval)+1

    def get_device_flop(self, device_idx):

        return self.device_flops[device_idx]

    def print_info(self):
        print(self.device_memory)
        print(self.device_flops)
        print(self.bridge_bandwidth)

In [4]:
class Environment:
    
    def __init__(self, comp_graph, device_graph):
        
        self.comp_graph, self.device_graph = comp_graph, device_graph 
        self.current_state = np.array([],np.float)
        self.has_topo, self.topo_order = self.comp_graph.toposort()
        print(self.topo_order)
        self.available_mem = self.device_graph.device_memory
        self.placement_dict = {}
        
        
        self.valid_placements = 0
        self.invalid_placements = 0
        
    def get_layer_tuple(self, layer_idx): #this idx is topo_sort[idx], from this we get 
        #layer_idx
        #flop required
        #data transfer
        #required memory level
        ##number of neighbors
        return [
            layer_idx,self.comp_graph.layer_estimated_float_op[layer_idx],
            self.comp_graph.layer_output_shape[layer_idx],self.comp_graph.get_mem_level(layer_idx)
        ]

    def get_edge_tuple(self, device_idx):

        #current memory level of device
        #number of operations incorporated
        #bandwidth to bridge
        #flops power
        ##number of floating point operations performed
        return [
            self.device_graph.get_mem_level(device_idx),
            0,
            self.device_graph.bridge_bandwidth[device_idx],
            self.device_graph.device_flops[device_idx]
        ]


    def print_current_state(self):
        #state e ektai layer ... [0][1] to [0][3]
        print("will place "+str(self.current_state[0][0])+" now",end=' ')
        print(", requires "+str(self.current_state[0][1])+" flops",end=' ')
        print(", will tranx "+str(self.current_state[0][2])+" data",end=' ')
        print(", requires "+str(self.current_state[0][3])+" level memory")
        
        #device related info

        for i in range(1,1+len(self.device_graph.device_memory)):
            print("device "+str(i-1),end=' ')
            print(", has memory level "+str(self.current_state[i][0]),end=' ')
            print(", will exec "+str(self.current_state[i][1])+" operations",end=' ')
            print(", has "+str(self.current_state[i][2])+" bw with bridge",end=' ')
            print(", has "+str(self.current_state[i][3])+" flops")
            
        print('')

        #print(self.current_state)

    def reset(self):
        
        init_state = []
        self.done = 0
        
        self.valid_placements = 0
        self.invalid_placements = 0
        
        #for i in range(self.comp_graph.num_layers):
        init_state.append(self.get_layer_tuple(self.topo_order[0][0]))

        for i in range(len(self.device_graph.device_memory)):
            init_state.append(self.get_edge_tuple(i))

        np_init_state = np.array(init_state,np.float)

        self.current_state = np_init_state
        self.available_mem = self.device_graph.device_memory.copy() #eita reference e hocche na 

        return self.current_state
    
    def reset_device_graph(self, device_graph):
        self.device_graph = device_graph
        return self.reset()

    def is_valid_placement(self): # need further edit
        
        d = len(self.device_graph.device_memory)

        for i in range(1,1+d):
            if(self.current_state[i][0]<=0):
                return 0
        return 1
    
    def evaluate_placement(self): #edit korte hobe
        
        l = self.comp_graph.num_layers
        d = len(self.device_graph.device_memory)

        topo_idx = [] #will contain the index of vertex v in topo_order
        # original vertex list = {0,1,2,3,4}
        # in topo order = {2,1,3,0,4}
        # in topo idx = {3,1,0,2,4}
        for i in range(len(self.topo_order)):
            topo_idx.append(-1)

        for i in range(len(self.topo_order)):
            topo_idx[self.topo_order[i][0]] = i
        
        # if(self.is_valid_placement() == 0):
        #     global INVALID
        #     INVALID = INVALID + INVALID_INCR
        #     return INVALID
        
        estimated_tx_time = 0.0
        
        for i in range(l): #topo_order er serial e jacche
            
            #because ith placmement directs the placement of the ith vertex in topo_sort
            cur_device = int(self.placement_dict[self.topo_order[i][0]]) 
            
            for u in self.comp_graph.adj_list[self.topo_order[i][0]]:
                
                nxt_device =  int(self.placement_dict[u]) #we need to get the index of u from topo_idx
                payload = self.comp_graph.layer_output_shape[self.topo_order[i][0]]
                
                if(cur_device != nxt_device):
                    estimated_tx_time = estimated_tx_time + \
                    (payload/self.device_graph.bridge_bandwidth[cur_device]) + \
                    (payload/self.device_graph.bridge_bandwidth[nxt_device])  
        
        estimated_ex_time= 0.0
        

        level_wise = []
        #print(self.topo_order)
        #print(level_wise)
        for i in range(len(self.topo_order)+3):
            level_wise.append([])
        
        for i in range(len(self.topo_order)):
            #print(i)
            level_wise[self.topo_order[i][1]].append(self.topo_order[i][0])
        
        for cur_level in range(len(level_wise)):
            
            device_time = []
            
            for i in range(len(self.device_graph.device_memory)):
                device_time.append(0.0)
            
            for i in range(len(level_wise[cur_level])):
                
                cur_v = level_wise[cur_level][i]
                cur_device = self.placement_dict[cur_v]

                flop_req = self.comp_graph.layer_estimated_float_op[cur_v]
                flops_got = self.device_graph.device_flops[cur_device]
            
                device_time[cur_device] = device_time[cur_device]+flop_req/flops_got
            
            estimated_ex_time = estimated_ex_time + max(device_time)

        #estimated_energy_consumption = []

        #for i in range(l,l+d):
        #    estimated_energy_consumption.append(self.current_state[i][4])

        #return (1.0-BETA)*(estimated_ex_time+estimated_tx_time)+ \ 
            #BETA*(np.average(estimated_energy_consumption)+np.var(estimated_energy_consumption))
#         return (1.0-BETA)*estimated_tx_time+BETA*estimated_ex_time
#         return BETA*estimated_ex_time+BETA*estimated_tx_time
        return estimated_tx_time + estimated_ex_time

    def step(self, action): #action is the index of a device

        l = self.comp_graph.num_layers
        d = len(self.device_graph.device_memory)
        reward = 0.0
        
        layer_processed = int(self.current_state[0][0])
        #current memory level of device
        #number of operations incorporated
        #bandwidth to bridge
        #flops power
        self.available_mem[action] = self.available_mem[action] - \
            self.comp_graph.layer_estimated_mem[layer_processed]
        self.current_state[action+1][0] = calc_level(self.available_mem[action])
        self.current_state[action+1][1] = self.current_state[action+1][1] + 1

        self.done = self.done+1
        self.placement_dict[layer_processed] = action
        
        if self.done==l:

            if self.is_valid_placement()==0:
                self.invalid_placements += 1
#                 print(f"Finished with invalid placement")
#                 return self.current_state,self.placement_dict,-10,1
                return self.current_state,self.placement_dict,-1 * self.invalid_placements,1
#                 return self.current_state,self.placement_dict,0,1
            else:
                self.valid_placements += 1
                cost = self.evaluate_placement()
                reward = 1.0/math.sqrt(cost)
                #return self.current_state,self.placement_dict,10*reward,1
                print(f"Finished with correct placement, reward={reward}, evaluation={cost}")
#                 print(reward)
#                 return self.current_state,self.placement_dict,100*reward,1
                return self.current_state,self.placement_dict,reward,1
        else:
            estimated_tx_time = 0.0
            
            for u in self.comp_graph.rev_adj_list[layer_processed]:
                
                prev_device =  int(self.placement_dict[u]) #we need to get the index of u from topo_idx
                payload = self.comp_graph.layer_output_shape[u]
                
                if(action != prev_device):
                    estimated_tx_time = estimated_tx_time + \
                    (payload/self.device_graph.bridge_bandwidth[prev_device]) + \
                    (payload/self.device_graph.bridge_bandwidth[action])  
            
            # Why?
            if(estimated_tx_time<=0.0):
                reward = 3
            else:
                reward = 1.0/math.sqrt(estimated_tx_time)
            
            self.current_state[0] = self.get_layer_tuple(self.topo_order[self.done][0])
            
            
            ### REWARD FUNCTION?
            
#             print(f"REWARD: {reward}")
            
            
            if (self.available_mem[action]<=0.0):
                self.invalid_placements += 1
#                 return self.current_state,self.placement_dict,0,0
                return self.current_state,self.placement_dict,-5,0
            else:
                self.valid_placements += 1
#                 print(reward)
                return self.current_state,self.placement_dict,0,0
#                 return self.current_state,self.placement_dict,10*reward,0
#                 return self.current_state,self.placement_dict,0,0


layer_file = "ml_graph_vgg.txt"
mem_file = "vgg_new_memory_calc.txt"

c = Computation_Graph(layer_file,mem_file)
#c.print_info()

device_file = "dev_graph.txt"
d = Device_Graph(device_file)
d2 = Device_Graph("dev_graph_2.txt")
d3 = Device_Graph("dev_graph_3.txt")
d4 = Device_Graph("dev_graph_4.txt")
d5 = Device_Graph("dev_graph_5.txt")
d6 = Device_Graph("dev_graph_6.txt")
d7 = Device_Graph("dev_graph_7.txt")

devices = [d, d2, d3, d4, d5, d6, d7]
#d.print_info()

#c.check_topo_sort()


e = Environment(c,d3)

cur_state = e.reset()
e.print_current_state()

for i in range(0,5):
    e.step(i)
    e.print_current_state()
    print(' ')

print(MAX_MEMORY)
print(MEMORY_LEVELS)

# # for i in range(0,18):
# #     e.step(0)
# #     e.print_current_state()
# #     print('')

0
['conv_1', '1153.696']
1
['conv_2', '1929.888']
2
['pool_1', '545.7']
3
['conv_3', '1134.24']
4
['conv_4', '1264.288']
5
['pool_2', '390.9']
6
['conv_5', '654.1']
7
['conv_6', '664.9']
8
['conv_7', '543.1']
9
['pool_3', '363.7']
10
['conv_8', '498.178']
11
['conv_9', '740.3']
12
['conv_10', '740.9']
13
['pool_4', '305.3']
14
['conv_11', '640.4']
15
['conv_12', '695.9']
16
['conv_13', '670.5']
17
['pool_5', '283.9']
18
['flat_1', '200']
19
['dense_1', '630']
20
['dense_2', '670']
21
['dense_3', '324']
[(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7), (7, 8), (8, 9), (9, 10), (10, 11), (11, 12), (12, 13), (13, 14), (14, 15), (15, 16), (16, 17), (17, 18), (18, 19), (19, 20), (20, 21), (21, 22)]
will place 0.0 now , requires 50176.0 flops , will tranx 50176.0 data , requires 8.0 level memory
device 0 , has memory level 25.0 , will exec 0.0 operations , has 214.0 bw with bridge , has 898887170.0 flops
device 1 , has memory level 13.0 , will exec 0.0 operations , has 129.0 bw with b

In [5]:
class PolicyNetwork(nn.Module):
    #num_inputs = size of state vectors == number of operations
    #num_actions = size of edge devices == number of actions
    
    def __init__(self, num_inputs, num_actions, hidden_size, learning_rate=3e-4):
        super(PolicyNetwork, self).__init__()

        self.num_actions = num_actions
        self.linear1 = nn.Linear(num_inputs, hidden_size) #two layer NN as policy network
        #self.linear3 = nn.Linear(32,16)
        #self.linear4 = nn.Linear(16,16)
        self.linear5 = nn.Linear(hidden_size,hidden_size*2)
#         self.linear8 = nn.Linear(hidden_size*2,hidden_size*2)
        self.linear6 = nn.Linear(hidden_size*2,hidden_size)
        #self.linear7 = nn.Linear(64,32)
        self.linear2 = nn.Linear(hidden_size, num_actions)
        self.optimizer = optim.Adam(self.parameters(), lr=learning_rate)

    def forward(self, state):
        x = F.relu(self.linear1(state))
        #x = F.relu(self.linear3(x))
        #x = F.relu(self.linear4(x))
        x = F.relu(self.linear5(x))
#         x = F.relu(self.linear8(x))
        x = F.relu(self.linear6(x))
        #x = F.relu(self.linear7(x))
        x = F.softmax(self.linear2(x), dim=1)
        return x 
    
    def get_action(self, state):
        state = torch.from_numpy(state).float().unsqueeze(0)
        #print(state)
        #print(self.linear1.weight.data)
        #print(self.linear2.weight.data)
        probs = self.forward(Variable(state))
        #print(probs)
        #print(np.squeeze(probs.detach().numpy()))
        highest_prob_action = np.random.choice(self.num_actions, p=np.squeeze(probs.detach().numpy()))
        #print(highest_prob_action)
        log_prob = torch.log(probs.squeeze(0)[highest_prob_action])
        
        return highest_prob_action, log_prob

def update_policy(policy_network, rewards, log_probs):
    
    discounted_rewards = []
    #print(rewards)
    for t in range(len(rewards)):
        Gt = 0 
        pw = 0
        for r in rewards[t:]:
            Gt = Gt + GAMMA**pw * r
            pw = pw + 1
        discounted_rewards.append(Gt)
    
    discounted_rewards = torch.tensor(discounted_rewards)
#     discounted_rewards = torch.tensor(rewards)
    discounted_rewards = (discounted_rewards - discounted_rewards.mean()) / (discounted_rewards.std() + 1e-9) 
    #print(discounted_rewards)
    
    policy_gradient = []
    for log_prob, Gt in zip(log_probs, discounted_rewards):
        policy_gradient.append(-log_prob * Gt)
    
    policy_network.optimizer.zero_grad()
    policy_gradient = torch.stack(policy_gradient).sum()
    policy_gradient.backward()
    policy_network.optimizer.step()

In [6]:
import copy

def main(placement_environment, devices):
    env = placement_environment
    #v = env.op_graph.toposort()
    layer_num = env.comp_graph.num_layers
    device_num = len(env.device_graph.device_memory)

    inp_size_nn = (1+device_num)*4
    policy_net = PolicyNetwork(inp_size_nn, device_num, 64) #all 3 arguments are int

    
    converge = 0
    max_epoch = 100000
    best_state = {}
    best_reward = 0.0
    best_policy = copy.deepcopy(policy_net)
    
    best_success_rate = 0.0
    
    total_score = 0
    total_count = 0
    
    ok_count = 0
    
#     devices_idx = 0
    device_successes = {d.filename: 0 for d in devices}
    device_placements = {d.filename: 0 for d in devices}
    
    
    while converge<max_epoch:
        
        # TODO(brian): swap in new/random device_graph here?
#         state = env.reset()  #state has been made compatible
        state = env.reset_device_graph(random.choice(devices))
#         devices_idx = (devices_idx + 1) % len(devices)
        log_probs = []
        rewards = []
        
        if converge%5000==0:
            print(converge)
        
        while True:
            
            temp_state = state.copy()

            for p in range(4):
                temp_state[0][p] = float(temp_state[0][p])/MAX_OP_PARAM[p]

            for d in range(1,1+device_num):
                for p in range(4):
                    temp_state[d][p] = float(temp_state[d][p])/MAX_ED_PARAM[p]

            #network_state = temp_state.flatten()
            #env.print_current_state()
            #inp = (network_state-np.min(network_state))/np.ptp(network_state)
            #inp = network_state
            #val = np.array([np.sum(inp)/2000000000.0,np.sum(inp)/3000000000.0])
            #print(val)
            #print(inp)
            action, log_prob = policy_net.get_action(temp_state.flatten())
            #print(action)
            #print(log_prob)
            #print('action chosen '+str(action))
            new_state, mapping, reward, done = env.step(action)
            #env.print_current_state()
            log_probs.append(log_prob)
            rewards.append(reward)
            
            if done:
                
                #for i in range(layer_num):
                #    print(state[i][0], end = " ")
                #print('')

                converge += 1
                score = env.evaluate_placement()
                device_placements[env.device_graph.filename] += 1
                if env.valid_placements == env.comp_graph.num_layers:
                    total_score += score
                    total_count += 1
                    ok_count += 1
                    
                    device_successes[env.device_graph.filename] += 1
                
                # number of good rewards
                pos = [num for num in rewards if num>0.0]
                
                log_frequency = 100
                if converge%log_frequency==0:
#                     print(f"{converge}: pos={len(pos)}/{len(rewards)}, reward={reward}, evaluation={score}")
                    print(f"{converge}:")
                    print(f"\tpos={env.valid_placements}/{len(rewards)}")
                    print(f"\treward={reward}")
                    print(f"\tevaluation={score}")
                    print(f"\tgraph={env.device_graph.filename}")
#                     print(env.valid_placements)
#                     print(env.invalid_placements)
                    rolling_success_rate = ok_count / log_frequency
                    print(f"\tsuccess rate: {rolling_success_rate}%")
                    if total_count > 0:
                        print(f"\trolling average score of ok = {total_score / total_count}")
                    if total_count > 100:
                        total_score = 0
                        total_count = 0
                    ok_count = 0
                    
                    print("\tdevice graph success rates")
                    for name, successes in device_successes.items():
                        print(f"\t\t{name}: {successes/device_placements[name]}%")
                        device_successes[name] = 0
                        device_placements[name] = 0
                    
                    if rolling_success_rate > best_success_rate:
                        best_success_rate = rolling_success_rate
                        best_state = mapping
                        best_policy = copy.deepcopy(policy_net)
                        torch.save(best_policy.state_dict(), './mymodel.pt')
                        print("\tcurrent best state ")
                        print("\t", best_state)          
                    
                    print(best_policy.linear1.weight.data)
                
                update_policy(policy_net, rewards, log_probs)
                
#                 if len(pos) > best_reward:
#                     best_reward = len(pos)
#                     best_state = mapping
#                     best_policy = copy.deepcopy(policy_net)
#                     torch.save(best_policy.state_dict(), './mymodel.pt')
#                     print("current best state ")
#                     print(best_state)          



                #if converge%1000==0:
                #print("state "+str(state))
                #print("reward incurred "+str(reward))
                break
            
            state = new_state
    

      
    return best_state,best_reward, best_policy, inp_size_nn, device_num 

# print("hello")

myenv = e

def set_MAX_PARAMS(environment):
    #layer_idx         
        #flop required
        #data transfer
        #required memory level
    MAX_OP_PARAM[0] = float(len(environment.comp_graph.layer_estimated_mem))
    MAX_OP_PARAM[1] = float(max(environment.comp_graph.layer_estimated_float_op))
    MAX_OP_PARAM[2] = float(max(environment.comp_graph.layer_output_shape))
    MAX_OP_PARAM[3] = float(MEMORY_LEVELS + 1)
    #current memory level of device
        #number of operations incorporated
        #bandwidth to bridge
        #flops power

    MAX_ED_PARAM[0] = float(MEMORY_LEVELS + 1)
    MAX_ED_PARAM[1] = float(environment.comp_graph.num_layers)
    MAX_ED_PARAM[2] = float(max(environment.device_graph.bridge_bandwidth))
    MAX_ED_PARAM[3] = float(max(environment.device_graph.device_flops))

In [7]:
set_MAX_PARAMS(e)

best_state,best_reward, best_policy, inp_size, device_num = main(e, devices)
print("training complete")

print("best state "+str(best_state))
print("estimated computation time "+str(1.0/best_reward**2)+" second")



0
Finished with correct placement, reward=0.024293979601750774, evaluation=1694.3482366073258
Finished with correct placement, reward=0.03055435041914235, evaluation=1071.158873284928
100:
	pos=18/22
	reward=-4
	evaluation=1223.006521566016
	graph=dev_graph_5.txt
	success rate: 0.02%
	rolling average score of ok = 1382.7535549461268
	device graph success rates
		dev_graph.txt: 0.0%
		dev_graph_2.txt: 0.0%
		dev_graph_3.txt: 0.05555555555555555%
		dev_graph_4.txt: 0.09090909090909091%
		dev_graph_5.txt: 0.0%
		dev_graph_6.txt: 0.0%
		dev_graph_7.txt: 0.0%
	current best state 
	 {0: 0, 1: 0, 2: 6, 3: 4, 4: 5, 5: 2, 6: 2, 7: 8, 8: 3, 9: 5, 10: 7, 11: 7, 12: 1, 13: 6, 14: 1, 15: 4, 16: 5, 17: 5, 18: 1, 19: 3, 20: 8, 21: 3}
tensor([[ 0.0495, -0.1079,  0.0237,  ...,  0.1101,  0.1295,  0.1123],
        [ 0.1154, -0.0712,  0.0546,  ..., -0.1004,  0.1219,  0.0524],
        [-0.1361, -0.0961,  0.1214,  ...,  0.0880,  0.1184, -0.1251],
        ...,
        [ 0.0838, -0.1502, -0.0584,  ..., -0.084

Finished with correct placement, reward=0.015447089182713881, evaluation=4190.894152092443
Finished with correct placement, reward=0.033315053672555835, evaluation=900.9879142508726


KeyboardInterrupt: 

In [None]:
print(best_policy.linear1.weight.data)
print(best_policy.num_actions)

layer_file = "ml_graph_vgg.txt"
mem_file = "vgg_new_memory_calc.txt"

c = Computation_Graph(layer_file,mem_file)


device_file = "dev_graph.txt"
d = Device_Graph(device_file)
#d.print_info()

#c.check_topo_sort()


e_another = Environment(c,d)

cur_state = e_another.reset()
#e_another.print_current_state()



for i in range(0,22):
    print(i)
    temp_state = cur_state.copy()

    for p in range(4):
        temp_state[0][p] = float(temp_state[0][p])/MAX_OP_PARAM[p]

    for d in range(1,1+device_num):
        for p in range(4):
            temp_state[d][p] = float(temp_state[d][p])/MAX_ED_PARAM[p]

    action, log_prob = best_policy.get_action(temp_state.flatten())
    #print(action)
    cur_state, mapping, reward, done = e_another.step(action)
    #e_another.print_current_state()

#print(mapping)
mapping_file = "mapping_default.txt"

f = open(mapping_file,"w")

for key in mapping:
    f.write(c.layer_idx_to_name_map[key]+' '+str(mapping[key])+'\n')
f.close()

print('done')

In [8]:
def place_graph(env, model, device_graph):
    state = env.reset_device_graph(device_graph)
    
    for i in range(env.comp_graph.num_layers):
        curr_state = state.copy()
        
        for p in range(4):
            curr_state[0][p] = float(curr_state[0][p])/MAX_OP_PARAM[p]

        for d in range(1,1+env.device_graph.num_devices):
            for p in range(4):
                curr_state[d][p] = float(curr_state[d][p])/MAX_ED_PARAM[p]
        
        action, log_prob = model.get_action(curr_state.flatten())
        curr_state, mapping, reward, done = env.step(action)

    if env.valid_placements == env.comp_graph.num_layers:
        print("Good!")
        return True
    print("Bad :(")
    return False

import device_gen
import os

def validate(num_graphs=5, num_trials=100):
    val_graph_files = [f'tmp_val_dev_graph_{i}.txt' for i in range(num_graphs)]
    
    for filename in val_graph_files:
        print(filename)
        device_gen.write_graph(9, [300000000, 1000000000], [1024, 4096], [50, 1000], filename=filename)
    
    val_device_graphs = [Device_Graph(file) for file in val_graph_files]

    # Load best policy network
    num_devices = len(val_device_graphs[0].device_memory)
    nn_input_size = (num_devices + 1) * 4
    model = PolicyNetwork(nn_input_size, num_devices, 64)
    model.load_state_dict(torch.load('./mymodel.pt'))
    model.eval()
    
    # Create compute graph
    layer_file = "ml_graph_vgg.txt"
    mem_file = "vgg_new_memory_calc.txt"
    compute_graph = Computation_Graph(layer_file,mem_file)

    # create Env
    env = Environment(compute_graph, val_device_graphs[0])
    
    # Run trials
    results = []
    results_per_graph = {d.filename: 0 for d in val_device_graphs}
    
    for _ in range(num_trials):
        for device_graph in val_device_graphs:
            res = place_graph(env, policy_network, device_graph)
            results_per_graph[device_graph.filename] += int(res)
            results.append(res)
    
    print("Final percentage of valid placements: ", sum(results) / len(results))
    print("Per device graph:")
    for name, count in results_per_graph.items():
        print(f"\t{name}: {count / num_trials}%")
    
    ## remove temp files
    for file in val_graph_files:
        if os.path.exists(file):
            os.remove(file)
    


In [9]:
validate()

ValueError: invalid literal for int() with base 10: ''

tmp_val_dev_graph_0.txt
