In [None]:
import os
import gymnasium as gym
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn

from classes_and_functions.plot import plot_graph
from classes_and_functions.serialize import serialize_loss_step_reward
from classes_and_functions.ini_agent_replay_buffer import initialize_agent_and_replay_buffer
from classes_and_functions.train_and_evaluate import train_and_eval

config = {  "EPISODES": 10000,
            "BATCH_SIZE": 256,
            "BUFFER_SIZE": 100000,
            "DISCOUNT_FACTOR": 0.95,
            "TARGET_UPDATE": 10,
            "DECAY_TIME": 100,
            "ep": 1,
            "INITIAL_EP": 1,
            "MIN_EP": 0.01,
            "POWER_EP": 7,
            "lr": 0.02,
            "INITIAL_LR": 0.02,
            "MIN_LR": 0.0001,
            "POWER_LR": 2,
            "LAYERS": [64, 128],
            "ACTIVATION_FUNCTION": nn.Sigmoid(),
            "SEED": 24,
            "ENVIRONMENT": 'CartPole-v1',
        }

path = f"online/seed_{config['SEED']}"

folder_is_exists = os.path.exists(path)
if not folder_is_exists:
    os.makedirs(path)
    os.makedirs(f"{path}/results")


with open(f"{path}/settings.txt", "w") as f:
    for key, value in config.items():
        f.write(f"{key} = {value}\n")

if torch.cuda.is_available():
    print("Using GPU")
else:
    print("Using CPU")

**Training of the agent**

In [None]:

# INITIALIZE agent and replay buffer
dqn_agent,replay_buffer = initialize_agent_and_replay_buffer(config)
# Train agent with target network
mean_loss_list,step_list,reward_list = train_and_eval(dqn_agent,replay_buffer,config,start_index=1,stop_index=config['EPISODES'],TN=True)

# SERIALIZE
serialize_loss_step_reward(mean_loss_list,step_list,reward_list,"",f"{path}/results")

# GRAPH 1
plot_graph(mean_loss_list,"Episodes","Mean loss",color="orange",ylim=[0,10],path_name=f"{path}/mean_loss")

# GRAPH 2
plot_graph(step_list,"Episodes","Steps",type=plt.bar,color="green",path_name=f"{path}/steps")

# GRAPH 3
plot_graph(reward_list,"Episodes","Reward",type=plt.plot,color="blue",path_name=f"{path}/reward")

