Permalink
Browse files

Upgrade to PyTorch 0.4

Closes #5
  • Loading branch information...
Kaixhin committed Jun 1, 2018
1 parent 0d53eb4 commit 5b7ca5d75bf16629ddaf68ecab4ab6c7dcccf56c
Showing with 39 additions and 40 deletions.
  1. +2 −0 main.py
  2. +3 −5 model.py
  3. +6 −6 test.py
  4. +27 −28 train.py
  5. +1 −1 utils.py
@@ -58,6 +58,7 @@
# mp.set_start_method(platform.python_version()[0] == '3' and 'spawn' or 'fork') # Force true spawning (not forking) if available
torch.manual_seed(args.seed)
T = Counter() # Global shared counter
gym.logger.set_level(gym.logger.ERROR) # Disable Gym warnings

# Create shared network
env = gym.make(args.env)
@@ -88,6 +89,7 @@
for rank in range(1, args.num_processes + 1):
p = mp.Process(target=train, args=(rank, args, T, shared_model, shared_average_model, optimiser))
p.start()
print('Process ' + str(rank) + ' started')
processes.append(p)

# Clean up
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
from torch import nn
from torch.nn import functional as F


class ActorCritic(nn.Module):
@@ -8,19 +9,16 @@ def __init__(self, observation_space, action_space, hidden_size):
self.state_size = observation_space.shape[0]
self.action_size = action_space.n

self.relu = nn.ReLU(inplace=True)
self.softmax = nn.Softmax(dim=1)

self.fc1 = nn.Linear(self.state_size, hidden_size)
self.lstm = nn.LSTMCell(hidden_size, hidden_size)
self.fc_actor = nn.Linear(hidden_size, self.action_size)
self.fc_critic = nn.Linear(hidden_size, self.action_size)

def forward(self, x, h):
x = self.relu(self.fc1(x))
x = F.relu(self.fc1(x))
h = self.lstm(x, h) # h is (hidden state, cell state)
x = h[0]
policy = self.softmax(self.fc_actor(x)).clamp(max=1 - 1e-20) # Prevent 1s and hence NaNs
policy = F.softmax(self.fc_actor(x), dim=1).clamp(max=1 - 1e-20) # Prevent 1s and hence NaNs
Q = self.fc_critic(x)
V = (Q * policy).sum(1, keepdim=True) # V is expectation of Q under π
return policy, Q, V, h
12 test.py
@@ -3,7 +3,6 @@
from datetime import datetime
import gym
import torch
from torch.autograd import Variable

from model import ActorCritic
from utils import state_to_tensor, plot_line
@@ -35,8 +34,8 @@ def test(rank, args, T, shared_model):
if done:
# Sync with shared model every episode
model.load_state_dict(shared_model.state_dict())
hx = Variable(torch.zeros(1, args.hidden_size), volatile=True)
cx = Variable(torch.zeros(1, args.hidden_size), volatile=True)
hx = torch.zeros(1, args.hidden_size)
cx = torch.zeros(1, args.hidden_size)
# Reset environment and done flag
state = state_to_tensor(env.reset())
done, episode_length = False, 0
@@ -47,13 +46,14 @@ def test(rank, args, T, shared_model):
env.render()

# Calculate policy
policy, _, _, (hx, cx) = model(Variable(state, volatile=True), (hx.detach(), cx.detach())) # Break graph for memory efficiency
with torch.no_grad():
policy, _, _, (hx, cx) = model(state, (hx, cx))

# Choose action greedily
action = policy.max(1)[1].data[0]
action = policy.max(1)[1][0]

# Step
state, reward, done, _ = env.step(action)
state, reward, done, _ = env.step(action.item())
state = state_to_tensor(state)
reward_sum += reward
done = done or episode_length >= args.max_episode_length # Stop episodes at a max length
@@ -5,7 +5,6 @@
import torch
from torch import nn
from torch.nn import functional as F
from torch.autograd import Variable

from memory import EpisodicReplayMemory
from model import ActorCritic
@@ -45,7 +44,7 @@ def _update_networks(args, T, model, shared_model, shared_average_model, loss, o
"""
loss.backward()
# Gradient L2 normalisation
nn.utils.clip_grad_norm(model.parameters(), args.max_gradient_norm)
nn.utils.clip_grad_norm_(model.parameters(), args.max_gradient_norm)

# Transfer gradients to shared model and update
_transfer_grads_to_shared_model(model, shared_model)
@@ -65,24 +64,24 @@ def _trust_region_loss(model, distribution, ref_distribution, loss, threshold):
model.zero_grad()
loss.backward(retain_graph=True)
# Gradients should be treated as constants (not using detach as volatility can creep in when double backprop is not implemented)
g = [Variable(param.grad.data.clone()) for param in model.parameters() if param.grad is not None]
g = [param.grad.data.clone() for param in model.parameters() if param.grad is not None]
model.zero_grad()

# KL divergence k ← ∇θ0∙DKL[π(∙|s_i; θ_a) || π(∙|s_i; θ)]
kl = F.kl_div(distribution.log(), ref_distribution, size_average=False)
# Compute gradients from (negative) KL loss (increases KL divergence)
(-kl).backward(retain_graph=True)
k = [Variable(param.grad.data.clone()) for param in model.parameters() if param.grad is not None]
k = [param.grad.data.clone() for param in model.parameters() if param.grad is not None]
model.zero_grad()

# Compute dot products of gradients
k_dot_g = sum(torch.sum(k_p * g_p) for k_p, g_p in zip(k, g))
k_dot_k = sum(torch.sum(k_p ** 2) for k_p in k)
# Compute trust region update
if k_dot_k.data[0] > 0:
if k_dot_k.item() > 0:
trust_factor = ((k_dot_g - threshold) / k_dot_k).clamp(min=0)
else:
trust_factor = Variable(torch.zeros(1))
trust_factor = torch.zeros(1)
# z* = g - max(0, (k^T∙g - δ) / ||k||^2_2)∙k
z_star = [g_p - trust_factor.expand_as(k_p) * k_p for g_p, k_p in zip(g, k)]
trust_loss = 0
@@ -104,7 +103,7 @@ def _train(args, T, model, shared_model, shared_average_model, optimiser, polici
if off_policy:
rho = policies[i].detach() / old_policies[i]
else:
rho = Variable(torch.ones(1, action_size))
rho = torch.ones(1, action_size)

# Qret ← r_i + γQret
Qret = rewards[i] + args.discount * Qret
@@ -169,8 +168,8 @@ def train(rank, args, T, shared_model, shared_average_model, optimiser):

# Reset or pass on hidden state
if done:
hx, avg_hx = Variable(torch.zeros(1, args.hidden_size)), Variable(torch.zeros(1, args.hidden_size))
cx, avg_cx = Variable(torch.zeros(1, args.hidden_size)), Variable(torch.zeros(1, args.hidden_size))
hx, avg_hx = torch.zeros(1, args.hidden_size), torch.zeros(1, args.hidden_size)
cx, avg_cx = torch.zeros(1, args.hidden_size), torch.zeros(1, args.hidden_size)
# Reset environment and done flag
state = state_to_tensor(env.reset())
done, episode_length = False, 0
@@ -184,25 +183,25 @@ def train(rank, args, T, shared_model, shared_average_model, optimiser):

while not done and t - t_start < args.t_max:
# Calculate policy and values
policy, Q, V, (hx, cx) = model(Variable(state), (hx, cx))
average_policy, _, _, (avg_hx, avg_cx) = shared_average_model(Variable(state), (avg_hx, avg_cx))
policy, Q, V, (hx, cx) = model(state, (hx, cx))
average_policy, _, _, (avg_hx, avg_cx) = shared_average_model(state, (avg_hx, avg_cx))

# Sample action
action = policy.multinomial().data[0, 0] # Graph broken as loss for stochastic action calculated manually
action = torch.multinomial(policy, 1)[0, 0]

# Step
next_state, reward, done, _ = env.step(action)
next_state, reward, done, _ = env.step(action.item())
next_state = state_to_tensor(next_state)
reward = args.reward_clip and min(max(reward, -1), 1) or reward # Optionally clamp rewards
done = done or episode_length >= args.max_episode_length # Stop episodes at a max length
episode_length += 1 # Increase episode counter

if not args.on_policy:
# Save (beginning part of) transition for offline training
memory.append(state, action, reward, policy.data) # Save just tensors
memory.append(state, action, reward, policy.detach()) # Save just tensors
# Save outputs for online training
[arr.append(el) for arr, el in zip((policies, Qs, Vs, actions, rewards, average_policies),
(policy, Q, V, Variable(torch.LongTensor([[action]])), Variable(torch.Tensor([[reward]])), average_policy))]
(policy, Q, V, torch.LongTensor([[action]]), torch.Tensor([[reward]]), average_policy))]

# Increment counters
t += 1
@@ -214,14 +213,14 @@ def train(rank, args, T, shared_model, shared_average_model, optimiser):
# Break graph for last values calculated (used for targets, not directly as model outputs)
if done:
# Qret = 0 for terminal s
Qret = Variable(torch.zeros(1, 1))
Qret = torch.zeros(1, 1)

if not args.on_policy:
# Save terminal state for offline training
memory.append(state, None, None, None)
else:
# Qret = V(s_i; θ) for non-terminal s
_, _, Qret, _ = model(Variable(state), (hx, cx))
_, _, Qret, _ = model(state, (hx, cx))
Qret = Qret.detach()

# Train the network on-policy
@@ -239,34 +238,34 @@ def train(rank, args, T, shared_model, shared_average_model, optimiser):
trajectories = memory.sample_batch(args.batch_size, maxlen=args.t_max)

# Reset hidden state
hx, avg_hx = Variable(torch.zeros(args.batch_size, args.hidden_size)), Variable(torch.zeros(args.batch_size, args.hidden_size))
cx, avg_cx = Variable(torch.zeros(args.batch_size, args.hidden_size)), Variable(torch.zeros(args.batch_size, args.hidden_size))
hx, avg_hx = torch.zeros(args.batch_size, args.hidden_size), torch.zeros(args.batch_size, args.hidden_size)
cx, avg_cx = torch.zeros(args.batch_size, args.hidden_size), torch.zeros(args.batch_size, args.hidden_size)

# Lists of outputs for training
policies, Qs, Vs, actions, rewards, old_policies, average_policies = [], [], [], [], [], [], []

# Loop over trajectories (bar last timestep)
for i in range(len(trajectories) - 1):
# Unpack first half of transition
state = torch.cat((trajectory.state for trajectory in trajectories[i]), 0)
action = Variable(torch.LongTensor([trajectory.action for trajectory in trajectories[i]])).unsqueeze(1)
reward = Variable(torch.Tensor([trajectory.reward for trajectory in trajectories[i]])).unsqueeze(1)
old_policy = Variable(torch.cat((trajectory.policy for trajectory in trajectories[i]), 0))
state = torch.cat(tuple(trajectory.state for trajectory in trajectories[i]), 0)
action = torch.LongTensor([trajectory.action for trajectory in trajectories[i]]).unsqueeze(1)
reward = torch.Tensor([trajectory.reward for trajectory in trajectories[i]]).unsqueeze(1)
old_policy = torch.cat(tuple(trajectory.policy for trajectory in trajectories[i]), 0)

# Calculate policy and values
policy, Q, V, (hx, cx) = model(Variable(state), (hx, cx))
average_policy, _, _, (avg_hx, avg_cx) = shared_average_model(Variable(state), (avg_hx, avg_cx))
policy, Q, V, (hx, cx) = model(state, (hx, cx))
average_policy, _, _, (avg_hx, avg_cx) = shared_average_model(state, (avg_hx, avg_cx))

# Save outputs for offline training
[arr.append(el) for arr, el in zip((policies, Qs, Vs, actions, rewards, average_policies, old_policies),
(policy, Q, V, action, reward, average_policy, old_policy))]

# Unpack second half of transition
next_state = torch.cat((trajectory.state for trajectory in trajectories[i + 1]), 0)
done = Variable(torch.Tensor([trajectory.action is None for trajectory in trajectories[i + 1]]).unsqueeze(1))
next_state = torch.cat(tuple(trajectory.state for trajectory in trajectories[i + 1]), 0)
done = torch.Tensor([trajectory.action is None for trajectory in trajectories[i + 1]]).unsqueeze(1)

# Do forward pass for all transitions
_, _, Qret, _ = model(Variable(next_state), (hx, cx))
_, _, Qret, _ = model(next_state, (hx, cx))
# Qret = 0 for terminal s, V(s_i; θ) otherwise
Qret = ((1 - done) * Qret).detach()

@@ -31,7 +31,7 @@ def plot_line(xs, ys_population):
mean_colour = 'rgb(0, 172, 237)'
std_colour = 'rgba(29, 202, 255, 0.2)'

ys = torch.Tensor(ys_population)
ys = torch.tensor(ys_population)
ys_min = ys.min(1)[0].squeeze()
ys_max = ys.max(1)[0].squeeze()
ys_mean = ys.mean(1).squeeze()

0 comments on commit 5b7ca5d

Please sign in to comment.