In [None]:
import torch
import torch.nn as nn

In [None]:
import numpy as np
import random
import matplotlib.pyplot as plt
%matplotlib inline
import pandas as pd
import copy
import gym
import retro

In [8]:
class ImageToPyTorch(gym.ObservationWrapper):
    def __init__(self, env):
        super(ImageToPyTorch, self).__init__(env)
        old_shape = self.observation_space.shape
        self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=(old_shape[-1], old_shape[0], old_shape[1]), dtype=np.float32)

    def observation(self, observation):
        return np.moveaxis(observation, 2, 0)

env = retro.make(game='AirStriker-Genesis')
env = ImageToPyTorch(env)

In [None]:
class DQN(nn.Module):
    def __init__(self, input_shape, n_actions):
        super(DQN, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU()
        )

        conv_out_size = self._get_conv_out(input_shape)
        self.fc = nn.Sequential(
            nn.Linear(conv_out_size, 512),
            nn.ReLU(),
            nn.Linear(512, n_actions)
        )
        
        def _get_conv_out(self, shape):
        o = self.conv(torch.zeros(1, *shape))
        return int(np.prod(o.size()))

    def forward(self, x):
        conv_out = self.conv(x).view(x.size()[0], -1)
        return self.fc(conv_out)

In [83]:
miniBatchSize = 32
N = 1000 

In [149]:
def main(nEpisode=100, gamma = 0.99, epsilon0 = 0.9, maxIter = 200):
    epsilon = epsilon0
    buffer = set()
    loss_fn = torch.nn.MSELoss()
    Q = nn.Sequential(
            nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU()
        )
    QHat = copy.deepcopy(Q)
    for step in tqdm(range(nEpisode)):
        obs = env.reset()
        for _ in range(maxIter):
            epsilon = 0.99*epsilon
            if np.random.random() < epsilon:
                action = env.action_space.sample()
            else:
                action = 0
                M = Q(torch.FloatTensor([obs, action]))
                for i in range(env.action_space.n):
                    if Q(torch.FloatTensor([obs, i])) > M:
                        action = i
                        M = Q(torch.FloatTensor([obs, i]))
            obsNext, reward, done , _ = env.step(action)
            buffer.add((obs, action, reward, done, obsNext))
            obs = obsNext
            minibatch = random.sample(buffer, min(len(buffer), miniBatchSize))
            optimizer = torch.optim.SGD(Q.parameters(), lr = 0.01, momentum = 0.9)
            for o, a, r, d, oN in minibatch:
                M = QHat(torch.FloatTensor([oN, action]))
                for i in range(env.action_space.n):
                    if QHat(torch.FloatTensor([oN, i])) > M:
                        M = QHat(torch.FloatTensor([oN, i]))
                output = Q(torch.FloatTensor([o, a]))
                target = torch.FloatTensor([r if d else r + gamma * M])
                optimizer.zero_grad()
                loss = loss_fn(output, target)
                loss.backward()
                optimizer.step()   
            if step % N == 0:
                QHat = copy.deepcopy(Q)        
            if done:
                break
    return Q

In [143]:
Q = main(20)

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19


In [144]:
def transform(Q, s, a):
    return float(Q(torch.FloatTensor([s, a])))

In [145]:
q_table = np.ones((env.observation_space.n, env.action_space.n))

for i in range(env.observation_space.n):
    for j in range(env.action_space.n):
        q_table[i, j] = transform(Q, i, j)

In [146]:
def testPolicy (q_table, nEpisode = 2000):
    success = 0
    for _ in range(nEpisode):
        t = 0
        observation = env.reset()
        done  = False
        actionTable = np.argmax(q_table, axis = 1)
        while not done and t < 200:
            action = actionTable[observation]
            observation, reward, done, info = env.step(action)
            t += 1

        if reward == 1:
            success += 1
    return success / nEpisode

In [147]:
testPolicy(q_table)

0.0

In [148]:
q_table

array([[0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.]])