In [None]:
import gym
from gym import wrappers
from gym.spaces.utils import flatdim
import torch as pt
import cv2

import numpy as np
import matplotlib.pyplot as plt
from tqdm import trange, tqdm

import io
import base64
from IPython.display import HTML

# Comment out for debugging
import warnings
warnings.filterwarnings('ignore')

In [None]:
def play_video(ep_number: int):
    video = io.open(f"./gym-results/rl-video-episode-{ep_number}.mp4", 'r+b').read()
    encoded = base64.b64encode(video)
    return HTML(data='''
        <video width="360" height="auto" alt="test" controls><source src="data:video/mp4;base64,{0}" type="video/mp4" /></video>'''
    .format(encoded.decode('ascii')))

def smooth(y, box_pts):
    box = np.ones(box_pts)/box_pts
    y_smooth = np.convolve(y, box, mode='valid')
    return y_smooth

In [None]:
env = gym.make("ALE/Pong-v5")
env = wrappers.RecordVideo(env, "./gym-results", new_step_api=True)
env.reset(seed=42)

# hyperparams
num_steps = 500000
random_rewards = [0]
num_eps = 0
for _ in trange(num_steps):
   action = env.action_space.sample() 
   observation, reward, done, info, _ = env.step(action)
   random_rewards.append(reward if reward > 0 else 0)

   if done:
      observation = env.reset()
      num_eps += 1
   
env.close()

plt.plot(smooth(random_rewards, 200))
plt.title(f"{sum(random_rewards)} wins out of {num_eps} episodes")

In [None]:
sum(random_rewards)

In [None]:
import collections
# Several useful wrapper environments
class FireResetEnv(gym.Wrapper):
    def __init__(self, env=None):
        super(FireResetEnv, self).__init__(env)
        assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
        assert len(env.unwrapped.get_action_meanings()) >= 3
    
    def step(self, action):
        return self.env.step(action)
       
def reset(self):
    self.env.reset()
    obs, _, done, _ = self.env.step(1)
    if done:
        self.env.reset()
    obs, _, done, _ = self.env.step(2)
    if done:
        self.env.reset()
    return obs

class MaxAndSkipEnv(gym.Wrapper):
    def __init__(self, env=None, skip=4):
        super(MaxAndSkipEnv, self).__init__(env)
        self._obs_buffer = collections.deque(maxlen=2)
        self._skip = skip
        
    def step(self, action):
        total_reward = 0.0
        done = None
        for _ in range(self._skip):
           obs, reward, done, info, _ = self.env.step(action)
           self._obs_buffer.append(obs)
           total_reward += reward
           if done:
               break
        max_frame = np.max(np.stack(self._obs_buffer), axis=0)
        return max_frame, total_reward, done, info, _

    def reset(self):
       self._obs_buffer.clear()
       obs = self.env.reset()
       self._obs_buffer.append(obs)
       return obs

# TODO: Still produces broken ball and split user paddle
class ProcessFrame84(gym.ObservationWrapper):
    def __init__(self, env=None):
        super(ProcessFrame84, self).__init__(env)
        self.observation_space = gym.spaces.Box(low=0, high=255, 
                            shape=(84, 84, 1), dtype=np.uint8)
    
    def observation(self, obs):
        return ProcessFrame84.process(obs)
         
    @staticmethod
    def process(frame):
        if frame.size == 210 * 160 * 3:
            img = np.reshape(frame, [210, 160,  3]).astype(np.float32)
        elif frame.size == 250 * 160 * 3:
            img = np.reshape(frame, [250, 160, 3]).astype(np.float32)
        else:
            assert False, "Unknown resolution." 

        img = img[:, :, 0] * 0.299 + img[:, :, 1] * 0.587 + img[:, :, 2] * 0.114
        resized_screen = cv2.resize(img, (84, 110),            
                        interpolation=cv2.INTER_AREA)
        x_t = resized_screen[16:100, :] # remove scoreboard + bottom
        x_t = np.reshape(x_t, [84, 84, 1])
        return x_t.astype(np.uint8)

class ImageToPyTorch(gym.ObservationWrapper):
    def __init__(self, env):
        super(ImageToPyTorch, self).__init__(env)
        old_shape = self.observation_space.shape
        self.observation_space = gym.spaces.Box(low=0.0, high=1.0,            
                                shape=(old_shape[-1], 
                                old_shape[0], old_shape[1]),
                                dtype=np.float32)
    def observation(self, observation):
      return np.moveaxis(observation, 2, 0)

# Stacks several frames together
class BufferWrapper(gym.ObservationWrapper):
    def __init__(self, env, n_steps, dtype=np.float32):
        super(BufferWrapper, self).__init__(env)
        self.dtype = dtype
        old_space = env.observation_space
        self.observation_space = \
                 gym.spaces.Box(old_space.low.repeat(n_steps, 
                 axis=0),old_space.high.repeat(n_steps, axis=0),     
                 dtype=dtype)
    def reset(self):
        self.buffer = np.zeros_like(self.observation_space.low,
        dtype=self.dtype)
        return self.observation(self.env.reset())
        
    def observation(self, observation):
        self.buffer[:-1] = self.buffer[1:]
        self.buffer[-1] = observation
        return self.buffer

class ScaledFloatFrame(gym.ObservationWrapper):
     def observation(self, obs):
         return np.array(obs).astype(np.float32) / 255.0

def make_pong():
   env = gym.make("ALE/Pong-v5")
   env = wrappers.RecordVideo(env, "./gym-results", new_step_api=True)
   env = MaxAndSkipEnv(env, skip=2)
   env = FireResetEnv(env)
   env = ProcessFrame84(env)
   env = ImageToPyTorch(env)
   env = BufferWrapper(env, 4)
   env = ScaledFloatFrame(env)

   return env



In [None]:
import torch.nn as nn

class QNet(nn.Module):
    def __init__(self, obs_space, act_space):
        super(QNet, self).__init__()
        self.obs_shape = obs_space.low.shape
        self.act_shape = flatdim(act_space)
        self.replay_memory = []

        self.conv1 = nn.Conv2d(self.obs_shape[0], 16, 8, stride=4)
        self.conv2 = nn.Conv2d(16, 32, 4, stride=2)
        self.fc1 = nn.Linear(2592, 256) # mathemagic
        self.out = nn.Linear(256, self.act_shape)

        self.optimizer = pt.optim.RMSprop(self.parameters(), lr=1e-4)
        self.loss = nn.MSELoss()
    
    def forward(self, x):
        x = self.conv1(x)
        x = nn.ReLU()(x)
        x = self.conv2(x)
        x = nn.ReLU()(x)
        x = nn.Flatten()(x)
        x = self.fc1(x)
        x = nn.ReLU()(x)
        return self.out(x)





In [None]:
def epsilon_greedy(q_vals, epsilon=0.1):
    if pt.rand(1) < epsilon:
        return pt.randint(6, (1,))
    
    return int(pt.argmax(q_vals).numpy())

class ReplayMemory():
    def __init__(self, size, obs_space, act_shape):
        self.size = size
        self.counter = 0

        obs_shape = obs_space.shape

        self.obs = np.zeros((size, *obs_shape))
        self.actions = np.zeros((size))
        self.rewards = np.zeros((size))
        self.obs_n = np.zeros((size, *obs_shape))
        self.done = np.zeros((size))
    
    def store_transition(self, obs, action, obs_n, r, done):
        indx = self.counter % self.size

        self.obs[indx] = obs
        self.actions[indx] = action
        self.rewards[indx] = r
        self.obs_n[indx] = obs_n
        self.done[indx] = done

        self.counter += 1
    
    def sample_batch(self, batch_size):
        max_mem = min(self.counter, self.size)
        batch = np.random.choice(max_mem, batch_size, replace=False)

        obs = pt.from_numpy(self.obs[batch]).float()
        actions = pt.from_numpy(self.actions[batch]).long()
        rewards = pt.from_numpy(self.rewards[batch]).float()
        obs_n = pt.from_numpy(self.obs_n[batch]).float()
        terminal = pt.from_numpy(self.done[batch]).bool()

        return obs, actions, obs_n, rewards, terminal



In [None]:

# Test
env = make_pong()
obs = env.reset()
memory = ReplayMemory(10, env.observation_space, env.action_space)

dqn = QNet(env.observation_space, env.action_space)
print(dqn)
obs_n, r, done, _, _ = env.step(env.action_space.sample())

# q_vals = dqn(pt.tensor(obs))
# action = epsilon_greedy(q_vals)
action = 0

print(f"Storing transition: {obs.shape}, {action}, {obs_n.shape}, {r}, {done}")
memory.store_transition(obs, action, obs_n, r, done)
memory.store_transition(obs, action, obs_n, r, done)
memory.store_transition(obs, action, obs_n, r, done)
memory.store_transition(obs, action, obs_n, r, done)
memory.store_transition(obs, action, obs_n, r, done)
memory.store_transition(obs, action, obs_n, r, done)
states, actions, rewards, states_, dones = memory.sample_batch(6)
print(f"actions: {actions}")
q_vals = dqn(states)
print(f"All q-vals: {q_vals}")

In [None]:
# Pseudocode from DQN Paper
env = make_pong()

# Hyperparams
epsilon = 0.99
batch_size = 16
update_freq = 2500
gamma = 0.99

# Initialize Replay Memory to capacity N
replay_capacity = 1000
memory = ReplayMemory(replay_capacity, env.observation_space, env.action_space)

# Initialize action-value Function Q with random weights
Q = QNet(env.observation_space, env.action_space)
T = QNet(env.observation_space, env.action_space)

num_episodes = 1000
pbar = tqdm(range(num_episodes))
dqn_rewards = [0]
wins = [0]
losses = [0]
i = 0
for ep in pbar:
    obs = env.reset()
    done = False
    ep_reward = 0
    while not done:

        # Select action random action with prob epsilon, otherwise argmax_a Q(obs, a)
        q_vals = Q(pt.from_numpy(obs).unsqueeze(0).float())
        # print("experienced q_val: ", q_vals)
        action = epsilon_greedy(q_vals.detach(), epsilon=epsilon)

        obs_n, r, done, _, _ = env.step(action)

        ep_reward += r
        wins.append(r if r > 0 else 0)
        memory.store_transition(obs, action, obs_n, r, done)
        obs = obs_n

        cur_loss = 0
        if (memory.counter > memory.size):
            # Update network weights on batch
            Q.optimizer.zero_grad()

            states, actions, states_, rewards, dones = memory.sample_batch(batch_size)
            indices = np.arange(batch_size)

            q_vals = Q.forward(states)
            # print("batch q_vals: ", q_vals)
            q_pred = q_vals.gather(1, actions.unsqueeze(-1))
            # print("selected action vals:", q_pred)
            q_next = T.forward(states_).detach().max(dim=1)[0]

            q_target = rewards + gamma*dones*q_next
            # print("target vals: ", q_target)

            loss = Q.loss(q_target, q_pred)

            pbar.set_description(f"Epsilon: {epsilon:0.2f}, Win rate: {np.sum(wins)/ep:0.2f}, Rewards: {smooth(dqn_rewards, 50)[-1]}")

            loss.backward()
            # for param in Q.parameters():
            #     param.grad.data.clamp_(-1, 1)
            Q.optimizer.step()
            losses.append(loss.detach())
        
        if i % update_freq == 0:
            T.load_state_dict(Q.state_dict())
            T.eval()
        
    epsilon = max(0.05, epsilon - 1/500)
    dqn_rewards.append(ep_reward)

    if (ep % 20 == 0):
        print(f"batch Q vals: {q_vals}")
        plt.clf()
        plt.plot(smooth(dqn_rewards, 10))
        plt.savefig(f"./gym-results/episode_{ep}.png")

        plt.clf()
        plt.plot(smooth(losses, 1000))
        plt.savefig(f"./gym-results/losses_episode_{ep}.png")



In [None]:
def moving_average(a, n=10) :
    ret = np.cumsum(a, dtype=float)
    ret[n:] = ret[n:] - ret[:-n]
    return ret[n - 1:] / n if len(ret[n-1:]) > 0 else [0]

In [None]:
env = make_pong()
obs = env.reset()
done = False
while not done:
    obs_format = pt.from_numpy(obs).unsqueeze(0).float()
    q_vals = Q(obs_format)
    action = pt.argmax(q_vals)

    print(f"action selected: {action}, vals: {q_vals}, obs: {np.sum(obs)}")
    obs, r, done, _, _ = env.step(action)
    
