In [1]:
!pip install gymnasium

Collecting gymnasium
  Downloading gymnasium-0.29.1-py3-none-any.whl (953 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/953.9 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m225.3/953.9 kB[0m [31m7.7 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m953.9/953.9 kB[0m [31m17.5 MB/s[0m eta [36m0:00:00[0m
Collecting farama-notifications>=0.0.1 (from gymnasium)
  Downloading Farama_Notifications-0.0.4-py3-none-any.whl (2.5 kB)
Installing collected packages: farama-notifications, gymnasium
Successfully installed farama-notifications-0.0.4 gymnasium-0.29.1


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
torch.cuda.is_available()

True

In [4]:
import time
import random

import numpy as np
import gymnasium as gym
import matplotlib.pyplot as plt

from tqdm import tqdm
from itertools import count

from IPython import display

In [5]:
class FCQ(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dims=(32, 32), activation_fc=F.relu):
        super(FCQ, self).__init__()
        self.activation_fc = activation_fc
        self.input_layer = nn.Linear(input_dim, hidden_dims[0])
        self.hidden_layers = nn.ModuleList()
        for i in range(len(hidden_dims) - 1):
            hidden_layer = nn.Linear(hidden_dims[i], hidden_dims[i+1])
            self.hidden_layers.append(hidden_layer)
        self.output_layer = nn.Linear(hidden_dims[-1], output_dim)

        device = "cpu"
        if torch.cuda.is_available():
            device = "cuda:0"
        self.device = torch.device(device)
        self.to(self.device)

    def _format(self, state):
        x = state
        if not isinstance(x, torch.Tensor):
            x = torch.tensor(x, device=self.device, dtype=torch.float32)
            x = x.unsqueeze(0)
        return x

    def forward(self, state):
        x = self._format(state)
        x = self.activation_fc(self.input_layer(x))
        for hidden_layer in self.hidden_layers:
            x = self.activation_fc(hidden_layer(x))
        x = self.output_layer(x)
        return x

    def numpy_float_to_device(self, variable):
        variable = torch.from_numpy(variable).float().to(self.device)
        return variable

    def load(self, experiences):
        states, actions, rewards, new_states, is_terminals = experiences
        states = torch.from_numpy(states).float().to(self.device)
        actions = torch.from_numpy(actions).long().to(self.device)
        new_states = torch.from_numpy(new_states).float().to(self.device)
        rewards = torch.from_numpy(rewards).float().to(self.device)
        is_terminals = torch.from_numpy(is_terminals).float().to(self.device)
        return states, actions, rewards, new_states, is_terminals

In [6]:
class GreedyStrategy():
    def __init__(self):
        pass
    def select_action(self, model, state):
        with torch.no_grad():
            q_values = model(state).cpu().detach().data.numpy().squeeze()
            return np.argmax(q_values)

In [7]:
class EGreedyStrategy():
    def __init__(self, epsilon=0.1):
        self.epsilon = epsilon

    def select_action(self, model, state):
        with torch.no_grad():
            q_values = model(state).cpu().detach().data.numpy().squeeze()

        if np.random.rand() > self.epsilon:
            action = np.argmax(q_values)
        else:
            action = np.random.randint(len(q_values))
        return action

In [8]:
class NFQ():
    def __init__(self, value_model_fn, value_optimizer_fn, value_optimizer_lr, training_strategy_fn, evaluation_strategy_fn, batch_size, epochs):
        self.value_model_fn = value_model_fn
        self.value_optimizer_fn = value_optimizer_fn
        self.value_optimizer_lr = value_optimizer_lr
        self.training_strategy_fn = training_strategy_fn
        self.evaluation_strategy_fn = evaluation_strategy_fn
        self.batch_size = batch_size
        self.epochs = epochs

    def train(self, env, seed, gamma, max_episodes):
        self.gamma = gamma
        self.seed = seed
        torch.manual_seed(self.seed); np.random.seed(self.seed); random.seed(self.seed)

        nS, nA = env.observation_space.shape[0], env.action_space.n

        self.online_model = self.value_model_fn(nS, nA)
        self.value_optimizer = self.value_optimizer_fn(self.online_model, self.value_optimizer_lr)
        self.training_stategy = training_strategy_fn()
        self.evaluation_strategy = evaluation_strategy_fn()
        self.experiences = []

        result = np.empty((max_episodes, 5))
        result[:] = np.nan
        training_time = 0

        for episode in tqdm(range(1, max_episodes + 1), leave=True):
            state, info = env.reset()
            for step in count():
                state, is_terminal = self.interaction_step(state, env)

                if len(self.experiences) >= self.batch_size:
                    experiences = np.array(self.experiences, dtype=object)
                    batches = [np.vstack(sars) for sars in experiences.T]
                    experiences = self.online_model.load(batches)

                    for _ in range(self.epochs):
                        self.optimize_model(experiences)
                    self.experiences.clear()

                if is_terminal:
                    break

        rewards, final_eval_score, score_std = self.evaluate(self.online_model, env, n_episodes=100)
        env.close()
        print("Training Complete")
        print(f"Final evaluation score: {final_eval_score:.2f} -+ ScoreSTD: {score_std:.2f}")
        return final_eval_score


    def interaction_step(self, state, env):
        action = self.training_stategy.select_action(self.online_model, state)
        next_state, reward, is_terminal, is_truncated, info = env.step(action)
        is_failure = is_terminal and not is_truncated
        experience = (state, action, reward, next_state, float(is_failure))
        self.experiences.append(experience)

        return next_state, (is_terminal or is_truncated)

    def optimize_model(self, experiences):
        states, actions, rewards, next_states, is_terminals = experiences
        batch_size = len(is_terminals)

        # Get the best action of the next state
        best_action_Q_next_state = self.online_model(next_states).detach().max(1)[0].unsqueeze(1)
        target_q_state = rewards + self.gamma * best_action_Q_next_state * (1 - is_terminals)
        # Get current estimate of Q(s, a)
        q_sa = self.online_model(states).gather(1, actions)

        td_errors = q_sa - target_q_state
        # MSE Loss
        value_loss = td_errors.pow(2).mul(0.5).mean()

        self.value_optimizer.zero_grad()
        value_loss.backward()
        self.value_optimizer.step()

    def evaluate(self, eval_policy_model, eval_env, n_episodes=1):
        rewards = []
        for _ in tqdm(range(n_episodes), leave=True):
            state, info = eval_env.reset()
            rewards.append(0)
            for _ in count():
                action = self.evaluation_strategy.select_action(eval_policy_model, state)
                state, reward, is_terminal, is_truncated, info = eval_env.step(action)
                rewards[-1] += reward
                if (is_terminal or is_truncated):
                    break
        return rewards, np.mean(rewards), np.std(rewards)

In [9]:
env = gym.make("CartPole-v1")

In [10]:
value_model_fn = lambda nS, nA: FCQ(nS, nA, hidden_dims=(512, 128))
value_optimizer_fn = lambda net, lr: torch.optim.RMSprop(net.parameters(), lr=lr)
value_optimizer_lr = 0.0005

training_strategy_fn = lambda: EGreedyStrategy(epsilon=0.5)
evaluation_strategy_fn = lambda: GreedyStrategy()

batch_size = 1024
epochs = 40

agent = NFQ(value_model_fn, value_optimizer_fn, value_optimizer_lr, training_strategy_fn, evaluation_strategy_fn, batch_size, epochs)
final_eval_score = agent.train(env, seed=90, gamma=1.00, max_episodes=5000)

100%|██████████| 5000/5000 [04:17<00:00, 19.40it/s]
100%|██████████| 100/100 [00:14<00:00,  7.08it/s]

Training Complete
Final evaluation score: 489.19 -+ ScoreSTD: 21.03



