In [None]:
import torch
import numpy as np
import rl_dbs.gym_oscillator
import rl_dbs.gym_oscillator.envs
from TD3 import TD3
from utils import ReplayBuffer
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.results_plotter import load_results, ts2xy
from stable_baselines3.common.noise import NormalActionNoise
from stable_baselines3.common.callbacks import BaseCallback
def train():
    ######### Hyperparameters #########
    random_seed = 42
    gamma = 0.99                # discount for future rewards
    batch_size = 64             # num of transitions sampled from replay buffer
    lr = 0.002
    exploration_noise = 0.1
    polyak = 0.995              # target policy update parameter (1-tau)
    policy_noise = 0.2          # target policy smoothing noise
    noise_clip = 0.5
    policy_delay = 2            # delayed policy updates parameter
    max_episodes = 1000         # max num of episodes
    max_timesteps = 10000       # max timesteps in one episode
    directory = "./td3_model/"  # save trained models
    filename = "TD3_Oscillator_{}"
    ###################################

    # Create the environment
    env = rl_dbs.gym_oscillator.envs.oscillatorEnv()

    # Get environment dimensions
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]
    max_action = float(env.action_space.high[0])

    # Initialize policy and replay buffer
    policy = TD3(lr, state_dim, action_dim, max_action)
    replay_buffer = ReplayBuffer()

    # Set random seeds
    if random_seed:
        print("Random Seed: {}".format(random_seed))
        env.reset(seed=random_seed)
        torch.manual_seed(random_seed)
        np.random.seed(random_seed)

    # Flexible reset function
    def flexible_reset(env):
        reset_result = env.reset()

        # If it's a tuple, return the first element
        if isinstance(reset_result, tuple):
            return reset_result[0]

        # If it's a numpy array, return it directly
        if isinstance(reset_result, np.ndarray):
            return reset_result

        # If it's a list, return the first element
        if isinstance(reset_result, list) and reset_result:
            return reset_result[0]

        # If none of the above, raise an error
        raise ValueError(f"Unexpected reset result type: {type(reset_result)}")

    # logging variables:
    avg_reward = 0
    ep_reward = 0
    log_f = open("log.txt","w+")

    # training procedure:
    for episode in range(1, max_episodes+1):
        state = flexible_reset(env)
        ep_reward = 0

        for t in range(max_timesteps):
            # select action and add exploration noise:
            action = policy.select_action(state)

            # Ensure action is in the correct format for this environment
            action = np.array([action[0]])

            # Add exploration noise
            action = action + np.random.normal(0, exploration_noise, size=env.action_space.shape[0])
            action = action.clip(env.action_space.low, env.action_space.high)

            # take action in env:
            next_state, reward, terminated, truncated, _ = env.step(action)

            replay_buffer.add((state, action, reward, next_state, terminated or truncated))
            state = next_state

            avg_reward += reward
            ep_reward += reward

            # Reset environment if episode is done
            if terminated or truncated:
                break

            # Perform policy update if buffer has enough samples
            if len(replay_buffer.buffer) >= batch_size:
                policy.update(replay_buffer, t, batch_size, gamma, polyak, policy_noise, noise_clip, policy_delay)

        # logging updates:
        log_f.write('{},{}\n'.format(episode, ep_reward))
        log_f.flush()

        # print avg reward every log interval:
        if episode % 10 == 0:
            avg_reward = int(avg_reward / 10)
            print("Episode: {}\tAverage Reward: {}".format(episode, avg_reward))
            avg_reward = 0

        # Save model periodically
        if episode > 500:
            policy.save(directory, filename.format(episode))

    log_f.close()
    print("Training completed.")

if __name__ == '__main__':
    train()