In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import matplotlib.pyplot as plt
import CBEngine
import json
import traceback
import argparse
import logging
import os
import sys
import time
from pathlib import Path
import re
import gym
import numpy as np
from datetime import datetime

import torch
import torch.nn as nn
from torch.distributions.normal import Normal
from torch.distributions.categorical import Categorical
from torch.optim import Adam

import agent.gym_cfg as gym_cfg
from agent.utils import PreprocEnv,get_args
from agent.agent_my import agent_specs, TestAgent
from agent.utils import read_config, process_roadnet, process_delay_index, run_simulation
from core import combined_shape, mlp, count_vars, discount_cumsum, VPGBuffer

In [2]:
simulator_cfg_file = './cfg/simulator.cfg'

gym_configs = gym_cfg.gym_cfg().cfg
simulator_configs = read_config(simulator_cfg_file)
env = gym.make(
    'CBEngine-v0',
    simulator_cfg_file=simulator_cfg_file,
    thread_num=16,
    gym_dict=gym_configs
)
scenario = [
    'test'
]


roadnet_path = Path(simulator_configs['road_file_addr'])
intersections, roads, agents = process_roadnet(roadnet_path)
env.set_log(0)
env.set_warning(0)
args = get_args()

In [14]:
def add_weight_decay(pars, l2_value):
    decay, no_decay = [], []
    for name, param in pars:
        if len(param.shape) == 1 or name.endswith(".bias") or torch.tensor(param.shape).sum()<3: 
            no_decay.append(param)
        else: decay.append(param)
    return [{'params': no_decay, 'weight_decay': 0.}, {'params': decay, 'weight_decay': l2_value}]

def compute_loss_v(data):
    obs, ret = data['obs'], data['ret']
    v_new = []
    for i in range(obs.shape[0]):
        v_new.append(agent.v(obs[i]))
    return ((torch.stack(v_new) - ret)**2).mean()

def compute_loss_pi(data):
    obs, act, adv, logp_old = data['obs'], data['act'], data['adv'], data['logp']
  
    ar_logp = []
    for i in range(act.shape[0]):
        pi, logp  = agent.pi(obs[i], act[i])
        ar_logp.append(logp)
    ar_logp = torch.stack(ar_logp)    
    # Policy loss
    loss_pi = -(ar_logp * adv).mean()
    # Useful extra info
    approx_kl = (logp_old - ar_logp).mean().item()
    ent = pi.entropy().mean().item()
    pi_info = dict(kl=approx_kl, ent=ent) 
    return loss_pi, pi_info

def update():
    data = buf.get()

    # Train policy with a single step of gradient descent
    agent.reset()
    pi_optimizer.zero_grad()
    loss_pi, pi_info = compute_loss_pi(data)
    loss_pi.backward( )
    pi_optimizer.step()

    # Value function learning
    train_v_iters = 10 # wtf!!!!!!!!!!!!!!!!!!!  
    for i in range(train_v_iters):
        agent.reset()
        vf_optimizer.zero_grad()
        loss_v = compute_loss_v(data)
        loss_v.backward( )#     retain_graph=True
        vf_optimizer.step()

In [16]:
o = env.reset()[0]

agent_id_list = []
for k in o:
    agent_id_list.append(int(k.split('_')[0]))
agent_id_list = list(set(agent_id_list))


agent_specs['test'] = TestAgent()  #!!!!!!!!!!!!!!!!!!!!!
agent = agent_specs[scenario[0]]
agent.load_agent_list(agent_id_list )#!!!!!!!!!!!!!!!!!
agent.get_data_from_env(env) #!!!!!!!!!!!!!!!!! знаю, что криво

local_steps_per_epoch = 80
buf = VPGBuffer((22,25), (agent.n_agnts,), local_steps_per_epoch, args.gamma)

arch = {'weight_decay_pi': 1e-4, 'lr_pi': 0.001, 'weight_decay_v': 1e-4, 'lr_v': 0.001}
prs_pi = add_weight_decay(agent.pi.named_parameters(), arch['weight_decay_pi'])
prs_v = add_weight_decay(agent.v.named_parameters(), arch['weight_decay_v'])
pi_optimizer = Adam(prs_pi, lr=arch['lr_pi'])
vf_optimizer = Adam(prs_v, lr=arch['lr_v'])

def make_step(j, act):  
    rwd_t = 0
    for _ in range(args.action_interval): 
        oo, _, d, _ = agent.env_preproc.step(env, act) # simulates 10 sec
        j += 1
        rwd_t += oo
        done = all(d.values())
        if done:
            print('all done')
            break
    rwd_t = rwd_t / args.action_interval    
    return j, done, rwd_t

# for e in range(args.episodes):
Nepochs  = 100
for e in range(Nepochs):
    agent.reset()

    
    rw_average = []
    o = agent.env_preproc.preproc_obs(env.reset()[0])
    i = 0
    
    for t in range(local_steps_per_epoch):  
        a, v, logp = agent.step(o)
        # We keep the same action for a certain time          
        i, don, next_o = make_step(i, a)
        rw = next_o[:, 13: 25].sum(dim=1)-next_o[:, 1: 13].sum(dim=1)
        buf.store(o, a, rw.mean(), v, logp)
        rw_average.append(rw.mean())
        # Update obs (critical!)
        o = next_o
        if don or i>int(args.steps/10)-args.action_interval-1:
            print(don, i)
            break         

    v = 0
    if not don:
        _, v, _ = agent.step(o)        
    buf.finish_path(v)
    
    update()

    
    print("ep:{}/{}, average tt:{} , mean rwd: {}".format(e, Nepochs, env.eng.get_average_travel_time(), np.mean(rw_average)))

ep:0/100, average tt:309.05070307625857 , mean rwd: 0.47826701402664185
ep:1/100, average tt:309.1964945300553 , mean rwd: -0.2809658944606781
ep:2/100, average tt:308.6563402407189 , mean rwd: 0.3779829442501068
ep:3/100, average tt:309.00997393176925 , mean rwd: 0.1457386314868927
ep:4/100, average tt:308.6411853697606 , mean rwd: 0.6017045378684998
ep:5/100, average tt:308.06065437707497 , mean rwd: 0.6129261255264282
ep:6/100, average tt:307.9430158491861 , mean rwd: 0.5723011493682861
ep:7/100, average tt:308.00596316743224 , mean rwd: 0.594318151473999
ep:8/100, average tt:307.87474375168244 , mean rwd: 0.5282670855522156
ep:9/100, average tt:307.99399613310266 , mean rwd: 0.3475852310657501
ep:10/100, average tt:308.15105135746154 , mean rwd: 0.5223010778427124
ep:11/100, average tt:307.9895105583217 , mean rwd: 0.6146306991577148
ep:12/100, average tt:308.0913563031395 , mean rwd: 0.44275569915771484
ep:13/100, average tt:307.8068714508937 , mean rwd: 0.5340908765792847
ep:14/1

KeyboardInterrupt: 

In [None]:
agent.reset()
scores = run_simulation(agent_specs, simulator_cfg_file, gym_cfg)
print('total_served_vehicles ', scores[0],'  ;delay_index ', scores[1])

INFO:/starter-kit/agent/utils.py:

INFO:/starter-kit/agent/utils.py:****************************************
