# Initialization

In [None]:
from citylearn.citylearn import CityLearnEnv
import time, random, typing, cProfile, traceback
import numpy as np

def action_space_to_dict(aspace):
    """ Only for box space """
    return { "high": aspace.high,
             "low": aspace.low,
             "shape": aspace.shape,
             "dtype": str(aspace.dtype)
    }

def env_reset(env):
    observations = env.reset()
    action_space = env.action_space
    observation_space = env.observation_space
    building_info = env.get_building_information()
    building_info = list(building_info.values())
    action_space_dicts = [action_space_to_dict(asp) for asp in action_space]
    observation_space_dicts = [action_space_to_dict(osp) for osp in observation_space]
    obs_dict = {"action_space": action_space_dicts,
                "observation_space": observation_space_dicts,
                "building_info": building_info,
                "observation": observations }
    return obs_dict

# Agent Setup

In [None]:
class Constants:
    episodes = 3
    schema_path = './data/citylearn_challenge_2022_phase_1/schema.json'

from agents.agents import ddpg
from agents.networks import central_critic, comm_net
from agents.features import *
from rewards import get_reward, rewards

from agents.orderenforcingwrapper import OrderEnforcingAgent

agent_wrapper = OrderEnforcingAgent(agent = ddpg.DDPGAgent(
    actor = comm_net.CommNet,
    critic = central_critic.CentralCritic,
    actor_feature=BaseFeatureEngineer(),
    critic_feature=CentralCriticEngineer(BaseFeatureEngineer())
))

get_reward.reward_function = rewards.default_reward


# Training

In [None]:
def train (agent_json = None,
           episodes_number = 3,
           schema_path = './data/citylearn_challenge_2022_phase_1/schema.json',
           experiment_id = str(random.randint(0,10000)),
           preload = False):
    '''
    Run a training session, save results
    '''

    # TODO: Agent should be initalized with JSON here

    env = CityLearnEnv(schema=Constants.schema_path)
    obs_dict = env_reset(env)

    agent = agent_wrapper

    actions = agent.register_reset(obs_dict)

    episodes_completed = 0
    interrupted = False
    episode_metrics = []

    try:
        while True:
            observations, _, done, _ = env.step(actions)
            if done:
                episodes_completed += 1
                metrics_t = env.evaluate()

                metrics = {"price_cost": metrics_t[0], "emmision_cost": metrics_t[1]}
                episode_metrics.append(metrics)
                print(f"Episode complete: {episodes_completed} | Latest episode metrics: {metrics}", )

                obs_dict = env_reset(env)

                actions = agent.register_reset(obs_dict)
            else:
                actions = agent.compute_action(observations)

            if episodes_completed >= episodes_number:
                break
    except KeyboardInterrupt:
        print("========================= Stopping Evaluation =========================")
        interrupted = True

    if not interrupted:
        print("=========================Completed=========================")


In [None]:
train()