In [1]:
import gym
import torch
import numpy as np
import multiprocessing as mp
from a3_a2c import A2C
import copy

In [63]:
class NDA2C(A2C):
    def partial_train(self, i, params, info):
        env = copy.deepcopy(self.env)
        state = env.reset()
        score = 0
        ep_counter = 0
        gamma = params["gamma"]
        epochs = params['epochs']
        steps = params["nstep"]

        for e in range(epochs):
            done = False
            values, probs, rewards = [], [], []
            G = torch.Tensor([0])

            for step in range(steps):
                if done: break

                value = self.critic(torch.from_numpy(state).float())
                policy = self.actor(torch.from_numpy(state).float())
                action = self.choose_action(policy)

                state, reward, done, _ = env.step(action)
                values.append(value)
                probs.append(policy[action])
                rewards.append(reward)

                if done:
                    self.print_progress(ep_counter, e, epochs, i, score)
                    info = self.store_progress(info, ep_counter, score)

                    state = env.reset()
                    score = -1
                    ep_counter += 1
                else:
                    score += reward
                    G = value.detach()

            values = torch.concat(values).flip(
                dims=(0, ))
            probs = torch.stack(probs).flip(
                dims=(0, ))
            rewards = torch.Tensor(rewards).flip(
                dims=(0, ))

            returns = self.calc_returns(G, rewards, gamma)

            critic_loss = self.critic_loss_func(values, returns)

            advantage = returns - values.detach()
            actor_loss = self.actor_loss(probs, advantage)

            self.critic_optim.zero_grad()
            critic_loss.backward()
            self.critic_optim.step()

            self.actor_optim.zero_grad()
            actor_loss.backward()
            self.actor_optim.step()
        return info

    def train(self, params):
        self.actor.share_memory()
        self.critic.share_memory()

        processes = []
        info = mp.Manager().dict()
        workers = params["workers"]

        for i in range(workers):
            p = mp.Process(target=self.partial_train, args=(
                i, params, info))
            p.start()
            processes.append(p)

        for p in processes:
            p.join()

        for p in processes:
            p.terminate()

        return info


In [64]:
env = gym.make('MountainCar-v0')

agent = NDA2C(env, 150, 1e-3)

training_params = {
    'epochs': 200,  
    'gamma': 0.99,
    'nstep': 50,
    'workers': mp.cpu_count()
}

info = agent.partial_train(0, training_params, {})

torch.Size([50])
torch.Size([50])
torch.Size([50])
worker: 0, epoch:3, episode: 0, score: -199.00
torch.Size([50])
torch.Size([50])
torch.Size([50])
torch.Size([50])
torch.Size([50])
torch.Size([50])
torch.Size([50])
torch.Size([50])
torch.Size([50])
torch.Size([50])
torch.Size([50])
torch.Size([50])
torch.Size([50])
torch.Size([50])
torch.Size([50])
torch.Size([50])
torch.Size([50])
torch.Size([50])
torch.Size([50])
torch.Size([50])
torch.Size([50])
torch.Size([50])
torch.Size([50])
torch.Size([50])
torch.Size([50])
torch.Size([50])
torch.Size([50])
torch.Size([50])
torch.Size([50])
torch.Size([50])
torch.Size([50])
torch.Size([50])
torch.Size([50])
torch.Size([50])
torch.Size([50])
torch.Size([50])
torch.Size([50])
torch.Size([50])
torch.Size([50])
torch.Size([50])
torch.Size([50])
torch.Size([50])
torch.Size([50])
torch.Size([50])
torch.Size([50])
torch.Size([50])
torch.Size([50])
torch.Size([50])
torch.Size([50])
torch.Size([50])
torch.Size([50])
torch.Size([50])
torch.Size([50])
t

In [65]:
info

{0: {'score': -199.0, 'count': 1},
 1: {'score': -200.0, 'count': 1},
 2: {'score': -200.0, 'count': 1},
 3: {'score': -200.0, 'count': 1},
 4: {'score': -200.0, 'count': 1},
 5: {'score': -200.0, 'count': 1},
 6: {'score': -200.0, 'count': 1},
 7: {'score': -200.0, 'count': 1},
 8: {'score': -200.0, 'count': 1},
 9: {'score': -200.0, 'count': 1},
 10: {'score': -200.0, 'count': 1},
 11: {'score': -200.0, 'count': 1},
 12: {'score': -200.0, 'count': 1},
 13: {'score': -200.0, 'count': 1},
 14: {'score': -200.0, 'count': 1},
 15: {'score': -200.0, 'count': 1},
 16: {'score': -200.0, 'count': 1},
 17: {'score': -200.0, 'count': 1},
 18: {'score': -200.0, 'count': 1},
 19: {'score': -200.0, 'count': 1},
 20: {'score': -200.0, 'count': 1},
 21: {'score': -200.0, 'count': 1},
 22: {'score': -200.0, 'count': 1},
 23: {'score': -200.0, 'count': 1},
 24: {'score': -200.0, 'count': 1},
 25: {'score': -200.0, 'count': 1},
 26: {'score': -200.0, 'count': 1},
 27: {'score': -200.0, 'count': 1},
 2