In [61]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import random
import math
import numpy as np
from tqdm.notebook import tqdm
import yaml
import pickle
from utils.replay_buffer import ReplayBuffer
from agents.network import QNetwork
from utils.converter import Converter
import grid2op
from grid2op.Action import TopologyChangeAction
from utils.data_saver import TrajectoryDataLoader
import os

In [62]:
class CQLAgent:
    def __init__(self, cfg):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.cfg = cfg
        self.q_net = QNetwork(self.cfg).to(self.device)
        self.target_net = QNetwork(self.cfg).to(self.device)
        self.target_net.load_state_dict(self.q_net.state_dict())
        self.optimizer = optim.Adam(self.q_net.parameters(), lr=self.cfg['lr'])
        self.tau = 1e-3
        
    def update_target_network(self):
        for target_param, param in zip(self.target_net.parameters(), self.q_net.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

    
    def choose_action(self, observation):
        state = torch.tensor(observation).to(self.device)
        q_val = self.q_net(state)
        return torch.argmax(q_val).item()

    def learn(self, states, actions, rewards, states_, dones):

        states = states.to(self.device)
        actions = actions.to(self.device)
        rewards = rewards.to(self.device)
        states_ = states_.to(self.device)
        dones = dones.to(self.device)
        batch_indices = np.arange(self.cfg['BATCH_SIZE'], dtype=np.int64)

        q_values = self.q_net(states)
        next_q_values = self.target_net(states_)

        #cql loss
        logsump = torch.logsumexp(q_values, keepdim=True, dim=1)
        cql_loss = torch.mean(logsump - q_values)
        

        q_loss = nn.functional.mse_loss(q_values, next_q_values)
        #print(q_loss , cql_loss.item() , self.cfg['cql_alpha'])
        
        total_loss = q_loss + cql_loss.item() * 0.5

        self.optimizer.zero_grad()
        total_loss.backward()
        self.optimizer.step()

        self.update_target_network()

        return total_loss.item()
    

    def save_models(self, path):
        os.makedirs(path, exist_ok=True)
        torch.save(self.q_net.state_dict(), os.path.join(path, "CQL.pth"))
    
    def load_model(self, path):
        self.q_net.load_state_dict(torch.save(path))

In [63]:
def read_yaml_file(file_path):
    with open(file_path, 'r') as file:
        try:
            data = yaml.safe_load(file)
            return data
        except yaml.YAMLError as e:
            print(f"Error reading YAML file: {e}")

In [64]:
yaml_data = read_yaml_file('config.yml')
data = TrajectoryDataLoader("Data\\trajectory.pkl", batch_size=32)

env_name = "rte_case5_example"  # or any other name.
env = grid2op.make(env_name, test=True, action_class=TopologyChangeAction)
converter = Converter(env)



In [65]:

agent = CQLAgent(yaml_data)

for i in tqdm(range(1, 500), desc="Episodes"):
    for batch in data:
        states, actions, rewards, next_states, dones = batch
        if states.shape != torch.Size([16, 182]):
            agent.learn(states, actions, rewards, next_states, dones)

Episodes:   0%|          | 0/499 [00:00<?, ?it/s]

In [66]:
from tqdm.notebook import tqdm
import numpy as np
all_obs = []
obs = env.reset()
all_obs.append(obs)
reward = env.reward_range[0]
reward_list = []
done = False
nb_step = 0
print("Very CQL Simulation")


with tqdm(total=env.chronics_handler.max_timestep()) as pbar:
    while True:
        action = agent.choose_action(obs.to_vect())
        #action = my_agent.act(obs, reward, done)
        obs, reward, done, _ = env.step(converter.convert_one_hot_encoding_act_to_env_act(converter.int_one_hot(action)))
        reward_list.append(reward)
        pbar.update(1)
        if done:
            break
        all_obs.append(obs)
        nb_step += 1

reward_list_simple_DQN = np.copy(reward_list)


Very CQL Simulation


  0%|          | 0/2016 [00:00<?, ?it/s]

In [68]:
agent.save_models("Agents/CQL")