In [1]:
import random
import math
from collections import defaultdict, namedtuple
from itertools import count

import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T

from utils import *

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
Transition = namedtuple('Transition',
                       ('state', 'action', 'next_state', 'reward'))

In [3]:
class ReplayMemory(object):

    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, *args):
        """Saves a transition."""
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

In [4]:
class DQN(nn.Module):

    def __init__(self, inp_size, emb_size, hid_size, out_size):
        super(DQN, self).__init__()
        self.emb = nn.Linear(inp_size, emb_size)
        
        self.mid = nn.Linear(emb_size, hid_size)
        self.out = nn.Linear(hid_size, out_size)
        
        self.steps_done = 0
        
    def forward(self, x):
        emb = F.relu(self.emb(x))
        out = F.relu(self.mid(emb))
        
        act = self.out(out)
        return F.softmax(act, dim=-1)
    
    def getAction(self, state):
        
        self.steps_done += 1
        
        sample = random.random()
        eps_threshold = EPS_END + (EPS_START - EPS_END) * \
            math.exp(-1. * self.steps_done / EPS_DECAY)
    
        if state[0,0] == AUTO_STOP:
            return torch.tensor([[0.]], device=device, dtype=torch.long)
        elif sample > eps_threshold:
            with torch.no_grad():
                return policy_net(state).max(1)[1].view(1, 1)
        else:
            return torch.tensor([[random.randrange(n_actions)]], device=device, dtype=torch.long)


In [5]:
BATCH_SIZE = 128
GAMMA = 0.999
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 200
TARGET_UPDATE = 10


n_actions = 2

policy_net = DQN(inp_size=2, emb_size=100, hid_size=256, out_size=2).to(device)
target_net = DQN(inp_size=2, emb_size=100, hid_size=256, out_size=2).to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

optimizer = optim.RMSprop(policy_net.parameters())
memory = ReplayMemory(10_000)


In [6]:
def optimize_model():
    if len(memory) < BATCH_SIZE:
        return
    transitions = memory.sample(BATCH_SIZE)
    # Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for
    # detailed explanation). This converts batch-array of Transitions
    # to Transition of batch-arrays.
    batch = Transition(*zip(*transitions))
    
    # Compute a mask of non-final states and concatenate the batch elements
    # (a final state would've been the one after which simulation ended)
    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                          batch.next_state)), device=device, dtype=torch.uint8)
    non_final_next_states = torch.cat([s for s in batch.next_state
                                                if s is not None])
    
    
    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)
    
    # Compute Q(s_t, a) - the model computes Q(s_t), then we select the
    # columns of actions taken. These are the actions which would've been taken
    # for each batch state according to policy_net
    state_action_values = policy_net(state_batch).gather(1, action_batch)

    # Compute V(s_{t+1}) for all next states.
    # Expected values of actions for non_final_next_states are computed based
    # on the "older" target_net; selecting their best reward with max(1)[0].
    # This is merged based on the mask, such that we'll have either the expected
    # state value or 0 in case the state was final.
    next_state_values = torch.zeros(BATCH_SIZE, device=device)
    next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0].detach()
    # Compute the expected Q values
    expected_state_action_values = (next_state_values * GAMMA) + reward_batch

    # Compute Huber loss
    loss = F.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1))

    # Optimize the model
    optimizer.zero_grad()
    loss.backward()
    for param in policy_net.parameters():
        param.grad.data.clamp_(-1, 1)
    optimizer.step()

In [7]:
class Game2():
    def __init__(self, lo, hi, n_idx, replace, reward_fn, reward):    
        self.lo = lo
        self.hi = hi
        self.n_idx = n_idx
        self.replace = replace
        self.reward_fn = reward_fn
        self.reward = reward
        
        self.reset()
    
    def step(self, action):
        if action == 0:
            reward = self.reward_fn(self)[0]
            done = True
        else:
            self.idx +=1
            self.val = self.values[self.idx]
            reward, done = 0., False
        
        return self.idx, self.val, float(reward), bool(done)
    
    def getReward(self):
        self.game_over = True
        return self.reward_fn(self)
          
    def reset(self):
        """Reset the game"""
        self.values = np.random.choice(np.arange(self.lo, self.hi+1), 
                                       size=self.n_idx, 
                                       replace=self.replace)
        self.values_sorted = np.sort(self.values) [::-1]
        
        self.idx = 0
        self.val = self.values[0]
        
        self.max_val = self.values.max()
        self.max_idx = self.values.argmax()
        
        self.game_over = False
        
game_params = {'lo':1, 
               'hi':25,
               'n_idx':20,
               'replace':False,
               'reward_fn': rewardTopN,
               'reward':{'n':10, 'pos':10, 'neg':-10}
              }

game = Game2(**game_params)

AUTO_STOP = (game.n_idx-1)/game.n_idx

In [8]:
num_episodes = 50

for i_episode in range(num_episodes):
    game.reset()
    
    idx, val, reward, done = game.idx, game.val, 0, False
    state = torch.tensor([idx/game.n_idx, val/game.hi]).unsqueeze(0)
    
    for t in count():
        action = policy_net.getAction(state)
        idx, val, reward, done = game.step(action.item())
        reward = torch.tensor([reward], device=device)
        
        if not done:
            next_state = torch.tensor([idx/game.n_idx, val/game.hi]).unsqueeze(0)
        else:
            next_state = None
            
        memory.push(state, action, next_state, reward)
        state = next_state
        
        optimize_model()
        
        if done:
            break
        
        
    if i_episode % TARGET_UPDATE == 0:
         target_net.load_state_dict(policy_net.state_dict())
        