In [1]:
import os
import cv2
import csv
import math
import h5py
import time
import pickle
import random
import argparse
import numpy as np
from datetime import datetime

from skimage.transform import pyramid_gaussian, resize

import torch
from torch import optim
from torch import nn
from torch.nn import functional as F
import multiprocessing as mp
import threading

from collections import deque, namedtuple

# Model Parameters

In [2]:
parser = argparse.ArgumentParser(description='ACER')
parser.add_argument('--seed', type=int, default=123, help='Random seed')
parser.add_argument('--num-processes', type=int, default=3, metavar='N', help='Number of training async agents (does not include single validation agent)')
parser.add_argument('--T-max', type=int, default=100000000, metavar='STEPS', help='Number of training steps')
parser.add_argument('--t-max', type=int, default=100, metavar='STEPS', help='Max number of forward steps for A3C before update')
parser.add_argument('--max-episode-length', type=int, default=15, metavar='LENGTH', help='Maximum episode length')
parser.add_argument('--hidden-size', type=int, default=49, metavar='SIZE', help='Hidden size of LSTM cell')
parser.add_argument('--n-dr-elements', type=int, default=49, metavar='SIZE', help='Number of objects in display')
parser.add_argument('--present-action', type=int, default=49, metavar='SIZE', help='Present Action Value')
parser.add_argument('--absent-action', type=int, default=50, metavar='SIZE', help='Absent Action Value')
parser.add_argument('--memory-capacity', type=int, default=100000, metavar='CAPACITY', help='Experience replay memory capacity')
parser.add_argument('--replay-ratio', type=int, default=4, metavar='r', help='Ratio of off-policy to on-policy updates')
parser.add_argument('--replay-start', type=int, default=20000, metavar='EPISODES', help='Number of transitions to save before starting off-policy training')
parser.add_argument('--discount', type=float, default=0.99, metavar='γ', help='Discount factor')
parser.add_argument('--trace-decay', type=float, default=1, metavar='λ', help='Eligibility trace decay factor')
parser.add_argument('--trace-max', type=float, default=10, metavar='c', help='Importance weight truncation (max) value')
parser.add_argument('--trust-region-decay', type=float, default=0.99, metavar='α', help='Average model weight decay rate')
parser.add_argument('--trust-region-threshold', type=float, default=1, metavar='δ', help='Trust region threshold value')
parser.add_argument('--lr', type=float, default=0.0001, metavar='η', help='Learning rate')
parser.add_argument('--rmsprop-decay', type=float, default=0.99, metavar='α', help='RMSprop decay factor')
parser.add_argument('--batch-size', type=int, default=8, metavar='SIZE', help='Off-policy batch size')
parser.add_argument('--entropy-weight', type=float, default=0.0001, metavar='β', help='Entropy regularisation weight')
parser.add_argument('--max-gradient-norm', type=float, default=40, metavar='VALUE', help='Gradient L2 normalisation')
parser.add_argument('--evaluation-interval', type=int, default=25000, metavar='STEPS', help='Number of training steps between evaluations (roughly)')
parser.add_argument('--evaluation-episodes', type=int, default=10, metavar='N', help='Number of evaluation episodes to average over')
parser.add_argument('--name', type=str, default='results', help='Save folder')
parser.add_argument('--on-policy', action='store_true', help='Use pure on-policy training (A3C)')
parser.add_argument('--trust-region', action='store_true', help='Use trust region')
parser.add_argument('--pretrain-model-available', action='store_true', help='Use pre trained model weights')

N_DR_ELEMENTS = 49
N_ACTIONS = N_DR_ELEMENTS + 2
PRESENT = N_DR_ELEMENTS
ABSENT = N_DR_ELEMENTS+1
MAX_STEPS = 15
N_ROWS = 7
PATH="./Lab_Model"
IMG_HEIGHT = 64
IMG_WIDTH = 64
IMG_CHANNELs = 3

PYRAMID_LEVEL = 4 #excluding the original image as a level.

CONVERSION_FACTOR = 6/15.5 #number of pixels per degree. Here, 15.5 is the display size in degree (size used in the experiment).

FOVEA = 2 #in degrees

FOVEA_RADIUS = int(np.round(FOVEA/CONVERSION_FACTOR)) #In pixels

# Utility Functions

In [3]:
class Counter():
    def __init__(self):
        self.val = mp.Value('i', 0)
        self.lock = mp.Lock()
    
    def increment(self):
        with self.lock:
            self.val.value += 1

    def value(self):
        with self.lock:
            return self.val.value

def state_to_tensor(state):
    return torch.from_numpy(state).float().unsqueeze(0)

#setup mapping of fixated location to its corresponding x,y coordinate in the image.
fixated_location = 0
FIXATION_DICT = {}
hor = 5
ver = 5
for row in range(7):
    for col in range(7):
        FIXATION_DICT[str(fixated_location)] = (hor, ver)
        fixated_location += 1
        hor = hor +  9
    ver = ver +  9
    hor = 5

# Replay Memory

In [4]:
Transition = namedtuple('Transition', ('state', 'action', 'reward', 'policy'))

In [5]:
class EpisodicReplayMemory():
    def __init__(self, capacity, max_episode_length):
        # Max number of transitions possible will be the memory capacity, could be much less
        self.num_episodes = capacity // max_episode_length
        self.memory = deque(maxlen=self.num_episodes)
        self.trajectory = []

    def append(self, state, action, reward, policy):
        self.trajectory.append(Transition(state, action, reward, policy))  # Save s_i, a_i, r_i+1, µ(·|s_i)
        # Terminal states are saved with actions as None, so switch to next episode
        if action is None:
            self.memory.append(self.trajectory)
            self.trajectory = []
        
    # Samples random trajectory
    def sample(self, maxlen=0):
        mem = self.memory[random.randrange(len(self.memory))]
        T = len(mem)
        # Take a random subset of trajectory if maxlen specified, otherwise return full trajectory
        if maxlen > 0 and T > maxlen + 1:
            t = random.randrange(T - maxlen - 1)  # Include next state after final "maxlen" state
            return mem[t:t + maxlen + 1]
        else:
            return mem
    
    # Samples batch of trajectories, truncating them to the same length
    def sample_batch(self, batch_size, maxlen=0):
        batch = [self.sample(maxlen=maxlen) for _ in range(batch_size)]
        minimum_size = min(len(trajectory) for trajectory in batch)
        batch = [trajectory[:minimum_size] for trajectory in batch]  # Truncate trajectories
        return list(map(list, zip(*batch)))  # Transpose so that timesteps are packed together

    def length(self):
        # Return number of epsiodes saved in memory
        return len(self.memory)

    def __len__(self):
        return sum(len(episode) for episode in self.memory)

# Optimiser

In [6]:
class SharedRMSprop(optim.RMSprop):
    def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0):
        super(SharedRMSprop, self).__init__(params, lr=lr, alpha=alpha, eps=eps, weight_decay=weight_decay, momentum=0, centered=False)

        # State initialisation (must be done before step, else will not be shared between threads)
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                state['step'] = p.data.new().resize_(1).zero_()
                state['square_avg'] = p.data.new().resize_as_(p.data).zero_()

    def share_memory(self):
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                state['step'].share_memory_()
                state['square_avg'].share_memory_()

    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data
                state = self.state[p]

                square_avg = state['square_avg']
                alpha = group['alpha']

                state['step'] += 1

                if group['weight_decay'] != 0:
                    grad = grad.add(group['weight_decay'], p.data)
                    
                # g = αg + (1 - α)Δθ^2
                square_avg.mul_(alpha).addcmul_(1 - alpha, grad, grad)
                
                # θ ← θ - ηΔθ/√(g + ε)
                avg = square_avg.sqrt().add_(group['eps'])
                p.data.addcdiv_(-group['lr'], grad, avg)
        
        return loss

# Network

In [7]:
class Unit(nn.Module):
    def __init__(self,in_channels,out_channels, padding):
        super(Unit,self).__init__()
        

        self.conv = nn.Conv2d(in_channels=in_channels,kernel_size=3,out_channels=out_channels,stride=1,padding=padding)
        self.bn = nn.BatchNorm2d(num_features=out_channels)
        self.relu = nn.ReLU()

    def forward(self,input):
        output = self.conv(input)
        output = self.bn(output)
        output = self.relu(output)
        return output

class CNN_Module(nn.Module):
    def __init__(self):
        super(CNN_Module, self).__init__()
        
        self.unit1 = Unit(in_channels=3,out_channels=8, padding=0)
        self.unit2 = Unit(in_channels=8,out_channels=16, padding=0)
        self.unit3 = Unit(in_channels=16,out_channels=24, padding=0)
        self.unit4 = Unit(in_channels=24,out_channels=32, padding=0)
        
        self.pool1 = nn.MaxPool2d(kernel_size=2)

        self.unit5 = Unit(in_channels=32,out_channels=32, padding=1)
        self.unit6 = Unit(in_channels=32,out_channels=32, padding=1)
        self.unit7 = Unit(in_channels=32,out_channels=32, padding=1)
        
        self.pool2 = nn.MaxPool2d(kernel_size=2)
        
        self.unit8 = Unit(in_channels=32,out_channels=32, padding=1)
        self.unit9 = Unit(in_channels=32,out_channels=32, padding=1)
        
        self.pool3 = nn.MaxPool2d(kernel_size=2)
        
        self.unit10 = Unit(in_channels=32,out_channels=32, padding=1)
        
        self.net = nn.Sequential(self.unit1, self.unit2, self.unit3, self.unit4, self.pool1, self.unit5, self.unit6, \
                                 self.unit7, self.pool2, self.unit8, self.unit9, self.pool3,  self.unit10)

        self.fc1 = nn.Linear(in_features=7 * 7 * 32, out_features=392)
        self.activation1 = nn.ReLU()
        self.dropout = nn.Dropout(0.2)
        self.fc2 = nn.Linear(in_features=392, out_features=98)
        self.activation2 = nn.Sigmoid()

    def forward(self, input):
        input = input.view(-1, 3, 64, 64)
        output = self.net(input)
        output = output.view(-1, 7 * 7 * 32)
        output = self.fc1(output)
        output = self.activation1(output)
        output = self.fc2(output)
        output = self.activation2(output)
        return output

class ActorCritic(nn.Module):
    def __init__(self, state_size, action_size, hidden_size):
        super(ActorCritic, self).__init__()
        
        #self.cnn = CNN_Module()
        self.fc1 = nn.Linear(state_size, hidden_size)
        self.lstm = nn.LSTMCell(hidden_size, hidden_size)
        self.fc_actor = nn.Linear(hidden_size, action_size)
        self.fc_critic = nn.Linear(hidden_size, action_size)

    def forward(self, x, h):   
        x = self.fc1(x)
        h = self.lstm(x, h)  # h is (hidden state, cell state)
        x = h[0]
        policy = F.softmax(self.fc_actor(x), dim=1).clamp(max=1 - 1e-20)  # Prevent 1s and hence NaNs
        Q = self.fc_critic(x)
        V = (Q * policy).sum(1, keepdim=True)  # V is expectation of Q under π
        return policy, Q, V, h

# Environment

In [8]:
class Env():
    def __init__(self, args):
        self.num_feats = N_DR_ELEMENTS
        self.num_actions = N_ACTIONS
        self.steps = 0
        self.total_time = 0.0
        self.image = None
        self.state = None
        self.model = CNN_Module()
        self.model.load_state_dict(torch.load(PATH, map_location='cpu'))
        #self.model.to(torch.device('cpu'))
        self.correct = 0
        self.target_present = False
        
        path = os.path.join('.','dr_data.h5')
        self.dr_data = h5py.File(path, 'r')
        
        
    
    def step(self, action):
        
        self.steps += 1
        done = False
        reward = -0.1
        info = ''
        
        if action < N_DR_ELEMENTS:
            fixation_loc = FIXATION_DICT[str(action)]
            fixate_x = int(action / N_ROWS)
            fixate_y = int(action % N_ROWS)
            input_image = self.sampling(self.image, fixation_loc[0], fixation_loc[1])
            with torch.no_grad():
                self.prob_out = self.model(torch.from_numpy(input_image).float().to(torch.device('cpu')))
            self.state = self.prob_out[0].detach().numpy()
        
        elif (action == PRESENT and self.target_present) or (action == ABSENT and not self.target_present):
            reward = 2.0
            done = True
            self.correct = 1
        
        else:
            reward = -2.0
            done = True
            self.correct = 0
        
        if self.steps >= MAX_STEPS:
            done = True
            self.correct = 0

        
        return self.state.flatten(), reward, done, info
            
    def reset(self):
        idx = np.random.randint(len(self.dr_data["Images"]))
        self.image = self.dr_data["Images"][idx]
        self.image = self.image.astype('uint8')
        
        self.steps = 0
        self.total_time = 0.0
        
        self.correct = 0
        self.target_present = True if self.dr_data["target_status"][idx] == 1 else False
        
        self.state = np.zeros((1, N_DR_ELEMENTS+N_DR_ELEMENTS))
        
        return self.state.flatten()
        
        
        
    def get_eccentricity(self, fixated_x, fixated_y):
        #Generate a mask with shape similar to the image.
        mask = 255*np.ones((IMG_HEIGHT, IMG_WIDTH), dtype='uint8')
    
        #Fovea is represented as a circle at fixated_x, fixated_y of radius FOVEA_RADIUS.
        cv2.circle(mask, (fixated_x, fixated_y), FOVEA_RADIUS, 0, -1)
    
        #Apply distance transform to mask. Open cv implementation of ecludian distance from fovea.
        eccentricity = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
        eccentricity = eccentricity/CONVERSION_FACTOR
        eccentricity = (eccentricity / np.max(eccentricity)) * PYRAMID_LEVEL
        eccentricity = np.round(eccentricity)
        eccentricity = eccentricity.astype(np.int)
    
        return eccentricity

    def smooth_pyramid(self, image, layers=4):
        pyr_img = []
        for (i, resized) in enumerate(pyramid_gaussian(image, max_layer=layers, downscale=1.7, multichannel=True)):
            pyr_img.append(resize(resized, (64,64), anti_aliasing=False, preserve_range=True, anti_aliasing_sigma=i**4, mode='constant'))
        return pyr_img


    def sampling(self, image, fixate_x, fixate_y):
        eccentricity = self.get_eccentricity(fixate_x, fixate_y)
        image = cv2.cvtColor(image, cv2.COLOR_RGB2Lab)
        pyramid = self.smooth_pyramid(image, PYRAMID_LEVEL)
        im_ = np.zeros(image.shape)
        for ecc in range(np.max(eccentricity)+1):
            i  = np.argwhere(eccentricity == ecc)
            if len(i) > 0:
                im_[i[:,0], i[:,1]] = pyramid[ecc][i[:,0], i[:,1]]
        im_ = im_.reshape(-1,64,64,3)
        #Pytorch accepts images as [channel, width, height]
        im_ = np.swapaxes(im_, 3, 2)
        im_ = np.swapaxes(im_, 2, 1)
        return im_

# Test 

In [9]:
def test(rank, args, T, shared_model):
    
    torch.manual_seed(args.seed + rank)
    
    env = env = Env(args)
    
    model = ActorCritic(N_DR_ELEMENTS+N_DR_ELEMENTS, N_ACTIONS, args.hidden_size)
    
    model.eval()
    
    save_dir = os.path.join('.', args.name)  
    
    can_test = True  # Test flag
    
    t_start = 1  # Test step counter to check against global counter
    
    rewards, accuracy, steps = [], [], []  # Rewards and steps for plotting
    l = str(len(str(args.T_max)))  # Max num. of digits for logging steps
    done = True  # Start new episode
    
    # stores step, reward, avg_steps and time 
    results_dict = {'t': [], 'reward': [], 'accuracy': [], 'avg_steps': [], 'time': []}
    
    while T.value() <= args.T_max:
        if can_test:
            t_start = T.value()  # Reset counter
            
            # Evaluate over several episodes and average results
            avg_rewards, avg_episode_lengths, avg_accuracy = [], [], []
            
            for _ in range(args.evaluation_episodes):
                while True:
                    # Reset or pass on hidden state
                    if done:
                        # Sync with shared model every episode
                        model.load_state_dict(shared_model.state_dict())
                        hx = torch.zeros(1, args.hidden_size)
                        cx = torch.zeros(1, args.hidden_size)
                        
                        # Reset environment and done flag
                        state = state_to_tensor(env.reset())
                        done, episode_length = False, 0
                        reward_sum = 0
                    
                    # Calculate policy
                    with torch.no_grad():
                        policy, _, _, (hx, cx) = model(state, (hx, cx))
                    
                    # Choose action greedily
                    action = policy.max(1)[1][0]
                    
                    # Step
                    state, reward, done, _ = env.step(action.item())
                    
                    state = state_to_tensor(state)
                    reward_sum += reward
                    episode_length += 1  # Increase episode counter
                    
                    # Log and reset statistics at the end of every episode
                    if done:
                        avg_rewards.append(reward_sum)
                        avg_episode_lengths.append(episode_length)
                        avg_accuracy.append(env.correct)
                        break

            print(('[{}] Step: {:<' + l + '} Avg. Reward: {:<8} Avg. Episode Length: {:<8} Avg. Accuracy: {:<8}').format(
                datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3],t_start,
                sum(avg_rewards) / args.evaluation_episodes,sum(avg_episode_lengths) / args.evaluation_episodes,sum(avg_accuracy) / args.evaluation_episodes))
            
            fields = [t_start, sum(avg_rewards) / args.evaluation_episodes, sum(avg_episode_lengths) / args.evaluation_episodes, sum(accuracy) / args.evaluation_episodes, str(datetime.now())]
            
            # storing data in the dictionary.
            results_dict['t'].append(t_start)
            results_dict['reward'].append(sum(avg_rewards) / args.evaluation_episodes)
            results_dict['avg_steps'].append(sum(avg_episode_lengths) / args.evaluation_episodes)
            results_dict['time'].append(str(datetime.now()))
            results_dict['accuracy'].append(sum(avg_accuracy) / args.evaluation_episodes)
            
            # Dumping the results in pickle format  
            with open(os.path.join(save_dir, 'results.pck'), 'wb') as f:
                pickle.dump(results_dict, f)
            
            # Saving the data in csv format
            with open(os.path.join(save_dir, 'test_results.csv'), 'a') as f:
                writer = csv.writer(f)
                writer.writerow(fields)
            
            
            torch.save(model.state_dict(), os.path.join(save_dir, 'model.pth'))  # Save model params
            can_test = False  # Finish testing
        else:
            if T.value() - t_start >= args.evaluation_interval:
                can_test = True


        time.sleep(0.001)  # Check if available to test every millisecond

    # Dumping the results in pickle format  
    with open(os.path.join(save_dir, 'results.pck'), 'wb') as f:
        pickle.dump(results_dict, f)

# Training

In [10]:
# Knuth's algorithm for generating Poisson samples
def _poisson(lmbd):
    L, k, p = math.exp(-lmbd), 0, 1
    while p > L:
        k += 1
        p *= random.uniform(0, 1)
    return max(k - 1, 0)

# Transfers gradients from thread-specific model to shared model
def _transfer_grads_to_shared_model(model, shared_model):
    for param, shared_param in zip(model.parameters(), shared_model.parameters()):
        if shared_param.grad is not None:
            return
        shared_param._grad = param.grad


# Adjusts learning rate
def _adjust_learning_rate(optimiser, lr):
    for param_group in optimiser.param_groups:
        param_group['lr'] = lr



# Updates networks
def _update_networks(args, T, model, shared_model, shared_average_model, loss, optimiser):
    # Zero shared and local grads
    optimiser.zero_grad()
    """
    Calculate gradients for gradient descent on loss functions
    Note that math comments follow the paper, which is formulated for gradient ascent
    """
    loss.backward()
    
    # Gradient L2 normalisation
    nn.utils.clip_grad_norm_(model.parameters(), args.max_gradient_norm)
    
    # Transfer gradients to shared model and update
    _transfer_grads_to_shared_model(model, shared_model)
    optimiser.step()
    
    # Update shared_average_model
    for shared_param, shared_average_param in zip(shared_model.parameters(), shared_average_model.parameters()):
        shared_average_param = args.trust_region_decay * shared_average_param + (1 - args.trust_region_decay) * shared_param

        
# Computes an "efficient trust region" loss (policy head only) based on an existing loss and two distributions
def _trust_region_loss(model, distribution, ref_distribution, loss, threshold, g, k):
    
    kl = - (ref_distribution * (distribution.log()-ref_distribution.log())).sum(1).mean(0)
    
    # Compute dot products of gradients
    k_dot_g = (k*g).sum(1).mean(0)
    k_dot_k = (k**2).sum(1).mean(0)
    
    # Compute trust region update
    if k_dot_k.item() > 0:
        trust_factor = ((k_dot_g - threshold) / k_dot_k).clamp(min=0).detach()
    else:
        trust_factor = torch.zeros(1)
    
    # z* = g - max(0, (k^T∙g - δ) / ||k||^2_2)∙k
    trust_loss = loss + trust_factor*kl
    
    return trust_loss

In [11]:
# Trains model
def _train(args, T, model, shared_model, shared_average_model, optimiser, policies, Qs, Vs, actions, rewards, Qret, average_policies, old_policies=None):
    off_policy = old_policies is not None
    
    action_size = policies[0].size(1)
    
    policy_loss, value_loss = 0, 0
    
    # Calculate n-step returns in forward view, stepping backwards from the last state
    
    t = len(rewards)
    
    for i in reversed(range(t)):        
        # Importance sampling weights ρ ← π(∙|s_i) / µ(∙|s_i); 1 for on-policy
        if off_policy:
            rho = policies[i].detach() / old_policies[i]
        else:
            rho = torch.ones(1, action_size)
            
        # Qret ← r_i + γQret
        Qret = rewards[i] + args.discount * Qret
    
        # Advantage A ← Qret - V(s_i; θ)
        A = Qret - Vs[i]

        # Log policy log(π(a_i|s_i; θ))
        log_prob = policies[i].gather(1, actions[i]).log()
    
        # g ← min(c, ρ_a_i)∙∇θ∙log(π(a_i|s_i; θ))∙A
        single_step_policy_loss = -(rho.gather(1, actions[i]).clamp(max=args.trace_max) * log_prob * A.detach()).mean(0)  # Average over batch
    
        # Off-policy bias correction
        if off_policy:
            # g ← g + Σ_a [1 - c/ρ_a]_+∙π(a|s_i; θ)∙∇θ∙log(π(a|s_i; θ))∙(Q(s_i, a; θ) - V(s_i; θ)
            bias_weight = (1 - args.trace_max / rho).clamp(min=0) * policies[i]
            single_step_policy_loss -= (bias_weight * policies[i].log() * (Qs[i].detach() - Vs[i].expand_as(Qs[i]).detach())).sum(1).mean(0)
    
        if args.trust_region:        
            # KL divergence k ← ∇θ0∙DKL[π(∙|s_i; θ_a) || π(∙|s_i; θ)]
            k = -average_policies[i].gather(1, actions[i]) / (policies[i].gather(1, actions[i]) + 1e-10)
        
            if off_policy:
                g = (rho.gather(1, actions[i]).clamp(max=args.trace_max) * A / (policies[i] + 1e-10).gather(1, actions[i]) \
                     + (bias_weight * (Qs[i] - Vs[i].expand_as(Qs[i]))/(policies[i] + 1e-10)).sum(1)).detach()
        
            else:
                g = (rho.gather(1, actions[i]).clamp(max=args.trace_max) * A / (policies[i] + 1e-10).gather(1, actions[i])).detach()
      
            # Policy update dθ ← dθ + ∂θ/∂θ∙z*
            policy_loss += _trust_region_loss(model, policies[i].gather(1, actions[i]) + 1e-10, average_policies[i].gather(1, actions[i]) + 1e-10, single_step_policy_loss, args.trust_region_threshold, g, k)
        
        else:
            # Policy update dθ ← dθ + ∂θ/∂θ∙g
            policy_loss += single_step_policy_loss
    
        # Entropy regularisation dθ ← dθ + β∙∇θH(π(s_i; θ))
        policy_loss -= args.entropy_weight * -(policies[i].log() * policies[i]).sum(1).mean(0)  # Sum over probabilities, average over batch

        # Value update dθ ← dθ - ∇θ∙1/2∙(Qret - Q(s_i, a_i; θ))^2
        Q = Qs[i].gather(1, actions[i])
        value_loss += ((Qret - Q) ** 2 / 2).mean(0)  # Least squares loss

        # Truncated importance weight ρ¯_a_i = min(1, ρ_a_i)
        truncated_rho = rho.gather(1, actions[i]).clamp(max=1)
        # Qret ← ρ¯_a_i∙(Qret - Q(s_i, a_i; θ)) + V(s_i; θ)
        Qret = truncated_rho * (Qret - Q.detach()) + Vs[i].detach()
    
    # Update networks
    _update_networks(args, T, model, shared_model, shared_average_model, policy_loss + value_loss, optimiser)


In [12]:
# Acts and trains model
def train(rank, args, T, shared_model, shared_average_model, optimiser):
    torch.manual_seed(args.seed + rank)
    
    env = Env(args)
    
    model = ActorCritic(N_DR_ELEMENTS+N_DR_ELEMENTS, N_ACTIONS, args.hidden_size)
    
    model.train()
    
    if not args.on_policy:
        # Normalise memory capacity by number of training processes
        memory = EpisodicReplayMemory(args.memory_capacity // args.num_processes, args.max_episode_length)
        
    t = 1  # Thread step counter
    done = True  # Start new episode
        
    while T.value() <= args.T_max:
        # On-policy episode loop
        
        while True:
            # Sync with shared model at least every t_max steps
            model.load_state_dict(shared_model.state_dict())
            # Get starting timestep
            t_start = t
            
            # Reset or pass on hidden state
            if done:
                hx, avg_hx = torch.zeros(1, args.hidden_size), torch.zeros(1, args.hidden_size)
                cx, avg_cx = torch.zeros(1, args.hidden_size), torch.zeros(1, args.hidden_size)
                
                # Reset environment and done flag
                state = state_to_tensor(env.reset())
                done, episode_length, prev_action = False, 0, -1
            else:
                # Perform truncated backpropagation-through-time (allows freeing buffers after backwards call)
                hx = hx.detach()
                cx = cx.detach()
            
            # Lists of outputs for training
            policies, Qs, Vs, actions, rewards, average_policies = [], [], [], [], [], []
            
            while not done and t - t_start < args.t_max:
                # Calculate policy and values
                policy, Q, V, (hx, cx) = model(state, (hx, cx))
                average_policy, _, _, (avg_hx, avg_cx) = shared_average_model(state, (avg_hx, avg_cx))

                # Sample action
                action = torch.multinomial(policy, 1)[0, 0]

                # Step
                next_state, reward, done, _ = env.step(action.item())
                next_state = state_to_tensor(next_state)
                episode_length += 1  # Increase episode counter

                if not args.on_policy:
                    # Save (beginning part of) transition for offline training
                    memory.append(state, action, reward, policy.detach())  # Save just tensors
                
                # Save outputs for online training
                [arr.append(el) for arr, el in zip((policies, Qs, Vs, actions, rewards, average_policies),
                                (policy, Q, V, torch.LongTensor([[action]]), torch.Tensor([[reward]]), average_policy))]
                
                # Increment counters
                t += 1
                T.increment()
                
                # Update state
                state = next_state
                prev_action = action
            
            # Break graph for last values calculated (used for targets, not directly as model outputs)
            
            if done:
                # Qret = 0 for terminal s
                if prev_action == PRESENT or prev_action == ABSENT:
                    Qret = torch.zeros(1, 1)
                else:
                    _, _, Qret, _ = model(state, (hx, cx))
                    Qret = Qret.detach()
                        
                
                if not args.on_policy:
                    # Save terminal state for offline training
                    memory.append(state, None, None, None)
            
            else:
                # Qret = V(s_i; θ) for non-terminal s
                _, _, Qret, _ = model(state, (hx, cx))
                Qret = Qret.detach()
                
            # Train the network on-policy
            _train(args, T, model, shared_model, shared_average_model, optimiser, policies, Qs, Vs, actions, rewards, Qret, average_policies)
            
            # Finish on-policy episode
            if done:
                break
        
        # Train the network off-policy when enough experience has been collected
        if not args.on_policy and len(memory) >= args.replay_start:
            # Sample a number of off-policy episodes based on the replay ratio
            for _ in range(_poisson(args.replay_ratio)):
                # Act and train off-policy for a batch of (truncated) episode
                trajectories = memory.sample_batch(args.batch_size, maxlen=args.t_max)
                
                # Reset hidden state
                hx, avg_hx = torch.zeros(args.batch_size, args.hidden_size), torch.zeros(args.batch_size, args.hidden_size)
                cx, avg_cx = torch.zeros(args.batch_size, args.hidden_size), torch.zeros(args.batch_size, args.hidden_size)
                
                # Lists of outputs for training
                policies, Qs, Vs, actions, rewards, old_policies, average_policies = [], [], [], [], [], [], []
                
                # Loop over trajectories (bar last timestep)
                for i in range(len(trajectories) - 1):
                    # Unpack first half of transition
                    state = torch.cat(tuple(trajectory.state for trajectory in trajectories[i]), 0)
                    action = torch.LongTensor([trajectory.action for trajectory in trajectories[i]]).unsqueeze(1)
                    reward = torch.Tensor([trajectory.reward for trajectory in trajectories[i]]).unsqueeze(1)
                    old_policy = torch.cat(tuple(trajectory.policy for trajectory in trajectories[i]), 0)
                    
                    # Calculate policy and values
                    policy, Q, V, (hx, cx) = model(state, (hx, cx))
                    average_policy, _, _, (avg_hx, avg_cx) = shared_average_model(state, (avg_hx, avg_cx))
                    
                    # Save outputs for offline training
                    [arr.append(el) for arr, el in zip((policies, Qs, Vs, actions, rewards, average_policies, old_policies),
                                             (policy, Q, V, action, reward, average_policy, old_policy))]
                    
                    # Unpack second half of transition
                    next_state = torch.cat(tuple(trajectory.state for trajectory in trajectories[i + 1]), 0)
                    done = torch.Tensor([trajectory.action is None for trajectory in trajectories[i + 1]]).unsqueeze(1)
                
                # Do forward pass for all transitions
                _, _, Qret, _ = model(next_state, (hx, cx))
                
                # Qret = 0 for terminal s, V(s_i; θ) otherwise
                Qret = ((1 - done) * Qret).detach()
                
                # Train the network off-policy
                _train(args, T, model, shared_model, shared_average_model, optimiser, policies, Qs, Vs,
                       actions, rewards, Qret, average_policies, old_policies=old_policies)
        done = True


In [None]:
if __name__ == '__main__':
    
    # BLAS setup
    os.environ['OMP_NUM_THREADS'] = '2'
    os.environ['MKL_NUM_THREADS'] = '2'
    
    # Setup
    args = parser.parse_args(args=[])
    # Creating directories.
    save_dir = os.path.join('.', args.name)  
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)  

    torch.manual_seed(args.seed)
    T = Counter()  # Global shared counter
    
    shared_model = ActorCritic(N_DR_ELEMENTS+N_DR_ELEMENTS, N_ACTIONS, args.hidden_size)
    shared_model.share_memory()
    
    if args.pretrain_model_available:
        # Load pretrained weights
        shared_model.load_state_dict(torch.load('model.pth'))
        
    # Create average network
    shared_average_model = ActorCritic(N_DR_ELEMENTS+N_DR_ELEMENTS, N_ACTIONS, args.hidden_size)
    shared_average_model.load_state_dict(shared_model.state_dict())
    shared_average_model.share_memory()
    
    for param in shared_average_model.parameters():
        param.requires_grad = False
        
    # Create optimiser for shared network parameters with shared statistics
    optimiser = SharedRMSprop(shared_model.parameters(), lr=args.lr, alpha=args.rmsprop_decay)
    optimiser.share_memory()
    
    fields = ['t', 'rewards', 'avg_steps', 'accuracy', 'time']
    with open(os.path.join(save_dir, 'test_results.csv'), 'w') as f:
        writer = csv.writer(f)
        writer.writerow(fields)
    
    processes = []
    # Start validation agent
    p = threading.Thread(target=test, args=(0, args, T, shared_model))
    p.start()
    processes.append(p)

    # Start training agents
    for rank in range(1, args.num_processes + 1):
        t = threading.Thread(target=train, args=(rank, args, T, shared_model, shared_average_model, optimiser))
        t.start()
        #p = mp.Process(target=train, args=(rank, args, T, shared_model, shared_average_model, optimiser))
        #p.start()
        print('Process ' + str(rank) + ' started')
        processes.append(t)
    
    for p in processes:
        print(p.is_alive())
    # Clean up
    for p in processes:
        p.join()


Process 1 started
Process 2 started
Process 3 started
True
True
True
True
[2019-04-06 09:42:12,797] Step: 0         Avg. Reward: -1.5000000000000002 Avg. Episode Length: 15.0     Avg. Accuracy: 0.0     
[2019-04-06 09:52:52,782] Step: 25001     Avg. Reward: -0.8     Avg. Episode Length: 1.0      Avg. Accuracy: 0.3     
[2019-04-06 10:06:04,806] Step: 50002     Avg. Reward: -1.2     Avg. Episode Length: 1.0      Avg. Accuracy: 0.2     
[2019-04-06 10:21:38,001] Step: 75002     Avg. Reward: -0.4     Avg. Episode Length: 1.0      Avg. Accuracy: 0.4     
[2019-04-06 10:37:40,011] Step: 100003    Avg. Reward: -0.8     Avg. Episode Length: 1.0      Avg. Accuracy: 0.3     
[2019-04-06 10:53:39,885] Step: 125003    Avg. Reward: 0.4      Avg. Episode Length: 1.0      Avg. Accuracy: 0.6     
[2019-04-06 11:09:34,889] Step: 150004    Avg. Reward: 0.4      Avg. Episode Length: 1.0      Avg. Accuracy: 0.6     
[2019-04-06 11:25:48,447] Step: 175006    Avg. Reward: 0.4      Avg. Episode Length: 1.0 

[2019-04-07 01:59:36,366] Step: 1675041   Avg. Reward: 0.29999999999999993 Avg. Episode Length: 2.0      Avg. Accuracy: 0.6     
[2019-04-07 02:14:36,142] Step: 1700041   Avg. Reward: 1.5000000000000002 Avg. Episode Length: 2.0      Avg. Accuracy: 0.9     
[2019-04-07 02:29:18,100] Step: 1725041   Avg. Reward: -0.10000000000000013 Avg. Episode Length: 2.0      Avg. Accuracy: 0.5     
[2019-04-07 02:43:56,092] Step: 1750042   Avg. Reward: -0.10000000000000009 Avg. Episode Length: 2.0      Avg. Accuracy: 0.5     
[2019-04-07 02:58:30,727] Step: 1775042   Avg. Reward: 0.6999999999999998 Avg. Episode Length: 2.0      Avg. Accuracy: 0.7     
[2019-04-07 03:13:41,483] Step: 1800042   Avg. Reward: -0.10000000000000009 Avg. Episode Length: 2.0      Avg. Accuracy: 0.5     
[2019-04-07 03:29:03,313] Step: 1825042   Avg. Reward: 0.29999999999999993 Avg. Episode Length: 2.0      Avg. Accuracy: 0.6     
[2019-04-07 03:44:09,865] Step: 1850042   Avg. Reward: 0.29999999999999993 Avg. Episode Length: 

[2019-04-07 17:08:18,028] Step: 3300063   Avg. Reward: 0.7      Avg. Episode Length: 2.0      Avg. Accuracy: 0.7     
[2019-04-07 17:21:41,226] Step: 3325063   Avg. Reward: 0.6699999999999999 Avg. Episode Length: 2.3      Avg. Accuracy: 0.7     
[2019-04-07 17:34:47,045] Step: 3350063   Avg. Reward: 0.29999999999999993 Avg. Episode Length: 2.0      Avg. Accuracy: 0.6     
[2019-04-07 17:48:19,995] Step: 3375063   Avg. Reward: -0.17999999999999994 Avg. Episode Length: 2.8      Avg. Accuracy: 0.5     
[2019-04-07 18:02:05,598] Step: 3400063   Avg. Reward: 1.0899999999999999 Avg. Episode Length: 2.1      Avg. Accuracy: 0.8     
[2019-04-07 18:15:30,471] Step: 3425064   Avg. Reward: 1.09     Avg. Episode Length: 2.1      Avg. Accuracy: 0.8     
[2019-04-07 18:29:10,581] Step: 3450065   Avg. Reward: 1.1      Avg. Episode Length: 2.0      Avg. Accuracy: 0.8     
[2019-04-07 18:42:51,428] Step: 3475065   Avg. Reward: 1.5000000000000002 Avg. Episode Length: 2.0      Avg. Accuracy: 0.9     
[20

In [None]:
args = parser.parse_args(args=[])

In [None]:
args.pretrain_model_available