In [None]:
### Get the game ready.

from gym.envs import atari
import torch  

games = ["adventure",
         "air_raid",
         "alien",
         "crazy_climber",
         "elevator_action",
         "gravitar",
         "keystone_kapers",
         "king_kong",
         "laser_gates",
         "mr_do",
         "ms_pacman",
         "jamesbond",
         "koolaid",
         "zaxxon"
         # There are way more than this!
         ]

env = atari.AtariEnv(
    game='pong',
    frameskip=3,
    obs_type='image'
)
env.reset()

num_actions = len(env.get_action_meanings())

# If some buttons are unused, ignore them. In Pong, I've decided to use just 3 moves.
num_actions = 3

def pick_move(move):
    move = move.item()
    if(num_actions == len(env.get_action_meanings())):
       return(move)
    move_dict = {
        0 : 0,
        1 : 2,
        2 : 3
    }
    return(move_dict[move])

print(env.get_action_meanings())

for i in range(num_actions):
    print(i, ", ", env.get_action_meanings()[pick_move(torch.tensor([i]))])

In [None]:
### How to make images.
            
import matplotlib.pyplot as plt    
import torch.nn.functional as F 
import torchvision.transforms.functional as F2

device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 
        
def image_to_ten(image):
    image = torch.from_numpy(image).float()
    image = image[34:-16,:,:] # I cut out the top and bottom, useless in Pong. 
    image = image.permute(-1,0,1)

    image = F2.to_pil_image(image)
    image = F2.to_grayscale(image)
    image = F2.resize(image, 32)
    image = F2.to_tensor(image)

    image = image - torch.ones(image.shape) * image.min().item()
    image = image / image.max().item()

    theshold = .99
    image[image < theshold] = torch.zeros(image[image < theshold].shape)
    image[image > theshold] = torch.ones(image[image > theshold].shape)

    image = image * 2 - torch.ones(image.shape)

    return(image.to(device))

def show_image(image):
    image = (image + torch.ones(image.shape).to(device)) / 2
    image = image.permute(1,2,0).cpu()
    if(image.shape[-1] == 1):
        image = torch.cat([image, image, image],dim=-1)
    plt.figure(figsize = (5,5))
    plt.imshow(image)
    plt.axis('off')
    plt.show()
    plt.close()
    plt.ioff()
    
    
    
### Play at random for a bit, just to show an interesting image.

import random

def move_random():
    env.step(pick_move(torch.tensor([random.randrange(num_actions)])))
    
for i in range(50):
    move_random()
    
    

image_1 = env._get_image()
image_1 = image_to_ten(image_1)
show_image(image_1)

two_images = torch.cat([image_1, image_1])

In [None]:
### Policy converting observations into action-probabilities. 

import torch.nn as nn
from torchsummary import summary
import torch.optim as optim
import numpy as np
    
class ConstrainedConv2d(nn.Conv2d):
    def forward(self, input):
        return nn.functional.conv2d(input, self.weight.clamp(min=-1.0, max=1.0), self.bias, self.stride,
                        self.padding, self.dilation, self.groups)

class DQN(nn.Module):

    def __init__(self):
        super(DQN, self).__init__()
        
        example = torch.zeros(two_images.shape).unsqueeze(0)

        self.cnn = nn.Sequential(
            ConstrainedConv2d(
                in_channels = 2, 
                out_channels = 4,
                kernel_size = (3,3),
                stride = (2,2),
                padding = (1,1)
            ),
            nn.LeakyReLU(),
            ConstrainedConv2d(
                in_channels = 4, 
                out_channels = 4,
                kernel_size = (3,3),
                stride = (2,2),
                padding = (1,1)
            ),
            nn.LeakyReLU()
        )
        
        example = self.cnn(example)
        quantity = np.product(example.shape)
  
        self.lin = nn.Sequential(
            nn.Linear(
                in_features = quantity,
                out_features = num_actions),
            nn.Softmax(dim = -1))
            
    def forward(self, x):
        if(len(x.shape) == 3):
            x = x.unsqueeze(0)
        x = self.cnn(x)
        x = torch.flatten(x, start_dim = 1)
        x = self.lin(x)
        return(x)
        
    def get_move(self, move):
        moves = [i for i in range(num_actions)]
        move = move.squeeze(0).tolist()
        move = torch.tensor(random.choices(moves, weights = move)).to(device).squeeze(0)
        return(move)
    

        
dqn = DQN().to(device)
opti = optim.Adam(dqn.parameters())

# If you've already started training, load the saved model
#import os
#os.chdir("C://Users//tedjt//Desktop//Thinkster//77 More Pong//code")
#dqn.load_state_dict(torch.load("model.pt"))

print(dqn)
print()

summary(dqn, two_images.shape)

In [None]:
### How to manage rewards

GAMMA = 0.99
            
def discount_rewards(r):
    discounted_r = torch.zeros(r.shape).to(device)
    length = list(r.shape)[0]
    running_add = 0
    for t in reversed(range(0, length)):
        if r[t].item() != 0: 
            running_add = 0 
        running_add = running_add * GAMMA + r[t].item() 
        discounted_r[t] = running_add
    return discounted_r

# For example:
rewards = torch.zeros((24,))
rewards[len(rewards)//3 - 1] = 1
rewards[-1] = -1
print(rewards)
print(discount_rewards(rewards))

In [None]:
### How to remember stuff
            
class memory():
    def __init__(self):
        self.mem = []
        
    def push(self, *args):
        self.mem.append(*args)
        
    def empty(self):
        self.mem = []
        
mem = memory()

In [None]:
### Update the policy based on memory

import datetime
from IPython.display import clear_output

start_time = datetime.datetime.now()

def duration():
    change_time = datetime.datetime.now() - start_time
    change_time = change_time - datetime.timedelta(microseconds=change_time.microseconds)
    return(change_time)

def update():
    batch = mem.mem
    mem.empty()
        
    moves_batch = torch.cat([b[0] for b in batch], dim = 0)
    
    move_batch = torch.cat([b[1].unsqueeze(0) for b in batch], dim = 0)
    move_batch = F.one_hot(move_batch, num_classes = num_actions).float()
    
    reward_batch = torch.cat([b[2].unsqueeze(0) for b in batch], dim = 0)
    total_reward = int(reward_batch.sum().item())
    reward_batch = discount_rewards(reward_batch)
    reward_batch = torch.cat([reward_batch.unsqueeze(-1) for i in range(num_actions)], dim = -1)
    
    clear_output()
    print("Reward:", total_reward)
    print("Min/Max:", round(moves_batch.min().item(),3), round(moves_batch.max().item(),3))
    print("Time:", duration())

    criterion = nn.BCELoss(weight = reward_batch)
    loss = criterion(moves_batch, move_batch)
    opti.zero_grad()
    loss.backward()
    for param in dqn.parameters():
        param.grad.data.clamp_(-1, 1)
    opti.step()
    
    return(total_reward, moves_batch.min().item(), moves_batch.max().item())

In [None]:
### How to train!
    
def train():
    last_state = None
    state = image_to_ten(env.reset())
    # In games besides Pong, negative-rewarding life-loss might help
    #old_lives = 5
    while(True):
        env.render()
        #show_image(state)
        observation = torch.cat([state, last_state]) if last_state != None else torch.cat([state, state])
        moves = dqn(observation)
        move = dqn.get_move(moves)
        next_state, reward, done, lives = env.step(pick_move(move))
        #lives = lives["ale.lives"]
        #if(lives < old_lives):
        #    reward -= 1
        #    old_lives = lives
        next_state = image_to_ten(next_state)
        mem.push((moves, move, torch.tensor(reward).to(device)))
        last_state = state
        state = next_state
        if(done):
            break
    torch.save(dqn.state_dict(), "model.pt")
    return(update())

In [None]:
### Plot training.
    
plot_len = 100

def plot_it(rewards, mins, maxes):
    xs = [i for i in range(len(rewards))]
    
    if(xs[-1] > plot_len):
        xs = xs[-plot_len:]
        rewards = rewards[-plot_len:]
        mins = mins[-plot_len:]
        maxes = maxes[-plot_len:]
    
    plt.plot(xs, rewards)
    plt.plot(xs, [0 for i in range(len(xs))])
    plt.figure(figsize = (5,5))
    plt.show()
    plt.close()
    plt.ioff()
    
    plt.plot(xs, mins)
    plt.plot(xs, maxes)
    plt.figure(figsize = (5,5))
    plt.show()
    plt.close()
    plt.ioff()

In [None]:
### Train!
    
rewards = []
mins = []
maxes = []
i = 0

while(True):
    i += 1
    reward, m_1, m_2 = train()
    print("Games:", i)
    rewards.append(reward)
    mins.append(m_1)
    maxes.append(m_2)
    plot_it(rewards, mins, maxes)