In [1]:
import gymnasium as gym
import random
import time

import torch
from torch import nn, optim, autograd
from torchvision import transforms
import torch.nn.functional as F

from collections import namedtuple, deque

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using {} Device".format(device))

Using cuda Device


In [3]:
Transition = namedtuple('Transition', 
                        ('state', 'action', 'next_state', 'reward'))

class ReplayMemory(object):

    def __init__(self, capacity:int) -> None:
        self.memory = deque([], maxlen=capacity)

    def push(self, *args) -> None:
        """Save a transition"""
        self.memory.append(Transition(*args))

    def sample(self, batch_size:int) -> list:
        return random.sample(self.memory, batch_size)

    def __len__(self) -> deque:
        return len(self.memory)

In [4]:
class DQN(nn.Module):

    def __init__(self, n_observations:int, n_actions:int) -> None:
        super(DQN, self).__init__()
        self.layer1 = nn.Linear(n_observations, 128)
        self.layer2 = nn.Linear(128, 128)
        self.layer3 = nn.Linear(128, n_actions)

    # Called with either one element to determine next action or a batch
    # during optimization.  Returns tensor([[left0exp,right0exp]...]).
    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        return self.layer3(x)

In [5]:
env = gym.make("CartPole-v1", render_mode="human")

In [6]:
env.observation_space

Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32)

In [7]:
env.action_space

Discrete(2)

In [8]:
episodes = 10
for episode in range(episodes):
    observation, info = env.reset(seed=42)
    terminated = False
    score = 0

    while not terminated or truncated:
        env.render()
        action = random.choice([0,1])
        observation, reward, terminated, truncated, info = env.step(action)
        score += reward
        #time.sleep(0.1)
    print('Episode:{} Score:{}'.format(episode, score))

Episode:0 Score:18.0
Episode:1 Score:8.0
Episode:2 Score:42.0
Episode:3 Score:26.0
Episode:4 Score:16.0
Episode:5 Score:17.0
Episode:6 Score:21.0
Episode:7 Score:13.0
Episode:8 Score:21.0
Episode:9 Score:9.0


In [9]:
env.close()