In [31]:
import gym
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import pyglet
from itertools import count
import math
import random
from PIL import Image
from torch.distributions import Categorical
from gym import ObservationWrapper
from gym import ActionWrapper
import os

In [32]:
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [33]:
class PongObsWrapper(ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        
    def observation(self, image):
        image = image[35:195] # crop
        image = image[::2,::2,0] # downsample by factor of 2
        image[image == 144] = 0 # erase background (background type 1)
        image[image == 109] = 0 # erase background (background type 2)
        image[image != 0] = 1 # everything else (paddles, ball) just set to 1
        return np.reshape(image.astype(np.float32).ravel(), [1, 80,80])
    
class PongActionWrapper(ActionWrapper):
    def __init__(self, env):
        super().__init__(env)
        self.action_ = {0: 2, 1: 3}
    
    def action(self, action):
        return self.action_[action]

In [34]:
class net(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(net, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=input_dim[0], out_channels=6, kernel_size=3, stride=1, padding=1),
            nn.ReLU(True),
            nn.MaxPool2d(2,2),
            nn.Conv2d(6,16,5,1,2),
            nn.ReLU(True),
            nn.MaxPool2d(4,2),   # 16 10 10
            nn.Flatten()    # 16*10*10
        )
        
        with torch.no_grad():
            latent_dim = np.prod(self.conv1(torch.zeros(1, *input_dim)).shape[1:])
        
        self.fc = nn.Sequential(
            self.conv1,
            nn.Linear(latent_dim, 256),
            nn.ReLU(True),
            nn.Linear(256, output_dim)
        )

    def forward(self, x):
        out = self.fc(x.float())
        return out

In [35]:
env = PongActionWrapper(PongObsWrapper(gym.make('Pong-v0')))
a_d = 2
gamma = 0.99
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
a_lr = 1e-2
c_lr = 1e-1
input_dim = (1,80,80)
num_episodes = 10000
batchsize = 10
num_batch = int(num_episodes/batchsize)

In [36]:
class a2c:
    def __init__(self, input_dim, a_d, gamma, a_lr, c_lr, device):
        self.actor = net(input_dim, a_d).to(device)
        self.critic = net(input_dim, 1).to(device)
        self.actor_optim = torch.optim.Adam(self.actor.parameters(), lr = a_lr)
        self.critic_optim = torch.optim.Adam(self.critic.parameters(), lr = c_lr)
        self.gamma = gamma
    
    def forward(self, s_input):
        return Categorical(torch.softmax(self.actor(s_input), -1))
    
    def act(self, s):
        sts = torch.tensor(s, dtype = torch.float).to(device)
        s_input = sts.unsqueeze(0)
        a = self.forward(s_input).sample().item()
        return a
    
    def l2t(self, List):  # list-to-tensor function, delete the last entry
        if len(List) > 0:
            return torch.tensor(List[0:-1], dtype = torch.float)

    def upd(self, traj):
        s_tensor = torch.tensor(traj['s'], dtype = torch.float).to(device)
        a_tensor = torch.tensor(traj['a'], dtype = torch.int64).view(-1,1).to(device)
        s__tensor = torch.tensor(traj['s_'], dtype = torch.float).to(device)
        r_tensor = torch.tensor(traj['r'], dtype = torch.float).view(-1,1).to(device)
        dones_tensor = torch.tensor(traj['dones'], dtype = torch.float).view(-1,1).to(device)
        
        td_stp1 = r_tensor + self.gamma * (1-dones_tensor)*self.critic(s__tensor).float()
        td_err = td_stp1 - self.critic(s_tensor).float()
        a_loss = -torch.mean(torch.log(self.actor(s_tensor)).gather(1,a_tensor) * td_err.detach()).float()
        c_loss = F.mse_loss(td_stp1.detach(), self.critic(s_tensor)).float()
        self.actor_optim.zero_grad()
        self.critic_optim.zero_grad()
        a_loss.backward()
        c_loss.backward()
        self.actor_optim.step()
        self.critic_optim.step()

In [37]:
agent = a2c(input_dim, a_d, gamma, a_lr, c_lr, device)
return_list = []

In [38]:
with open("pong_data.csv", "w") as f:
    f.write(
            "step, episode length, total return\n"
        )

In [39]:
for i in range(num_batch):
    with tqdm(total=batchsize, desc='Iter %d' % i) as pbar:
        for i_episode in range(batchsize):
            total_return = 0
            traj = {'s': [], 'a': [], 's_': [], 'r': [], 'dones': []}
            s = env.reset()
            done = False
            while not done:
                # print(s)
                a = agent.act(s)
                s_, r, done, _ = env.step(a)
                traj['s'].append(s)
                traj['a'].append(a)
                traj['s_'].append(s_)
                traj['r'].append(r)
                traj['dones'].append(done)
                s = s_
                total_return = total_return + r
            return_list.append(total_return)
            # print(len(traj['r']))
            agent.upd(traj)
            with open("pong_data.csv", "a") as f:
                f.write(
                    "{},{},{}\n".format(
                        i*batchsize + i_episode,
                        len(traj['dones']),
                        total_return
                    )
                )
            if (i_episode+1) % 10 == 0:
                pbar.set_postfix({'epis': '%d' % (batchsize * i + i_episode+1), 'mean return': '%.5f' % np.mean(return_list[-10:])})
            pbar.update(1)

Iter 0: 100%|██████████| 10/10 [00:23<00:00,  2.36s/it, epis=10, mean return=-21.00000]
Iter 1: 100%|██████████| 10/10 [00:23<00:00,  2.34s/it, epis=20, mean return=-21.00000]
Iter 2: 100%|██████████| 10/10 [00:23<00:00,  2.33s/it, epis=30, mean return=-21.00000]
Iter 3: 100%|██████████| 10/10 [00:23<00:00,  2.34s/it, epis=40, mean return=-21.00000]
Iter 4: 100%|██████████| 10/10 [00:23<00:00,  2.32s/it, epis=50, mean return=-21.00000]
Iter 5: 100%|██████████| 10/10 [00:23<00:00,  2.36s/it, epis=60, mean return=-21.00000]
Iter 6: 100%|██████████| 10/10 [00:23<00:00,  2.34s/it, epis=70, mean return=-21.00000]
Iter 7: 100%|██████████| 10/10 [00:23<00:00,  2.34s/it, epis=80, mean return=-21.00000]
Iter 8: 100%|██████████| 10/10 [00:23<00:00,  2.36s/it, epis=90, mean return=-21.00000]
Iter 9: 100%|██████████| 10/10 [00:23<00:00,  2.34s/it, epis=100, mean return=-21.00000]
Iter 10: 100%|██████████| 10/10 [00:23<00:00,  2.34s/it, epis=110, mean return=-21.00000]
Iter 11: 100%|██████████| 10/

KeyboardInterrupt: 

In [None]:
episodes_list = list(range(len(return_list)))
plt.plot(episodes_list,return_list)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('A2C on {}'.format("Pong-v0"))
plt.show()