In [1]:
"""
Generates a multi-arm bandit problem, and trains a network on it
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import time
import argparse
import sys
import random
import numpy as np
import torch
import torch.optim as optim
from torch.distributions.categorical import Categorical

from bandit import BanditProblem
from learner import SimpleRNN, makeObservation

In [2]:
def assure_equal_and_not_empty(a,b):
    return len(a) == len(b) and len(a) > 0


def step(gamma, optimizer, records):
    """implements REINFORCE for learning
    """
    rewards = records.get("reward")
    log_probs = records.get("log_prob")
    assure_equal_and_not_empty(rewards, log_probs)

    R = 0
    discount_rewards = []
    policy_loss = []

    for r in rewards[::-1]:
        R = r + gamma * R
        discount_rewards.insert(0, R)

    for log_prob, reward in zip(log_probs, discount_rewards):
        policy_loss.append(-log_prob * reward)

    optimizer.zero_grad()
    policy_loss = torch.cat(policy_loss).sum()
    policy_loss.backward(retain_graph=True)
    optimizer.step()

    
class EpisodeRecorder:
    def __init__(self):
        self.data = {}

    def record(self, thing, x):
        if thing in self.data:
            self.data[thing].append(x)
        else:
            self.data[thing] = [x]
    
    def get(self, thing):
        """Returns the recorded data
        
        Note: Raise an exception if `thing` was not recorded..
        """
        return self.data[thing]

       


In [8]:
def run(n_episodes, sequence_length, gamma, display_epochs):
    # create a meta-learner
    s = SimpleRNN(hidden_size=20, layers=2)
    optimizer = optim.SGD(s.parameters(), lr = 0.01, momentum=0.9)
    fantastic = EpisodeRecorder()

    K = 1
    for i in range(n_episodes // K):

        # reset the hidden state after every K environments
        s.reset_hidden()
        for y in range(K):
            start_time = time.time()
            
            # reset the environment
            action = 0
            reward = 0
            epi = EpisodeRecorder()
            b = BanditProblem()

            for t in range(sequence_length):
                done = (t == sequence_length - 1)

                # consult our neural network
                obs = makeObservation(0, action, reward, done)
                probs = s.forward(obs)
                epi.record("log_prob", probs)

                # choose the max action
                # (this is the part that isn't differentiable!)
                di = Categorical(probs)
                action = di.sample()
                epi.record("action", action)

                # update the last reward
                reward = b.pull(action)
                epi.record("reward", reward)

            # upgrade our neural network
            step(gamma, optimizer, epi)

            # before we destroy the environment
            fantastic.record("action_mean", np.mean(epi.get("action")))
            fantastic.record("action_var", np.std(epi.get("action")))
            fantastic.record("average_reward", np.mean(epi.get("reward")))
            fantastic.record("time", time.time() - start_time)

            # is it learning?
            current_epoch = K*i +y
            if current_epoch % display_epochs == 0:
                display = ""
                display += "Episode {}, Time (elapsed {:.0f}, {:.4f}s/episode), ".format(current_epoch, np.sum(fantastic.get("time")), np.mean(fantastic.get("time")[-display_epochs:]))
                display += "Reward (avg {:.4f}) ".format(np.mean(fantastic.get("average_reward")[-display_epochs:]))
                display += "Actions (std.dev {:.4f}) ".format(fantastic.get("action_var")[-1])
                print(display)

In [10]:
n_episodes = 10
sequence_length = 5
display_epochs = 1
gamma = 0.5

run(n_episodes, sequence_length, gamma, display_epochs)

Episode 0, Time (elapsed 0, 0.0169s/episode), Reward (avg 0.2000) Actions (std.dev 0.4899) 
Episode 1, Time (elapsed 0, 0.0150s/episode), Reward (avg 1.0000) Actions (std.dev 0.4000) 
Episode 2, Time (elapsed 0, 0.0185s/episode), Reward (avg 0.2000) Actions (std.dev 0.4000) 
Episode 3, Time (elapsed 0, 0.0137s/episode), Reward (avg 0.8000) Actions (std.dev 0.4899) 
Episode 4, Time (elapsed 0, 0.0166s/episode), Reward (avg 0.0000) Actions (std.dev 0.0000) 
Episode 5, Time (elapsed 0, 0.0224s/episode), Reward (avg 0.6000) Actions (std.dev 0.4899) 
Episode 6, Time (elapsed 0, 0.0176s/episode), Reward (avg 0.2000) Actions (std.dev 0.4899) 
Episode 7, Time (elapsed 0, 0.0159s/episode), Reward (avg 0.4000) Actions (std.dev 0.4899) 
Episode 8, Time (elapsed 0, 0.0190s/episode), Reward (avg 1.0000) Actions (std.dev 0.4899) 
Episode 9, Time (elapsed 0, 0.0176s/episode), Reward (avg 0.8000) Actions (std.dev 0.4000) 
