In [9]:
import gymnasium as gym
import ale_py
gym.register_envs(ale_py)
from IPython import display
from gymnasium.wrappers import RecordVideo
import ipywidgets as widgets
import warnings
import torch
import torch_directml
from torch import nn
from torch.nn import functional as F
from torch import optim
from collections import deque, namedtuple
import random
from torchvision.transforms import ToTensor
import math
import numpy as np
try:
    global device
    device = torch_directml.device()
except NameError:
    device = torch.device(
            "cuda" if torch.cuda.is_available() else
            "mps" if torch.backends.mps.is_available() else
            "cpu"
        )
warnings.filterwarnings("ignore")
device

device(type='privateuseone', index=0)

In [3]:
# Initialize environment
base_env = gym.make('ALE/IceHockey-v5', render_mode='rgb_array')
env = RecordVideo(base_env, video_folder="./videos", disable_logger=True)
done = False

obs, info = env.reset()
t = 0
max_steps = 200

# Simulate an episode
while not done:

    # Take a random action
    action = env.action_space.sample()
    new_obs, reward, terminated, truncated, info = env.step(action)

    done = terminated or truncated or t > max_steps
    t += 1

# Close environment
env.close()

# Render recording
widgets.Video.from_file(
    f"./videos/rl-video-episode-0.mp4", autoplay=False, loop=False, width=700
)

env.close()

A.L.E: Arcade Learning Environment (version 0.10.1+unknown)
[Powered by Stella]


- obs: array(210, 160, 3)
    + obs dataset: array(n, 210, 160, 3)
- reward: float

설계: 에피소드를 플레이해 기억 버퍼에 저장, 4개씩 임의로 뽑아 상태로 활용

In [4]:
Transition = namedtuple("Transition", ("state", "action", "next_state", "reward"))
# Transition이란 이름을 일종의 구조체로

class ReplayMemory(object):

    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)

    def push(self, *args):
        """Save a transition"""
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

In [None]:
class DQN(nn.Module):
    def __init__(self, n_actions):
        super().__init__()
        self.fc=nn.Sequential(
            nn.AdaptiveAvgPool2d(1), 
            nn.Flatten(),
            nn.Linear(64, n_actions)
            )
        self.conv = nn.Sequential(
            nn.Conv2d(3, 192, 3, padding="same"),
            nn.ReLU(),
            nn.MaxPool2d(5),  # (42, 32)
            nn.Conv2d(192, 768, 3, padding="same"),
            nn.ReLU(),
            nn.MaxPool2d(2),  # (21, 16)
            nn.Conv2d(768, 64, 3, padding="same"),
            nn.ReLU(),
        )
    def forward(self, x):
        conv0 = self.conv(x)
        conv0_result = self.fc(conv0)
        return conv0_result

In [7]:
# BATCH_SIZE is the number of transitions sampled from the replay buffer
# GAMMA is the discount factor as mentioned in the previous section
# EPS_START is the starting value of epsilon
# EPS_END is the final value of epsilon
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
# TAU is the update rate of the target network
# LR is the learning rate of the ``AdamW`` optimizer
BATCH_SIZE = 128
GAMMA = 0.99
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 1000
TAU = 0.005
LR = 1e-4

# Get number of actions from gym action space
n_actions = env.action_space.n
# Get the number of state observations
state, info = env.reset()
n_observations = state.shape

policy_net = DQN(n_actions).to(device)
target_net = DQN(n_actions).to(device)
# 처음에는 파라미터가 완전히 같게 시작
# 코딩 실습: 강화학습 과제 시간에 하기
target_net.load_state_dict(policy_net.state_dict())

optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)
memory = ReplayMemory(10000)


steps_done = 0

In [None]:
# module test

test_state = []
for _ in range(4):
    test_state.append(env.step(0)[0])
test_state = np.array(test_state, dtype=np.float64) # (4, 210, 160, 3)
test_state = test_state.transpose(0, 3, 1, 2) # (4, 3, 210, 160)
test_state = torch.Tensor(test_state).to(device=device)
print(test_state.dtype)
policy_net(test_state)

torch.float32
torch.Size([4, 3, 210, 160])


tensor([[ 0.1966, -0.1492,  2.5425, -0.4242,  2.1744,  0.7381,  0.3016, -4.8014,
         -3.8713,  6.2697, -2.4546, -0.2726,  4.1840,  2.4821, -0.7777, -6.5005,
          0.8856,  3.9952],
        [ 0.1966, -0.1492,  2.5425, -0.4242,  2.1744,  0.7381,  0.3016, -4.8014,
         -3.8713,  6.2697, -2.4546, -0.2726,  4.1840,  2.4821, -0.7777, -6.5005,
          0.8856,  3.9952],
        [ 0.1966, -0.1492,  2.5425, -0.4242,  2.1744,  0.7381,  0.3016, -4.8014,
         -3.8713,  6.2697, -2.4546, -0.2726,  4.1840,  2.4821, -0.7777, -6.5005,
          0.8856,  3.9952],
        [ 0.1966, -0.1492,  2.5425, -0.4242,  2.1744,  0.7381,  0.3016, -4.8014,
         -3.8713,  6.2697, -2.4546, -0.2726,  4.1840,  2.4821, -0.7777, -6.5005,
          0.8856,  3.9952]], device='privateuseone:0',
       grad_fn=<AddmmBackward0>)