In [1]:
#!pip install argh
#!pip install "gym[atari, accept-rom-license]"




[notice] A new release of pip available: 23.2.1 -> 23.3.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [2]:
import torch
import numpy as np
from tqdm import tqdm
import gym
import time
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from dataclasses import dataclass
from typing import Any
from random import sample, random
import wandb
from collections import deque
import argh
import cv2
from random import randint

ModuleNotFoundError: No module named 'gym'

In [None]:
class ConvModel(nn.Module):
    def __init__(self, obs_shape, num_actions, lr=0.0001):
        assert len(obs_shape) == 3
        super(ConvModel, self).__init__()
        self.obs_shape = obs_shape
        self.num_actions = num_actions
        self.conv_net = torch.nn.Sequential(
            torch.nn.Conv2d(4, 16, (8, 8), stride=(4, 4)),
            torch.nn.ReLU(),
            torch.nn.Conv2d(16, 32, (4, 4), stride=(2, 2)),
            torch.nn.ReLU(),
        )
        
        with torch.no_grad():
            dummy = torch.zeros((1, *obs_shape))
            x = self.conv_net(dummy)
            s = x.shape
            fc_size = s[1] * s[2] * s[3]
        
        self.fc_net = torch.nn.Sequential(
            torch.nn.Linear(fc_size, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, num_actions),
        )
        self.opt = optim.Adam(self.parameters(), lr=lr)
        
    def forward(self, x):
        conv_latent = self.conv_net(x / 255.0)  
        return self.fc_net(conv_latent.view(conv_latent.shape[0], -1))




In [None]:
class FrameStackingAndResizingEnv:
    def __init__(self, env, w, h, num_stack=4):
        self.env = env
        self.n = num_stack
        self.w = w
        self.h = h
        
        self.buffer = np.zeros((num_stack, h, w), 'uint8')

    def _preprocess_frame(self, frame):
        image = cv2.resize(frame, (self.w, self.h))
        image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        return image

    def step(self, action):
        im, reward, done, info = self.env.step(action)
        im = self._preprocess_frame(im)
        self.buffer[1:self.n, :, :] = self.buffer[0:self.n-1, :, :]
        self.buffer[0, :, :] = im
        return self.buffer.copy(), reward, done, info
    
    @property
    def observation_space(self):
        # gym.spaces.Box()
        return np.zeros((self.n, self.h, self.w))

    @property
    def action_space(self):
        return self.env.action_space


    def reset(self):
        im = self.env.reset()
        im = self._preprocess_frame(im)
        self.buffer = np.stack([im]*self.n, 0)
        return self.buffer.copy()

    def render(self, mode):
        return self.env.render(mode)


In [None]:
@dataclass
class Sarsd:
    state: Any
    action: int
    reward: float
    next_state: Any
    done: bool
    
class ReplayBuffer:
    def __init__(self, buffer_size=100000):
        self.buffer_size = buffer_size
        self.buffer = [None]*buffer_size
        self.idx = 0
    
    def insert(self, sars):
        self.buffer[self.idx % self.buffer_size] = sars
        self.idx += 1

    def sample(self, num_samples):
        assert num_samples <= min(self.idx, self.buffer_size)
        if self.idx < self.buffer_size:
            return sample(self.buffer[:self.idx], num_samples)
        return sample(self.buffer, num_samples)

def update_tgt_model(m, tgt):
    tgt.load_state_dict(m.state_dict())

def train_step(model, state_transitions, tgt, num_actions, device, gamma=0.99):
    cur_states = torch.stack([torch.Tensor(s.state) for s in state_transitions]).to(device)
    rewards = torch.stack([torch.Tensor([s.reward]) for s in state_transitions]).to(device)
    mask = torch.stack(
        [
            torch.Tensor([0]) if s.done else torch.Tensor([1])
            for s in state_transitions
        ]
    ).to(device)
    next_states = torch.stack([torch.Tensor(s.next_state) for s in state_transitions]).to(device)
    actions = [s.action for s in state_transitions]

    with torch.no_grad():
        qvals_next = tgt(next_states).max(1)[0]  

    model.opt.zero_grad()
    qvals = model(cur_states)  
    one_hot_actions = F.one_hot(torch.LongTensor(actions), num_actions).to(device)

    loss_fn = nn.SmoothL1Loss()
    loss = loss_fn(
        torch.sum(qvals * one_hot_actions, -1), rewards + mask[:,0] * qvals_next * gamma
    )
    
    loss.backward()
    model.opt.step()
    return loss



In [None]:
def main(test=False, chkpt=None, device='cuda'):
    memory_size = 50000
    min_rb_size = 2000
    sample_size = 100
    lr = 0.001
    # eps_max = 1.0
    eps_min = 0.05
    eps_decay = 0.99995
    env_steps_before_train = 10
    tgt_model_update = 1000
    
    env = gym.make('ALE/Breakout-v5', render_mode='rgb_array') 
    env = FrameStackingAndResizingEnv(env, 84, 84, 4)
    last_observation = env.reset()
    
    m = ConvModel(env.observation_space.shape, env.action_space.n, lr=lr).to(device)
    if chkpt is not None:
        m.load_state_dict(torch.load(chkpt))
    tgt = ConvModel(env.observation_space.shape, env.action_space.n, lr=lr).to(device)
    update_tgt_model(m, tgt)
    
    rb = ReplayBuffer()
    steps_since_train = 0
    epochs_since_tgt = 0
    
    step_num = -1 * min_rb_size
    episode_rewards = []
    rolling_reward = 0
    
    tq = tqdm()
    try:
        while True:
            if test:
                env.render()
                time.sleep(0.05)
            tq.update(1)

            eps = eps_decay ** (step_num)
            if test:
                eps = 0

            if random() < eps:
                action = env.action_space.sample()  # your agent here (this takes random actions)
            else:
                import ipdb; ipdb.set_trace()
                action = m(torch.Tensor(last_observation).unsqueeze(0).to(device)).max(1)[-1].item()

            observation, reward, done, info = env.step(action)
            rolling_reward += reward

            reward = reward * 0.1
            rb.insert(Sarsd(last_observation, action, reward, observation, done))

            last_observation = observation

            if done:
                episode_rewards.append(rolling_reward)
                if test:
                    print(rolling_reward)
                rolling_reward = 0
                observation = env.reset()

            steps_since_train += 1
            step_num += 1

            if (
            (not test) and
            rb.idx > min_rb_size and
            steps_since_train > env_steps_before_train
            ):
                loss = train_step(
                    m, rb.sample(sample_size), tgt, env.action_space.n, device
                )
                print(f"""loss: {loss.detach().cpu().item()},
                        eps: {eps},
                        avg_reward: {np.mean(episode_rewards)}"""            )
                episode_rewards = []
                epochs_since_tgt += 1
                if epochs_since_tgt > tgt_model_update:
                    print("updating target model")
                    update_tgt_model(m, tgt)
                    epochs_since_tgt = 0
                    torch.save(tgt.state_dict(), f"models/{step_num}.pth")

                steps_since_train = 0

    except KeyboardInterrupt:
        pass

    env.close()



In [None]:
if __name__ == "__main__":
    main()

NamespaceNotFound: Namespace ALE not found. Have you installed the proper package for ALE?