In [None]:
import os
import torch
import torch.nn as nn
import torch.multiprocessing as mp

import gym
import json
import shutil

import numpy as np

from collections import deque
from itertools import product, permutations

num_processes = 2

N_EPOCH = 264000 // num_processes

obs_dim = 128 * 4
actions_dim = 4 - 1
hidden_size = 128

seed = 42

In [None]:
class ActorModel(nn.Module):
    def __init__(self, obs_shape, action_space, hidden_size):
        super(ActorModel, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(obs_shape, hidden_size), nn.Tanh(),
            nn.Linear(hidden_size, hidden_size // 2), nn.Tanh(),
            nn.Linear(hidden_size // 2, hidden_size // 4), nn.Tanh(),
            nn.Linear(hidden_size // 4, action_space), nn.Softmax(dim=0))

        for m in self.model:
            if isinstance(m, nn.Linear):
                nn.init.constant_(m.weight, 0)
                nn.init.constant_(m.bias, 1)

        self.train()

    def forward(self, inputs):
        return self.model(inputs)

    def create_eligibility_traces(self, device=None):
        if device is None:
            device = torch.device('cpu')
        traces = []
        with torch.no_grad():
            for param in self.model.parameters():
                traces.append(
                    param.data.new_zeros(param.data.size()).to(device))
        return traces

In [None]:
class CriticModel(nn.Module):
    def __init__(self, obs_shape, hidden_size):
        super(CriticModel, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(obs_shape, hidden_size), nn.Tanh(),
            nn.Linear(hidden_size, hidden_size // 2), nn.Tanh(),
            nn.Linear(hidden_size // 2, hidden_size // 4), nn.Tanh(),
            nn.Linear(hidden_size // 4, 1))

        for m in self.model:
            if isinstance(m, nn.Linear):
                nn.init.constant_(m.weight, 0)
                nn.init.constant_(m.bias, 1)


        self.eligibility_traces = []

        self.train()

    def forward(self, inputs):
        return self.model(inputs)

    def create_eligibility_traces(self, device=None):
        if device is None:
            device = torch.device('cpu')
        traces = []
        with torch.no_grad():
            for param in self.model.parameters():
                traces.append(
                    param.data.new_zeros(param.data.size()).to(device))
        return traces

In [None]:
def update_state(ram_bytes, device=None):
    if device is None:
        device = torch.device("cpu")

    update_state.frames_buffer
    frame = ram_bytes / torch.Tensor([255.0]).to(device)
    if update_state.frames_buffer is None:
        update_state.frames_buffer = deque([torch.zeros(frame.size()).to(device)] * 4, maxlen=4)
    update_state.frames_buffer.appendleft(frame)
    return torch.stack(list(update_state.frames_buffer)).flatten().to(device)
update_state.frames_buffer = None

def save_checkpoint(state, is_best, filename='checkpoint_a.pth.tar'):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')


def cosine_annealing_eps(epoch, eps_min, eps_max, epoch_max, device=None):
    if device is None:
        device = torch.device('cpu')
    return torch.Tensor([eps_min + 0.5 * (eps_max - eps_min) * (1 + np.cos(epoch / epoch_max * np.pi))]
                        ).to(device=device, dtype=torch.float)


def cosine_annealing_lr(epoch, lr_min, lr_max, epoch_max, device=None):
    if device is None:
        device = torch.device('cpu')
    return torch.Tensor([lr_min + 0.5 * (lr_max - lr_min) * (1 + np.cos(epoch/epoch_max * np.pi))]).to(device=device, dtype=torch.float)


def load_actor_critic(filename, actor_model, critic_model):
    print("=> loading checkpoint '{}'".format(filename))
    if os.path.isfile(filename):
        checkpoint = torch.load(filename)
        actor_model.load_state_dict(checkpoint['actor_dict'])
        critic_model.load_state_dict(checkpoint['critic_dict'])
        return checkpoint['iteration_number']
    else:
        print("=> no checkpoint found at '{}'".format(filename))


def uniform_interval_random_sampling(start, end, step_in_between):
    intervals = [start - (start - end) / step_in_between * _ for _ in range(step_in_between + 1)]
    intervals_pairs = list(zip(intervals[:-1], intervals[1:]))
    sampling = [start] + [np.random.uniform(_, __) for _, __ in intervals_pairs] + [end]
    return sampling



In [None]:
def train(*args):
    args = args[0]
    rank = args[0]
    actor_model = args[1]
    critic_model = args[2]
    iteration_number = args[3]
    gamma = args[4]
    lamda_actor = args[5]
    lamda_critic = args[6]
    alpha_actor = args[7]
    alpha_critic = args[8]
    save_models = args[9]

    env = gym.make('Breakout-ram-v4')

    torch.manual_seed(seed + rank)

    ln = torch.log

    total_score = 0

    for epoch in range(N_EPOCH):
        frame = env.reset()
        reset_state(frame)
        frame = torch.from_numpy(frame).to(device).to(torch.float)

        actor_eligibility_traces = actor_model.create_eligibility_traces()
        critic_eligibility_traces = critic_model.create_eligibility_traces()


        I = 1
        real_score = 0
        lr_actor = cosine_annealing_lr(epoch, lr_min=0.0000001, lr_max=alpha_actor, epoch_max=N_EPOCH)
        lr_critic = cosine_annealing_lr(epoch, lr_min=0.0000001, lr_max=alpha_critic, epoch_max=N_EPOCH)
        while True:

            last_frame = frame
            frame, reward, is_done, infos = env.step(1)
            frame = torch.from_numpy(frame).to(device).to(torch.float)

            state = update_state(frame - last_frame)
            lives = infos['ale.lives']

            while infos['ale.lives'] == lives and not is_done:
                policy = actor_model(state)
                action_probs = policy.detach().numpy()
                if (rank < 2):
                    action = np.random.choice(range(actions_dim), p=action_probs)
                else:
                    action = np.argmax(action_probs)
                    
                env.render()

                one_hot = torch.zeros(action_probs.shape)
                one_hot[action] = policy[action]

                ln_policy = ln(policy)

                if action > 0:
                    action += 1

                last_frame = frame
                frame, reward, is_done, infos = env.step(action)
                frame = torch.from_numpy(frame).to(device).to(torch.float)
                next_state = update_state(frame - last_frame)

                real_score += reward

                if infos['ale.lives'] != lives:
                    reward = -1

                if is_done:
                    assert True

                if infos['ale.lives'] != lives:
                    next_state_value = torch.Tensor([0])
                    total_score += real_score
                else:
                    with torch.no_grad():
                        next_state_value = critic_model(next_state)

                current_state_value = critic_model(state)

                delta = reward + gamma * next_state_value - current_state_value

                actor_model.zero_grad()
                critic_model.zero_grad()
                delta.backward()
                ln_policy.backward(one_hot)

                with torch.no_grad():
                    actor_params = list(actor_model.parameters())
                    critic_params = list(critic_model.parameters())
                    for i in range(len(critic_params)):
                        updated_trace = gamma * lamda_actor * critic_eligibility_traces[i] + critic_params[i].grad.data
                        critic_eligibility_traces[i] = updated_trace
                        regularized = (1 - critic_params[i].data.norm(2) * 0.1) * critic_params[i].data
                        updated = regularized + lr_critic * delta * critic_eligibility_traces[i].data
                        critic_params[i].data = updated
                        assert True

                    for i in range(len(actor_params)):
                        updated_trace = gamma * lamda_actor * actor_eligibility_traces[i].data + I * actor_params[i].grad.data
                        actor_eligibility_traces[i].data = updated_trace
                        regularized = (1 - actor_params[i].data.norm(2) * 0.1) * actor_params[i].data
                        updated = regularized + lr_actor * delta * actor_eligibility_traces[i].data
                        actor_params[i].data = updated
                        assert True

                I = gamma * I
                state = next_state

            if is_done:
                save_model(rank, epoch, real_score, iteration_number, save_models, actor_model, critic_model)
                break

    print(json.dumps({'gamma': gamma, 'lambda_actor': lamda_actor, 'lambda_critic': lamda_critic , 'rank': rank, 'total_score': total_score, 'total_score_per_epoch': total_score / N_EPOCH, 'total_score_per_epoch_per_live': total_score / N_EPOCH / 5}) + ',')

def save_model(rank, epoch, real_score, iteration_number, save_models, actor_model, critic_model):
    print('[proc{}:epoch{}] Score: {}'.format(
        rank, epoch, real_score))
    if save_models and epoch % 10 == 0 and epoch > 0:
        fname = "saved_checkpoint_{}_{}.pth.tar".format(
            epoch, iteration_number)
        print(fname)
        save_checkpoint({
            'iteration_number': iteration_number + 1,
            'actor_dict': actor_model.state_dict(),
            'critic_dict': critic_model.state_dict()
        }, False, filename=fname)

def reset_state(frame):
    update_state(torch.from_numpy(np.zeros(frame.shape)).to(device).to(torch.float))
    update_state(torch.from_numpy(np.zeros(frame.shape)).to(device).to(torch.float))
    update_state(torch.from_numpy(np.zeros(frame.shape)).to(device).to(torch.float))
    update_state(torch.from_numpy(np.zeros(frame.shape)).to(device).to(torch.float))


In [None]:
torch.manual_seed(seed)

dtype = torch.float
device = torch.device("cpu")
# device = torch.device("cuda:0") # Uncomment this to run on GPU

gamma_sampling = uniform_interval_random_sampling(0.01, 0.9999, 20)
lambda_actor_sampling = uniform_interval_random_sampling(0.1, 0.9, 20)
lambda_critic_sampling = uniform_interval_random_sampling(0.1, 0.9, 20)

GAMMA = 0.9
ALPHA_ACTOR = 0.0025
ALPHA_CRITIC = 0.0025

hyperparams_triples = list(product(gamma_sampling, lambda_actor_sampling, lambda_critic_sampling))
np.random.shuffle(hyperparams_triples)


for gamma, lamda_actor, lamda_critic in hyperparams_triples:
    iteration_number = 0

    actor = ActorModel(obs_dim, actions_dim, hidden_size)
    actor.to(device=device, dtype=dtype)

    critic = CriticModel(obs_dim, hidden_size)
    critic.to(device=device, dtype=dtype)

    actor.share_memory()
    critic.share_memory()

    with mp.Pool(processes=num_processes) as pool:
        args = [(rank, actor, critic, iteration_number,
                 gamma, lamda_actor, lamda_critic, ALPHA_ACTOR, ALPHA_CRITIC, rank == 0) for rank in range(num_processes)]
        pool.map(train, args)