In [9]:
import torch
import multiprocessing
from torchrl.envs import RewardSum, TransformedEnv
from torchrl.envs.libs.vmas import VmasEnv
from torchrl.envs.utils import check_env_specs
from matplotlib import pyplot as plt
from tqdm import tqdm
import numpy as np


class SARSATrainer:
    def __init__(
        self,
        num_epochs,
        frames_per_batch,
        device,
        total_frames,
        gamma,
        epsilon_start,
        epsilon_end,
        epsilon_decay_steps,
        env,
    ):
        self.num_epochs = num_epochs
        self.frames_per_batch = frames_per_batch
        self.device = device
        self.total_frames = total_frames
        self.gamma = gamma
        self.epsilon_start = epsilon_start
        self.epsilon_end = epsilon_end
        self.epsilon_decay_steps = epsilon_decay_steps
        self.env = env

        # Initialize Q-table
        self.num_states = self.get_num_states(env.observation_space)
        self.num_actions = self.get_num_actions(env.action_space)
        self.q_values = np.zeros((self.num_states, self.num_actions))

    def get_num_states(self, observation_space):
        if observation_space is None:
            return 0
        elif isinstance(observation_space, tuple):
            return sum(self.get_num_states(space) for space in observation_space)
        else:
            return observation_space.n if hasattr(observation_space, 'n') else observation_space.shape[0]

    def get_num_actions(self, action_space):
        if action_space is None:
            return 0
        elif isinstance(action_space, tuple):
            return sum(self.get_num_actions(space) for space in action_space)
        else:
            return action_space.n if hasattr(action_space, 'n') else action_space.shape[0]


    def train(self, max_steps):
        pbar = tqdm(total=self.total_frames, desc="episode_reward_mean = 0")
        episode_reward_mean_list = []

        epsilon_decay = np.linspace(self.epsilon_start, self.epsilon_end, self.epsilon_decay_steps)

        for step, epsilon in zip(range(self.total_frames), epsilon_decay):
            state = self.env.reset()
            done = False
            episode_reward = 0

            # Select initial action using epsilon-greedy policy
            action = self.select_action(state, epsilon)

            while not done:
                # Take action, observe next state and reward
                next_state, reward, done, _ = self.env.step(action)

                # Select next action using epsilon-greedy policy
                next_action = self.select_action(next_state, epsilon)

                # Update Q-value based on SARSA algorithm
                self.update_q_value(state, action, reward, next_state, next_action)

                # Update state, action and episode reward
                state = next_state
                action = next_action
                episode_reward += reward

            episode_reward_mean_list.append(episode_reward)
            pbar.update(1)

        pbar.close()
        return episode_reward_mean_list

    def select_action(self, state, epsilon):
        # Epsilon-greedy policy
        if np.random.rand() < epsilon:
            return self.env.action_space.sample()  # Random action
        else:
            return np.argmax(self.q_values[state])

    def update_q_value(self, state, action, reward, next_state, next_action):
        # SARSA update rule: Q(s, a) += alpha * (r + gamma * Q(s', a') - Q(s, a))
        alpha = 0.1  # Learning rate
        self.q_values[state, action] += alpha * (
            reward + self.gamma * self.q_values[next_state, next_action] - self.q_values[state, action]
        )


def main():
    torch.manual_seed(0)
    is_fork = multiprocessing.get_start_method() == "fork"
    device = (
        torch.device(0)
        if torch.cuda.is_available() and not is_fork
        else torch.device("cpu")
    )

    # Training
    num_epochs = 1
    gamma = 0.9
    epsilon_start = 1.0
    epsilon_end = 0.1
    epsilon_decay_steps = 100
    max_steps = 100
    num_vmas_envs = 1  # SARSA doesn't require parallel environments
    scenario_name = "navigation"
    n_agents = 3

    env = VmasEnv(
        scenario=scenario_name,
        num_envs=num_vmas_envs,
        continuous_actions=True,
        max_steps=max_steps,
        device=device,
        n_agents=n_agents,
    )

    env = TransformedEnv(
        env,
        RewardSum(in_keys=[env.reward_key], out_keys=[("agents", "episode_reward")]),
    )

    check_env_specs(env)

    sarsa_trainer = SARSATrainer(
        num_epochs=num_epochs,
        frames_per_batch=max_steps,
        device=device,
        total_frames=epsilon_decay_steps,
        gamma=gamma,
        epsilon_start=epsilon_start,
        epsilon_end=epsilon_end,
        epsilon_decay_steps=epsilon_decay_steps,
        env=env,
    )

    episode_rewards = sarsa_trainer.train(max_steps)

    plt.plot(episode_rewards)
    plt.xlabel("Episode")
    plt.ylabel("Reward")
    plt.title("Episode Reward")
    plt.show()


if __name__ == "__main__":
    main()


2024-04-15 02:31:41,508 [torchrl][INFO] check_env_specs succeeded!


TypeError: 'NoneType' object is not subscriptable