# Initialization

In [None]:
from citylearn.citylearn import CityLearnEnv
import time, random, typing, cProfile, traceback
import numpy as np

def action_space_to_dict(aspace):
    """ Only for box space """
    return { "high": aspace.high,
             "low": aspace.low,
             "shape": aspace.shape,
             "dtype": str(aspace.dtype)
    }

def env_reset(env):
    observations = env.reset()
    action_space = env.action_space
    observation_space = env.observation_space
    building_info = env.get_building_information()
    building_info = list(building_info.values())
    action_space_dicts = [action_space_to_dict(asp) for asp in action_space]
    observation_space_dicts = [action_space_to_dict(osp) for osp in observation_space]
    obs_dict = {"action_space": action_space_dicts,
                "observation_space": observation_space_dicts,
                "building_info": building_info,
                "observation": observations }
    return obs_dict

# Agent Setup

In [None]:
class Constants:
    episodes = 3
    schema_path = './data/citylearn_challenge_2022_phase_1/schema.json'

from agents.agents import ddpg
from agents.networks import central_critic, comm_net
from agents.features import *
from rewards import get_reward, rewards

from agents.orderenforcingwrapper import OrderEnforcingAgent

agent_wrapper = OrderEnforcingAgent(agent = ddpg.DDPGAgent(
    actor = comm_net.CommNet,
    critic = central_critic.CentralCritic,
    actor_feature=BaseFeatureEngineer(),
    critic_feature=CentralCriticEngineer(BaseFeatureEngineer())
))

get_reward.reward_function = rewards.default_reward


# Training

In [None]:
'''
README@ChenRang:
TODO: Adjust this for training purposes
1. Logging (Debugging module and experimenter module)
2. Saving episode (check with zzx for API details)
3. JSON parameter setup instead of hardcoded values
4. Reward selection system (also integrate into JSON)
5. Automated unit tests for different modules
6. Automated profiling
7. Neptune AI / MLFlow experiment logging (low priority)
8. Multiprocessing (low priority)
'''



def train():
    env = CityLearnEnv(schema=Constants.schema_path)

    obs_dict = env_reset(env)

    agent = agent_wrapper
    agent_time_elapsed = 0

    step_start = time.perf_counter()
    actions = agent.register_reset(obs_dict)
    agent_time_elapsed += time.perf_counter()- step_start

    episodes_completed = 0
    num_steps = 0
    interrupted = False
    episode_metrics = []
    try:
        while True:
            observations, _, done, _ = env.step(actions)
            if done:
                episodes_completed += 1
                metrics_t = env.evaluate()
                metrics = {"price_cost": metrics_t[0], "emmision_cost": metrics_t[1]}
                if np.any(np.isnan(metrics_t)):
                    raise ValueError("Episode metrics are nan, please contant organizers")
                episode_metrics.append(metrics)
                print(f"Episode complete: {episodes_completed} | Latest episode metrics: {metrics}", )

                obs_dict = env_reset(env)

                step_start = time.perf_counter()
                actions = agent.register_reset(obs_dict)
                agent_time_elapsed += time.perf_counter()- step_start
            else:
                step_start = time.perf_counter()
                actions = agent.compute_action(observations)
                agent_time_elapsed += time.perf_counter()- step_start
            
            num_steps += 1
            if num_steps % 1000 == 0:
                print(f"Num Steps: {num_steps}, Num episodes: {episodes_completed}")

            if episodes_completed >= Constants.episodes:
                break
    except KeyboardInterrupt:
        print("========================= Stopping Evaluation =========================")
        interrupted = True

    if not interrupted:
        print("=========================Completed=========================")

    if len(episode_metrics) > 0:
        print("Average Price Cost:", np.mean([e['price_cost'] for e in episode_metrics]))
        print("Average Emmision Cost:", np.mean([e['emmision_cost'] for e in episode_metrics]))


In [None]:
train()

In [None]:
import cProfile

pr = cProfile.Profile()
pr.enable()

pr.run('train()')

pr.disable()
pr.print_stats(sort='time')
