# Imports

In [None]:
import json
import copy

from src.utils.cl_rewards import *
from src.utils.utils import get_env_from_config, set_seed

# Torch
import torch

# Tensordict modules
from tensordict import TensorDict
from tensordict.nn import TensorDictModule
from tensordict.nn.distributions import NormalParamExtractor

# Data collection
from torchrl.collectors import SyncDataCollector
from torchrl.data import Composite, Categorical, Bounded, Unbounded
from torchrl.data.replay_buffers import ReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.data.replay_buffers.storages import LazyTensorStorage

# Env
from torchrl.envs import RewardSum, TransformedEnv, EnvBase
from torchrl.envs.utils import check_env_specs

# Multi-agent network
from torchrl.modules import MultiAgentMLP, ProbabilisticActor, TanhNormal

# Loss
from torchrl.objectives import ValueEstimators, A2CLoss

# Utils

from torch import multiprocessing
from matplotlib import pyplot as plt
from tqdm import tqdm
from gymnasium.wrappers import NormalizeReward
from citylearn.wrappers import (
    NormalizedObservationWrapper,
)

from src.utils.cl_torchrl_helper import (
    create_env,
    plot_rewards_and_actions
)

# Constants

In [None]:
REWARDS = {
    'cost': Cost,
    'weighted_cost_emissions': WeightedCostAndEmissions,
    'cost_pen_no_batt': CostNoBattPenalization,
    'cost_pen_bad_batt': CostBadBattUsePenalization,
    'cost_pen_bad_action': CostIneffectiveActionPenalization,
}

## Create training and validation environments

In [None]:
# Common configurations for environments

active_observations = [
    'hour',
    'day_type',
    'solar_generation',
    'net_electricity_consumption',
    'electrical_storage_soc',
    'non_shiftable_load',
    'non_shiftable_load_predicted_4h',
    'non_shiftable_load_predicted_6h',
    'non_shiftable_load_predicted_12h',
    'non_shiftable_load_predicted_24h',
    'direct_solar_irradiance',
    'direct_solar_irradiance_predicted_6h',
    'direct_solar_irradiance_predicted_12h',
    'direct_solar_irradiance_predicted_24h',
    'selling_price'
]

data_path = 'data/naive_data/'
reward = 'weighted_cost_emissions'
seed = 0
day_count = 1

# device = 'cpu'

device_ix = 7
is_fork = multiprocessing.get_start_method() == "fork"
device = (
    torch.device(device_ix)
    if torch.cuda.is_available() and not is_fork
    else torch.device("cpu")
)

set_seed(seed)

# Training configurations

schema_filepath = data_path + 'schema.json'

with open(schema_filepath) as json_file:
    schema_dict = json.load(json_file)

train_env_config = {
    "schema": schema_dict,
    "central_agent": False,
    "active_observations": active_observations,
    "reward_function": REWARDS[reward],
    "random_seed": seed,
    "day_count": day_count,
    "personal_encoding": True,
}

train_env = create_env(train_env_config, device, seed)

# Validation configurations

schema_filepath = data_path + 'eval/schema.json'

with open(schema_filepath) as json_file:
    schema_dict = json.load(json_file)

eval_env_config = {
    **train_env_config,
    "schema": schema_dict,
}

eval_env = create_env(eval_env_config, device, seed)

### Check specs for training and validation environments

In [None]:
print("action_spec:", train_env.full_action_spec)
print("reward_spec:", train_env.full_reward_spec)
print("done_spec:", train_env.full_done_spec)
print("observation_spec:", train_env.full_observation_spec)

In [None]:
print("action_spec:", eval_env.full_action_spec)
print("reward_spec:", eval_env.full_reward_spec)
print("done_spec:", eval_env.full_done_spec)
print("observation_spec:", eval_env.full_observation_spec)

In [None]:
check_env_specs(train_env)
check_env_specs(eval_env)

# A2C Training 

## Networks configuration

### Policy definition

In [None]:
def create_probabilistic_policy(env, share_parameters_policy=True, device="cpu"):

    # First: define a neural network n_obs_per_agent -> 2 * n_actions_per_agents (mean and std)

    policy_net = torch.nn.Sequential(
        MultiAgentMLP(
            n_agent_inputs=env.observation_spec["agents", "observation"].shape[-1],
            n_agent_outputs=2 * env.action_spec.shape[-1],
            n_agents=env.n_agents,
            centralised=False,
            share_params=share_parameters_policy,
            device=device,
            depth=2,
            num_cells=256,
            activation_class=torch.nn.Tanh,
        ),
        NormalParamExtractor(),
    )

    # Second: wrap the neural network in a TensordictModule

    policy_module = TensorDictModule(
        policy_net,
        in_keys=[("agents", "observation")],
        out_keys=[("agents", "loc"), ("agents", "scale")],
    )

    # Third: define the probabilistic policy

    policy = ProbabilisticActor(
        module=policy_module,
        spec=env.unbatched_action_spec,
        in_keys=[("agents", "loc"), ("agents", "scale")],
        out_keys=[env.action_key],
        distribution_class=TanhNormal,
        distribution_kwargs={
            "low": env.unbatched_action_spec[env.action_key].space.low,
            "high": env.unbatched_action_spec[env.action_key].space.high,
        },
        return_log_prob=True,
        log_prob_key=("agents", "sample_log_prob"),
    )  # we'll need the log-prob for the PPO loss

    return policy

policy = create_probabilistic_policy(train_env, share_parameters_policy=True, device=device)

### Critic definition

In [None]:
def create_critic(env, share_parameters_critic=True, mappo=True, device="cpu"):

    critic_net = MultiAgentMLP(
        n_agent_inputs=env.observation_spec["agents", "observation"].shape[-1],
        n_agent_outputs=1,  # 1 value per agent
        n_agents=env.n_agents,
        centralised=mappo,
        share_params=share_parameters_critic,
        device=device,
        depth=2,
        num_cells=256,
        activation_class=torch.nn.Tanh,
    )

    critic = TensorDictModule(
        module=critic_net,
        in_keys=[("agents", "observation")],
        out_keys=[("agents", "state_value")],
    )
    
    return critic

critic = create_critic(train_env, share_parameters_critic=True, mappo=True, device=device)

### Verify that the Actor and Critic are well configured

In [None]:
print("Running policy:", policy(train_env.reset()))
print("Running value:", critic(train_env.reset()))

In [None]:
def sample_data(env, policy, device, frames_per_batch, n_iters):
    total_frames = frames_per_batch * n_iters
    collector = SyncDataCollector(
        env,
        policy,
        device=device,
        storing_device=device,
        frames_per_batch=frames_per_batch,
        total_frames=total_frames,
    )
    return collector

### Replay buffer

In [None]:
def create_replay_buffer(frames_per_batch, device, minibatch_size):
    return ReplayBuffer(
        storage=LazyTensorStorage(
            frames_per_batch, device=device
        ),  # We store the frames_per_batch collected at each iteration
        sampler=SamplerWithoutReplacement(),
        batch_size=minibatch_size,  # We will sample minibatches of this size
    )

### Loss function Reinforce

In [None]:
def create_loss_module(policy, critic, env, gamma, lmbda):

    loss_module = A2CLoss(
        actor_network=policy,
        critic_network=critic,
        entropy_bonus=True
    )

    loss_module.set_keys(  # We have to tell the loss where to find the keys
        reward=env.reward_key,
        action=env.action_key,
        sample_log_prob=("agents", "sample_log_prob"),
        value=("agents", "state_value"),
        # These last 2 keys will be expanded to match the reward shape
        done=("agents", "done"),
        terminated=("agents", "terminated"),
    )

    loss_module.make_value_estimator(
        ValueEstimators.GAE, gamma=gamma, lmbda=lmbda
    )  # We build GAE
    
    return loss_module

### Training loop

In [None]:
def train_policy(
        env, eval_env, n_iters, collector,loss_module, replay_buffer, num_epochs, frames_per_batch, minibatch_size, max_grad_norm, optim
    ):

    episode_reward_mean_list = []
    episode_reward_mean_list_eval = []

    best_policy = None
    best_eval_reward = -float("inf")

    GAE = loss_module.value_estimator

    with tqdm(total=n_iters, nrows=10, desc="episode: 0, reward_mean: 0, eval_reward_mean: 0") as pbar:

        episode = 0

        for tensordict_data in collector:

            tensordict_data.set(
                ("next", "agents", "done"),
                tensordict_data.get(("next", "done"))
                .unsqueeze(-1)
                .repeat(1, 1, env.n_agents)
                .unsqueeze(-1)
                .expand(tensordict_data.get_item_shape(("next", env.reward_key))),
            )
            tensordict_data.set(
                ("next", "agents", "terminated"),
                tensordict_data.get(("next", "terminated"))
                .unsqueeze(-1)
                .repeat(1, 1, env.n_agents)
                .unsqueeze(-1)
                .expand(tensordict_data.get_item_shape(("next", env.reward_key))),
            )

            data_view = tensordict_data.reshape(-1)  # Flatten the batch size to shuffle data
            replay_buffer.extend(data_view)

            # We need to expand the done and terminated to match the reward shape (this is expected by the value estimator)

            with torch.no_grad():
                GAE(
                    tensordict_data,
                    params=loss_module.critic_network_params,
                    target_params=loss_module.target_critic_network_params,
                )  # Compute GAE and add it to the data

            for _ in range(num_epochs):

                for _ in range(frames_per_batch // minibatch_size):

                    subdata = replay_buffer.sample()
                    loss_vals = loss_module(subdata)

                    loss_value = (
                        loss_vals["loss_objective"]
                        + loss_vals["loss_critic"]
                        + loss_vals["loss_entropy"]
                    )

                    loss_value.backward()

                    torch.nn.utils.clip_grad_norm_(
                        loss_module.parameters(), max_grad_norm
                    )  # Optional

                    optim.step()
                    optim.zero_grad()

            # Evaluating

            with torch.no_grad():

                policy.eval()

                episode_reward_mean_eval = 0

                for _ in range(minibatch_size):

                    rollout = eval_env.rollout((eval_env.cl_env.unwrapped.time_steps - 1), policy=policy)
                    episode_reward_mean_eval += rollout.get(("next", "agents", "episode_reward")).mean().item()

                episode_reward_mean_eval = episode_reward_mean_eval / minibatch_size

                episode_reward_mean_list_eval.append(episode_reward_mean_eval)
                policy.train()


            # Save best policy

            if episode_reward_mean_eval > best_eval_reward:

                best_eval_reward = episode_reward_mean_eval
                best_policy = copy.deepcopy(policy)

            # Logging

            done = tensordict_data.get(("next", "agents", "done"))

            episode_reward_mean = (
                tensordict_data.get(("next", "agents", "episode_reward"))[done].mean().item()
            )

            episode_reward_mean_list.append(episode_reward_mean)

            episode += 1

            pbar.set_description(f"episode: {episode}, reward_mean: {episode_reward_mean}, eval_reward_mean: {episode_reward_mean_eval}")
            pbar.update()

    return best_policy, episode_reward_mean_list, episode_reward_mean_list_eval

In [None]:
# Configure data collection
collector = sample_data(train_env, policy, device, 48, 3)

for tensordict_data in collector:

    print("tensordict_data:", tensordict_data)



## Launch training

In [None]:
# Sampling
days_per_batch = 512
frames_per_batch = days_per_batch * 24  # Number of team frames collected per training iteration
n_iters = 60  # Number of sampling and training iterations
total_frames = frames_per_batch * n_iters

# Training
num_epochs = 30  # Number of optimization steps per training iteration
minibatch_size = 24  # Size of the mini-batches in each optimization step
max_grad_norm = 1.0  # Maximum norm for the gradients
lr = 1e-2  # Learning rate

# A2C
gamma = 1  # discount factor
lmbda = 0.9  # lambda for generalised advantage estimation

# Create networks
policy = create_probabilistic_policy(train_env, share_parameters_policy=True, device=device)
critic = create_critic(train_env, share_parameters_critic=True, mappo=True, device=device)

# Configure data collection
collector = sample_data(train_env, policy, device, frames_per_batch, n_iters)

# Create replay buffer
replay_buffer = create_replay_buffer(frames_per_batch, device, minibatch_size)

# Create loss module
loss_module = create_loss_module(policy, critic, train_env, gamma, lmbda)

# Create optimizer
optim = torch.optim.Adam(loss_module.parameters(), lr)

# Train policy

a2c_policy, a2c_reward_mean_list, a2c_eval_reward_mean_list = train_policy(
    env=train_env,
    eval_env=eval_env,
    n_iters=n_iters,
    collector=collector,
    loss_module=loss_module,
    replay_buffer=replay_buffer,
    num_epochs=num_epochs,
    frames_per_batch=frames_per_batch,
    minibatch_size=minibatch_size,
    max_grad_norm=max_grad_norm,
    optim=optim,
)

### Plot results

In [None]:
plot_rewards_and_actions(a2c_reward_mean_list, a2c_eval_reward_mean_list, train_env, eval_env, a2c_policy)