In [1]:
%load_ext autoreload
%autoreload 2
import ipywidgets as widgets
from ipywidgets import interact, interactive

from citylearn.citylearn import CityLearnEnv
import time, random, typing, cProfile, traceback
import numpy as np
from common.initialization_methods import *
from visualiser.frame_cache import *

class Constants:
    episodes = 3
    schema_path = './data/citylearn_challenge_2022_phase_1/schema.json'
    steps_per_frame_save = 1

from agents.agents import ddpg
from agents.networks import central_critic, comm_net
from agents.features import *
from rewards import get_reward, rewards

from agents.orderenforcingwrapper import OrderEnforcingAgent

agent_wrapper = OrderEnforcingAgent(agent = ddpg.DDPGAgent(
    actor = comm_net.CommNet,
    critic = central_critic.CentralCritic,
    actor_feature=BaseFeatureEngineer(),
    critic_feature=CentralCriticEngineer(BaseFeatureEngineer())
))

get_reward.reward_function = rewards.default_reward

agent_dict = {"ddpg" : ddpg.DDPGAgent}
agent_spec_dict = {"ddpg" : {
    "actor" : comm_net.CommNet,
    "critic" : central_critic.CentralCritic,
    "actor_feature":BaseFeatureEngineer(),
    "critic_feature":CentralCriticEngineer(BaseFeatureEngineer()),
    "a_kwargs" : {},
    "c_kwargs" : {},
    "gamma" : 0.99, 
    "lr" : 3e-4 ,
    "tau" : 0.001,
    "batch_size" : 32,
    "memory_size" : 4096,
    "device" : 'cpu'
}}
agent_obj_field = {"ddpg" : {
    "actor", "critic", "actor_feature", "critic_feature"
}}
agent_obj_value = {"ddpg" : {
    "comm_net.CommNet" : comm_net.CommNet,
    "central_critic.CentralCritic" :central_critic.CentralCritic,
    "BaseFeatureEngineer()" :BaseFeatureEngineer(),
    "CentralCriticEngineer(BaseFeatureEngineer())" :CentralCriticEngineer(BaseFeatureEngineer())
}}

reward_function_dict = {"default_reward" : rewards.default_reward}

import json
with open('json_example.json') as f:
    specs = json.load(f)

agent_input_specs = specs.get("agent")
agent_name = agent_input_specs.get("name")
agent_use_specs = agent_spec_dict.get(agent_name)

for key in agent_use_specs:
    if key in agent_input_specs.keys():
        if key in agent_obj_field[agent_name]:
             obj_key = agent_input_specs[key]
             agent_use_specs[key] = agent_obj_value[agent_name][obj_key]
        else:
            agent_use_specs[key] = agent_input_specs[key]
agent_paras = tuple(agent_use_specs.values())


agent_wrapper = OrderEnforcingAgent(agent = agent_dict[agent_name](
    *agent_paras
))

if "reward" in specs.keys():
    reward_function_name = specs["reward"]
    get_reward.reward_function = reward_function_dict[reward_function_name]

def train (agent_json = None,
           episodes_number = 3,
           schema_path = './data/citylearn_challenge_2022_phase_1/schema.json',
           experiment_id = str(random.randint(0,10000)),
           preload = False):
    '''
    Run a training session, save results
    '''
    print("============= Start Evaluation =============")

    # TODO: Agent should be initalized with JSON here

    env = CityLearnEnv(schema=Constants.schema_path)
    obs_dict = env_reset(env)

    agent = agent_wrapper

    actions = agent.register_reset(obs_dict)

    episodes_completed = 0
    num_steps = 0
    interrupted = False
    episode_metrics = []

    try:
        while True:
            observations, _, done, _ = env.step(actions)
            if done:
                episodes_completed += 1
                metrics_t = env.evaluate()

                metrics = {"price_cost": metrics_t[0], "emmision_cost": metrics_t[1]}
                episode_metrics.append(metrics)
                print(f"Episode complete: {episodes_completed} | Latest episode metrics: {metrics}", )

                obs_dict = env_reset(env)

                actions = agent.register_reset(obs_dict)
            else:
                actions = agent.compute_action(observations)

            num_steps += 1
            if num_steps % Constants.steps_per_frame_save == 0:
                append_one_frame(env.render())

            if episodes_completed >= episodes_number:
                break
    except KeyboardInterrupt:
        print("========================= Stopping Evaluation =========================")
        interrupted = True

    if not interrupted:
        print("=========================Completed=========================")

In [3]:
train()



In [4]:
train_button = widgets.Button(description="Train")
train_button.on_click(train)
from IPython.display import display
def _show(x):
    display(get_image_of_frame_at(x))
frame_slider = widgets.IntSlider(min=0, max=get_total_frame_number() - 1, value=0)
frame_player = widgets.Play(
    value=0,
    min=0,
    max=get_total_frame_number() - 1,
    step=1,
    interval=500,
    description="Press play",
    disabled=False
)
frame_input_ui = widgets.HBox([frame_player, frame_slider])
widgets.jslink((frame_slider, 'value'), (frame_player, 'value'))
img_output = widgets.interactive_output(_show, {'x': frame_slider})
img_ui = widgets.VBox([train_button, frame_input_ui, img_output])
display(img_ui)



VBox(children=(Button(description='Train', style=ButtonStyle()), HBox(children=(Play(value=0, description='Pre…

