# Laboratorio 8 - Gymnasium y DQN
- Ricardo Méndez
- Sara Echverría
- Melissa Pérez Alarcón, 21385

https://github.com/MelissaPerez09/Lab08-CC3104

In [1]:
import random
import collections
import math
import os
from typing import Deque, Tuple


import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import imageio

In [2]:
ENV_NAME = 'CartPole-v1'
SEED = 42
GAMMA = 0.99
LR = 1e-3
BATCH_SIZE = 64
BUFFER_SIZE = 10000
MIN_REPLAY_SIZE = 1000
TARGET_UPDATE_FREQ = 1000 # steps
MAX_EPISODES = 800
MAX_STEPS_PER_EPISODE = 500
EPS_START = 1.0
EPS_END = 0.01
EPS_DECAY = 50000
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

<torch._C.Generator at 0x117455b30>

In [4]:
Transition = collections.namedtuple('Transition', ('state', 'action', 'reward', 'next_state', 'done'))


class ReplayBuffer:
    def __init__(self, capacity: int):
        self.buffer: Deque[Transition] = collections.deque(maxlen=capacity)


        def push(self, *args):
            self.buffer.append(Transition(*args))


        def sample(self, batch_size: int):
            batch = random.sample(self.buffer, batch_size)
            states = torch.tensor(np.array([t.state for t in batch]), dtype=torch.float32, device=DEVICE)
            actions = torch.tensor([t.action for t in batch], dtype=torch.int64, device=DEVICE).unsqueeze(1)
            rewards = torch.tensor([t.reward for t in batch], dtype=torch.float32, device=DEVICE).unsqueeze(1)
            next_states = torch.tensor(np.array([t.next_state for t in batch]), dtype=torch.float32, device=DEVICE)
            dones = torch.tensor([t.done for t in batch], dtype=torch.float32, device=DEVICE).unsqueeze(1)
            return states, actions, rewards, next_states, dones


        def __len__(self):
            return len(self.buffer)

In [5]:
class QNetwork(nn.Module):
    def __init__(self, obs_dim: int, action_dim: int, hidden_dim: int = 128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim)
        )

    def forward(self, x):
        return self.net(x)

In [6]:
class DQNAgent:
    def __init__(self, obs_dim, action_dim):
        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.online_net = QNetwork(obs_dim, action_dim).to(DEVICE)
        self.target_net = QNetwork(obs_dim, action_dim).to(DEVICE)
        self.target_net.load_state_dict(self.online_net.state_dict())
        self.optimizer = optim.Adam(self.online_net.parameters(), lr=LR)
        self.replay_buffer = ReplayBuffer(BUFFER_SIZE)
        self.total_steps = 0

    def select_action(self, state, epsilon=0.0):
        # state: np.array
        if random.random() < epsilon:
            return random.randrange(self.action_dim)
        else:
            s = torch.tensor(state, dtype=torch.float32, device=DEVICE).unsqueeze(0)
            with torch.no_grad():
                qvals = self.online_net(s)
            return int(torch.argmax(qvals, dim=1).item())

    def update(self):
        if len(self.replay_buffer) < BATCH_SIZE:
            return None

        states, actions, rewards, next_states, dones = self.replay_buffer.sample(BATCH_SIZE)

        # Q(s,a) for chosen actions
        q_values = self.online_net(states).gather(1, actions)

        # Target: r + gamma * max_a' Q_target(s', a') * (1 - done)
        with torch.no_grad():
            next_q_values = self.target_net(next_states)
            max_next_q_values, _ = torch.max(next_q_values, dim=1, keepdim=True)
            target_q = rewards + GAMMA * (1 - dones) * max_next_q_values

        loss = nn.functional.mse_loss(q_values, target_q)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return loss.item()

    def soft_update_target(self):
        self.target_net.load_state_dict(self.online_net.state_dict())

    def save(self, path):
        torch.save(self.online_net.state_dict(), path)

    def load(self, path):
        self.online_net.load_state_dict(torch.load(path, map_location=DEVICE))
        self.target_net.load_state_dict(self.online_net.state_dict())
