# Imports

In [None]:
import json

from src.utils.cl_rewards import *
from src.utils.utils import get_env_from_config, set_seed

# Torch
import torch

# Tensordict modules
from tensordict import TensorDict
from tensordict.nn import TensorDictModule
from tensordict.nn.distributions import NormalParamExtractor

# Data collection
from torchrl.collectors import SyncDataCollector
from torchrl.data import Composite, Categorical, Bounded, Unbounded
from torchrl.data.replay_buffers import ReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.data.replay_buffers.storages import LazyTensorStorage

# Env
from torchrl.envs import RewardSum, TransformedEnv, EnvBase
from torchrl.envs.utils import check_env_specs

# Multi-agent network
from torchrl.modules import MultiAgentMLP, ProbabilisticActor, TanhNormal

# Loss
from torchrl.objectives import ClipPPOLoss, ValueEstimators, ReinforceLoss, A2CLoss, KLPENPPOLoss

# Utils

from torch import multiprocessing
from matplotlib import pyplot as plt
from tqdm import tqdm
from gymnasium.wrappers import NormalizeReward
from citylearn.wrappers import (
    NormalizedObservationWrapper,
)

# Constants

In [None]:
REWARDS = {
    'cost': Cost,
    'weighted_cost_emissions': WeightedCostAndEmissions,
    'cost_pen_no_batt': CostNoBattPenalization,
    'cost_pen_bad_batt': CostBadBattUsePenalization,
    'cost_pen_bad_action': CostIneffectiveActionPenalization,
}

# Utils

In [None]:
def plot_rewards_and_actions(train_rewards, eval_rewards, train_env, eval_env, policy):

    fig, axs = plt.subplots(2, 3, figsize=(18, 8))

    # Plot rewards
    axs[0, 0].plot(train_rewards, label="Training Reward Mean")
    axs[0, 0].plot(eval_rewards, label="Evaluation Reward Mean")
    axs[0, 0].set_xlabel("Training Iterations")
    axs[0, 0].set_ylabel("Reward")
    axs[0, 0].set_title("Training and Evaluation Rewards")
    axs[0, 0].grid()
    axs[0, 0].legend()

    # Sample actions for train_env and eval_env
    with torch.no_grad():
        
        train_rollout = train_env.rollout(train_env.cl_env.unwrapped.time_steps - 1, policy=policy)
        train_actions = train_rollout.get(train_env.action_key).squeeze()
        train_soc = [b.electrical_storage.soc for b in train_env.cl_env.unwrapped.buildings]
        train_net_electricity_consumption = [b.net_electricity_consumption for b in train_env.cl_env.unwrapped.buildings]
        train_opt_actions = torch.tensor(np.array(train_env.cl_env.unwrapped.optimal_actions), requires_grad=False).swapaxes(0, 1)
        train_opt_soc = torch.tensor(np.array(train_env.cl_env.unwrapped.optimal_soc), requires_grad=False).swapaxes(0, 1)

        eval_rollout = eval_env.rollout(eval_env.cl_env.unwrapped.time_steps - 1, policy=policy)
        eval_actions = eval_rollout.get(eval_env.action_key).squeeze()
        eval_soc = [b.electrical_storage.soc for b in eval_env.cl_env.unwrapped.buildings]
        eval_net_electricity_consumption = [b.net_electricity_consumption for b in eval_env.cl_env.unwrapped.buildings]
        eval_opt_actions = torch.tensor(np.array(eval_env.cl_env.unwrapped.optimal_actions), requires_grad=False).swapaxes(0, 1)
        eval_opt_soc = torch.tensor(np.array(eval_env.cl_env.unwrapped.optimal_soc), requires_grad=False).swapaxes(0, 1)

    # Plot actions for each building
    for i in range(train_env.n_agents):
        axs[0, 1].plot(train_actions[:, i].numpy(), label=f"Train Agent {i} actions")
        axs[0, 1].plot(train_opt_actions[:, i], label=f"Train Agent {i} optimal actions", linestyle="--")
        axs[0, 2].plot(eval_actions[:, i].numpy(), label=f"Eval Agent {i} actions")
        axs[0, 2].plot(eval_opt_actions[:, i], label=f"Eval Agent {i} optimal actions", linestyle="--")

    axs[0, 1].set_xlabel("Hour of the day")
    axs[0, 1].set_ylabel("Action")
    axs[0, 1].set_title("Train Actions Comparison")
    axs[0, 1].grid()
    axs[0, 1].legend()

    axs[0, 2].set_xlabel("Hour of the day")
    axs[0, 2].set_ylabel("Action")
    axs[0, 2].set_title("Eval Actions Comparison")
    axs[0, 2].grid()
    axs[0, 2].legend()

    # Plot SOC for each building
    for i in range(train_env.n_agents):
        axs[1, 1].plot(train_soc[i], label=f"Train Agent {i} SOC")
        axs[1, 1].plot(train_opt_soc[:, i], label=f"Train Agent {i} optimal SOC", linestyle="--")
        axs[1, 2].plot(eval_soc[i], label=f"Eval Agent {i} SOC")
        axs[1, 2].plot(eval_opt_soc[:, i], label=f"Eval Agent {i} optimal SOC", linestyle="--")

    axs[1, 1].set_xlabel("Hour of the day")
    axs[1, 1].set_ylabel("SOC")
    axs[1, 1].set_title("Train SOC Comparison")
    axs[1, 1].grid()
    axs[1, 1].legend()

    axs[1, 2].set_xlabel("Hour of the day")
    axs[1, 2].set_ylabel("SOC")
    axs[1, 2].set_title("Eval SOC Comparison")
    axs[1, 2].grid()
    axs[1, 2].legend()

    # Plot net electricity consumption for each building
    for i in range(train_env.n_agents):
        axs[1, 0].plot(train_net_electricity_consumption[i], label=f"Train Agent {i} Net Electricity Consumption")
        axs[1, 0].plot(eval_net_electricity_consumption[i], label=f"Eval Agent {i} Net Electricity Consumption", linestyle="--")

    axs[1, 0].set_xlabel("Hour of the day")
    axs[1, 0].set_ylabel("Net Electricity Consumption")
    axs[1, 0].set_title("Net Electricity Consumption Comparison")
    axs[1, 0].grid()
    axs[1, 0].legend()

    plt.suptitle(f'Results for day {train_env.cl_env.unwrapped.episode_tracker.episode_start_time_step}')
    plt.tight_layout()
    plt.show()

# Create CityLearn wrapper for TorchRL

In [None]:
class CityLearnMultiAgentEnv(EnvBase):
    
    def __init__(self, env_config, device, seed, batch_size=1):

        super().__init__()

        self.cl_env = get_env_from_config(config=env_config, seed=seed)
        self.cl_env = NormalizeReward(self.cl_env, gamma=1, epsilon=1e-8)
        self.cl_env = NormalizedObservationWrapper(self.cl_env)
        self.n_agents = len(env_config["schema"]['buildings'])
        self.batch_size = torch.Size([batch_size,])
        self.device = device

        action_specs = []
        observation_specs = []
        reward_specs = []
        
        cl_env_as = self.cl_env.action_space
        cl_env_os = self.cl_env.observation_space
        self.n_actions = cl_env_as[0].shape[0]
        self.n_observations = cl_env_os[0].shape[0]

        for i in range(self.n_agents):
            
            action_specs.append(Bounded(
                low=cl_env_as[i].low,
                high=cl_env_as[i].high,
                shape=cl_env_as[i].shape,
                device=device
            ))
            reward_specs.append(Unbounded
                (shape=(1,), dtype=torch.float, device=device),
            )
            observation_specs.append(Bounded(
                low=cl_env_os[i].low,
                high=cl_env_os[i].high,
                shape=cl_env_os[i].shape,
                device=device
            ))

        # Define observation and action spaces

        self.action_spec = Composite({
            "agents": Composite(
                {"action": torch.stack(action_specs, dim=0).view(self.batch_size[0], self.n_agents, cl_env_as[0].shape[0])},
                shape=(self.batch_size[0], self.n_agents,)
            )
        }, batch_size=self.batch_size)

        self.unbatched_action_spec = Composite({
            "agents": Composite(
                {"action": torch.stack(action_specs, dim=0).view(self.n_agents, cl_env_as[0].shape[0])},
                shape=(self.n_agents,)
            )
        })

        self.reward_spec = Composite({
            "agents": Composite(
                {"reward": torch.stack(reward_specs, dim=0).view(self.batch_size[0], self.n_agents, 1)},
                shape=(self.batch_size[0], self.n_agents,)
            )
        }, batch_size=self.batch_size)

        self.unbatched_reward_spec = Composite({
            "agents": Composite(
                {"reward": torch.stack(reward_specs, dim=0).view(self.n_agents, 1)},
                shape=(self.n_agents,)
            )
        })

        self.observation_spec = Composite({
            "agents": Composite(
                {"observation": torch.stack(observation_specs, dim=0).expand(self.batch_size[0], self.n_agents, cl_env_os[0].shape[0])},
                shape=(self.batch_size[0], self.n_agents,)
            )
        }, batch_size=self.batch_size)

        self.unbatched_observation_spec = Composite({
            "agents": Composite(
                {"observation": torch.stack(observation_specs, dim=0).expand(self.n_agents, cl_env_os[0].shape[0])},
                shape=(self.n_agents,)
            )
        })

        self.done_spec = Categorical(n=2, shape=torch.Size((self.batch_size[0], )), dtype=torch.bool)

        # Initialize state variables
        self.current_step = 0
        self.done = False

    def _reset(self, tensordict=None):
        observations, _ = self.cl_env.reset()
        tensordict = TensorDict({
            "agents": TensorDict({
                # "info": torch.empty(self.batch_size),
                "observation": torch.tensor(
                    np.array(observations), dtype=torch.float
                ).reshape(self.batch_size[0], len(observations), len(observations[0])),
            }, torch.Size((self.batch_size[0], self.n_agents)), device=self.device),
            "done": torch.tensor(False, dtype=torch.bool).repeat(*self.batch_size)
        }, batch_size=self.batch_size, device=self.device)
        
        return tensordict

    def _step(self, tensordict):
       
        # Step through the environment
        next_obs, rewards, done, _, _ = self.cl_env.step(tensordict['agents','action'].cpu().squeeze(0).numpy())
        self.done = done

        # Prepare TensorDict for the step
        step_results = TensorDict({
            "agents": TensorDict({
                "reward": torch.tensor(
                    rewards, dtype=torch.float
                ).reshape(self.batch_size[0], self.n_agents, self.n_actions),
                "observation": torch.tensor(
                    np.array(next_obs), dtype=torch.float
                ).reshape(self.batch_size[0], self.n_agents, self.n_observations),
                # "info": torch.empty(self.batch_size),
            }, batch_size=torch.Size((self.batch_size[0], self.n_agents)), device=self.device),
            "done": torch.tensor(done, dtype=torch.bool).repeat(*self.batch_size)
        }, batch_size=self.batch_size, device=self.device)

        return step_results

    def _set_seed(self, seed):
        self.cl_env.seed(seed)
        torch.manual_seed(seed)

    def close(self):
        self.cl_env.close()

## Utility to create a configured environment

In [None]:
def create_env(env_config, device, seed):
    
    env = CityLearnMultiAgentEnv(
        env_config=env_config,
        device=device,
        seed=seed,
    )

    # Include reward sum in the environment
    env = TransformedEnv(
        env,
        RewardSum(in_keys=[env.reward_key], out_keys=[("agents", "episode_reward")]),
    )
    
    return env

## Create training and validation environments

In [None]:
# Common configurations for environments

active_observations = [
    'hour',
    'day_type',
    'solar_generation',
    'net_electricity_consumption',
    'electrical_storage_soc',
    'non_shiftable_load',
    'direct_solar_irradiance',
    'direct_solar_irradiance_predicted_6h',
    'direct_solar_irradiance_predicted_12h',
    'direct_solar_irradiance_predicted_24h',
    'selling_price'
]

data_path = 'data/naive_data/'
reward = 'weighted_cost_emissions'
seed = 0
price_margin = 0.1
day_count = 1

# device = 'cpu'

device_ix = 7
is_fork = multiprocessing.get_start_method() == "fork"
device = (
    torch.device(device_ix)
    if torch.cuda.is_available() and not is_fork
    else torch.device("cpu")
)

set_seed(seed)

# Training configurations

schema_filepath = data_path + 'schema.json'

with open(schema_filepath) as json_file:
    schema_dict = json.load(json_file)

train_env_config = {
    "schema": schema_dict,
    "central_agent": False,
    "active_observations": active_observations,
    "reward_function": REWARDS[reward],
    "random_seed": seed,
    "day_count": day_count,
    "extended_obs": True,
    "price_margin": price_margin,
    "personal_encoding": True,
}

train_env = create_env(train_env_config, device, seed)

# Validation configurations

schema_filepath = data_path + 'eval/schema.json'

with open(schema_filepath) as json_file:
    schema_dict = json.load(json_file)

eval_env_config = {
    "schema": schema_dict,
    "central_agent": False,
    "active_observations": active_observations,
    "reward_function": REWARDS[reward],
    "random_seed": seed,
    "day_count": day_count,
    "extended_obs": True,
    "price_margin": price_margin,
    "personal_encoding": True,
}

eval_env = create_env(eval_env_config, device, seed)

### Check specs for training and validation environments

In [None]:
print("action_spec:", train_env.full_action_spec)
print("reward_spec:", train_env.full_reward_spec)
print("done_spec:", train_env.full_done_spec)
print("observation_spec:", train_env.full_observation_spec)

In [None]:
print("action_spec:", eval_env.full_action_spec)
print("reward_spec:", eval_env.full_reward_spec)
print("done_spec:", eval_env.full_done_spec)
print("observation_spec:", eval_env.full_observation_spec)

In [None]:
check_env_specs(train_env)
check_env_specs(eval_env)

# PPO Training

## Networks configuration

### Policy definition

In [None]:
def create_probabilistic_policy(env, share_parameters_policy=True, device="cpu"):

    # First: define a neural network n_obs_per_agent -> 2 * n_actions_per_agents (mean and std)

    policy_net = torch.nn.Sequential(
        MultiAgentMLP(
            n_agent_inputs=env.observation_spec["agents", "observation"].shape[-1],
            n_agent_outputs=2 * env.action_spec.shape[-1],
            n_agents=env.n_agents,
            centralised=False,
            share_params=share_parameters_policy,
            device=device,
            depth=2,
            num_cells=256,
            activation_class=torch.nn.Tanh,
        ),
        NormalParamExtractor(),
    )

    # Second: wrap the neural network in a TensordictModule

    policy_module = TensorDictModule(
        policy_net,
        in_keys=[("agents", "observation")],
        out_keys=[("agents", "loc"), ("agents", "scale")],
    )

    # Third: define the probabilistic policy

    policy = ProbabilisticActor(
        module=policy_module,
        spec=env.unbatched_action_spec,
        in_keys=[("agents", "loc"), ("agents", "scale")],
        out_keys=[env.action_key],
        distribution_class=TanhNormal,
        distribution_kwargs={
            "low": env.unbatched_action_spec[env.action_key].space.low,
            "high": env.unbatched_action_spec[env.action_key].space.high,
        },
        return_log_prob=True,
        log_prob_key=("agents", "sample_log_prob"),
    )  # we'll need the log-prob for the PPO loss

    return policy

policy = create_probabilistic_policy(train_env, share_parameters_policy=True, device=device)

### Critic definition

In [None]:
def create_critic(env, share_parameters_critic=True, mappo=True, device="cpu"):

    critic_net = MultiAgentMLP(
        n_agent_inputs=env.observation_spec["agents", "observation"].shape[-1],
        n_agent_outputs=1,  # 1 value per agent
        n_agents=env.n_agents,
        centralised=mappo,
        share_params=share_parameters_critic,
        device=device,
        depth=2,
        num_cells=256,
        activation_class=torch.nn.Tanh,
    )

    critic = TensorDictModule(
        module=critic_net,
        in_keys=[("agents", "observation")],
        out_keys=[("agents", "state_value")],
    )
    
    return critic

critic = create_critic(train_env, share_parameters_critic=True, mappo=True, device=device)

### Verify that the Actor and Critic are well configured

In [None]:
print("Running policy:", policy(train_env.reset()))
print("Running value:", critic(train_env.reset()))

## Data collector

In [None]:
def sample_data(env, policy, device, frames_per_batch, n_iters):
    total_frames = frames_per_batch * n_iters
    collector = SyncDataCollector(
        env,
        policy,
        device=device,
        storing_device=device,
        frames_per_batch=frames_per_batch,
        total_frames=total_frames,
    )
    return collector

### Replay buffer

In [None]:
def create_replay_buffer(frames_per_batch, device, minibatch_size):
    return ReplayBuffer(
        storage=LazyTensorStorage(
            frames_per_batch, device=device
        ),  # We store the frames_per_batch collected at each iteration
        sampler=SamplerWithoutReplacement(),
        batch_size=minibatch_size,  # We will sample minibatches of this size
    )

### Loss function PPO

In [None]:
def create_loss_module(policy, critic, env, clip_epsilon, entropy_eps, gamma, lmbda):
    loss_module = ClipPPOLoss(
        actor_network=policy,
        critic_network=critic,
        clip_epsilon=clip_epsilon,
        entropy_coef=entropy_eps,
        normalize_advantage=False,  # Important to avoid normalizing across the agent dimension
    )
    loss_module.set_keys(  # We have to tell the loss where to find the keys
        reward=env.reward_key,
        action=env.action_key,
        sample_log_prob=("agents", "sample_log_prob"),
        value=("agents", "state_value"),
        # These last 2 keys will be expanded to match the reward shape
        done=("agents", "done"),
        terminated=("agents", "terminated"),
    )

    loss_module.make_value_estimator(
        ValueEstimators.GAE, gamma=gamma, lmbda=lmbda
    )  # We build GAE
    
    return loss_module

### Training loop

In [None]:
def train_policy(
        env, eval_env, n_iters, collector, loss_module, replay_buffer, num_epochs, frames_per_batch, minibatch_size, max_grad_norm, optim
    ):

    episode_reward_mean_list = []
    episode_reward_mean_list_eval = []

    GAE = loss_module.value_estimator

    with tqdm(total=n_iters, nrows=10, desc="episode: 0, reward_mean: 0, eval_reward_mean: 0") as pbar:

        episode = 0

        for tensordict_data in collector:

            tensordict_data.set(
                ("next", "agents", "done"),
                tensordict_data.get(("next", "done"))
                .unsqueeze(-1)
                .repeat(1, 1, env.n_agents)
                .unsqueeze(-1)
                .expand(tensordict_data.get_item_shape(("next", env.reward_key))),
            )
            tensordict_data.set(
                ("next", "agents", "terminated"),
                tensordict_data.get(("next", "terminated"))
                .unsqueeze(-1)
                .repeat(1, 1, env.n_agents)
                .unsqueeze(-1)
                .expand(tensordict_data.get_item_shape(("next", env.reward_key))),
            )

            # We need to expand the done and terminated to match the reward shape (this is expected by the value estimator)

            with torch.no_grad():
                GAE(
                    tensordict_data,
                    params=loss_module.critic_network_params,
                    target_params=loss_module.target_critic_network_params,
                )  # Compute GAE and add it to the data

            data_view = tensordict_data.reshape(-1)  # Flatten the batch size to shuffle data
            replay_buffer.extend(data_view)

            for _ in range(num_epochs):

                for _ in range(frames_per_batch // minibatch_size):

                    subdata = replay_buffer.sample()
                    loss_vals = loss_module(subdata)

                    loss_value = (
                        loss_vals["loss_objective"]
                        + loss_vals["loss_critic"]
                        + loss_vals["loss_entropy"]
                    )

                    loss_value.backward()

                    torch.nn.utils.clip_grad_norm_(
                        loss_module.parameters(), max_grad_norm
                    )  # Optional

                    optim.step()
                    optim.zero_grad()

            # Evaluating

            with torch.no_grad():

                policy.eval()

                episode_reward_mean_eval = 0

                for _ in range(minibatch_size):

                    rollout = eval_env.rollout((eval_env.cl_env.unwrapped.time_steps - 1), policy=policy)
                    episode_reward_mean_eval += rollout.get(("next", "agents", "episode_reward")).mean().item()

                episode_reward_mean_eval = episode_reward_mean_eval / minibatch_size

                episode_reward_mean_list_eval.append(episode_reward_mean_eval)
                policy.train()

            # Logging

            done = tensordict_data.get(("next", "agents", "done"))

            episode_reward_mean = (
                tensordict_data.get(("next", "agents", "episode_reward"))[done].mean().item()
            )

            episode_reward_mean_list.append(episode_reward_mean)

            episode += 1

            pbar.set_description(f"episode: {episode}, reward_mean: {episode_reward_mean}, eval_reward_mean: {episode_reward_mean_eval}")
            pbar.update()

    return policy, episode_reward_mean_list, episode_reward_mean_list_eval

## Launch training

In [None]:
# Sampling
days_per_batch = 512
frames_per_batch = days_per_batch * 24  # Number of team frames collected per training iteration
n_iters = 25  # Number of sampling and training iterations
total_frames = frames_per_batch * n_iters

# Training
num_epochs = 30  # Number of optimization steps per training iteration
minibatch_size = 256  # Size of the mini-batches in each optimization step
max_grad_norm = 1.0  # Maximum norm for the gradients
lr = 1e-3  # Learning rate

# PPO
clip_epsilon = 0.2  # clip value for PPO loss
gamma = 1  # discount factor
lmbda = 0.9  # lambda for generalised advantage estimation
entropy_eps = 1e-3  # coefficient of the entropy term in the PPO loss

# Create networks
policy = create_probabilistic_policy(train_env, share_parameters_policy=False, device=device)
critic = create_critic(train_env, share_parameters_critic=True, mappo=True, device=device)

# Configure data collection
collector = sample_data(train_env, policy, device, frames_per_batch, n_iters)

# Create replay buffer
replay_buffer = create_replay_buffer(frames_per_batch, device, minibatch_size)

# Create loss module
loss_module = create_loss_module(policy, critic, train_env, clip_epsilon, entropy_eps, gamma, lmbda)

# Create optimizer
optim = torch.optim.Adam(loss_module.parameters(), lr)

# Train policy

ppo_policy, ppo_reward_mean_list, ppo_eval_reward_mean_list = train_policy(
    env=train_env,
    eval_env=eval_env,
    n_iters=n_iters,
    collector=collector,
    loss_module=loss_module,
    replay_buffer=replay_buffer,
    num_epochs=num_epochs,
    frames_per_batch=frames_per_batch,
    minibatch_size=minibatch_size,
    max_grad_norm=max_grad_norm,
    optim=optim,
)

### Plot results

In [None]:
plot_rewards_and_actions(ppo_reward_mean_list, ppo_eval_reward_mean_list, train_env, eval_env, ppo_policy)

In [None]:
with torch.no_grad():
    eval_rollout = eval_env.rollout(eval_env.cl_env.unwrapped.time_steps - 1, policy=policy)

In [None]:
eval_env.cl_env.unwrapped.buildings[0].electrical_storage.soc

In [None]:
eval_env.cl_env.unwrapped.optimal_soc[0]

In [None]:
eval_env.cl_env.unwrapped.episode_tracker.episode_start_time_step