In [0]:
import gym
import numpy as np
import math
import random
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple
from itertools import count

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T

# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display

plt.ion()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [0]:
from gym import error, spaces, utils
from gym.utils import seeding

# Create a custom environment 
class WTAEnv(gym.Env):
    metadata = {'render.modes': ['console']}

    def __init__(self, target_values, prob, device):
        ''' 
        Assignment - m x n array that maps weapons to assigned target, where
        the assigned target is one-hot encoded to differentiate between unassigned
        weapons (row is all zero's) and assigned weapons (exactly one column 
        index of row is 1)
        target_values - 1D array that maps targets to their value when destroyed
        prob - m x n array where prob[i, j] = probability of weapon i killing target j
        ''' 
        super(WTAEnv, self).__init__()

        self.n = len(target_values)
        self.m = len(prob)

        self.empty_assignment = np.zeros((self.m, self.n))
        self.assignment = np.copy(self.empty_assignment)

        # Keep track of the assigned weapons. Index i is 1 if 
        # weapon i is already assigned, 0 if it is unassigned. Used in order
        # to set probabilities of infeasible weapon-target assignments to 0
        self.weapons_assigned = np.zeros(self.m)
        
        self.target_values = target_values 
        self.prob = prob
        
        # The action space - a number 0 <= i < m * n, where
        # (weapon = i // n, target = i % n)
        # Assigns weapon to target (one target per weapon)
        self.action_space = spaces.MultiDiscrete([self.m, self.n])
        self.device = device
        
    def decode_action(self, action):
        '''
        Given an action, return the weapon and target associated with
        that action.
        '''
        return action // self.n, action % self.n

    def step(self, action):
        weapon, target = self.decode_action(action[0][0])
        if not self.action_space.contains([weapon, target]):
            raise ValueError("Received invalid action={} which is not part of the action space".format(action))
        # Update the weapons remaining and the assignment
        self.weapons_assigned[weapon] = 1
        self.assignment[weapon][target] = 1
        # The reward is simply expected value of the chosen weapon killing 
        # the chosen target times the value of the target
        reward = self.prob[weapon][target] * self.target_values[target]
        done = np.sum(self.weapons_assigned) == self.m
        return (self.get_state(), reward, done, {})

    def get_state(self):
        '''
        First flatten the matrix of one-hot encoded assignment to get 1D vector
        of length m * n. Then append the matrix of weapons_remaining
        Our state is a size m * n + m + n array.
        state[:m*n] is self.assignments
        state[m*n:m*n+m] is the number of weapons left 
        state[m*n+m:] is self.target_values
        '''
        flat_assign = self.assignment.flatten()
        state = np.concatenate([flat_assign, self.weapons_assigned, self.target_values])
        # Return state as a tensor and add a batch dimension
        return torch.tensor(state, device=self.device).unsqueeze(0)

    def reset(self):
        '''
        Important: the observation must be a numpy array
        :return: (np.array) 
        '''
        self.assignment = np.copy(self.empty_assignment)
        self.weapons_assigned = np.zeros(self.m)
        return self.get_state()

    def render(self, mode='human', close=False):
        pass

In [0]:
Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))

class ReplayMemory(object):
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, *args):
        """Saves a transition."""
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

In [0]:
def generate_initial_assignment(n, m):
    '''   
    Start with all weapons unassigned 
    ''' 
    return np.zeros((m, n)).astype(np.int32)

In [0]:
steps_done = 0

def select_action(state, weapons_assigned, m, n):
    global steps_done
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END) * \
        math.exp(-1. * steps_done / EPS_DECAY)
    steps_done += 1
    weapons_tensor = torch.tensor(weapons_assigned)
    if sample > eps_threshold:
      with torch.no_grad():
        policy_net.eval()
        nn_prob = policy_net(state)
        policy_net.train()
        ind = torch.max(nn_prob, dim = 1)[1]
        return torch.tensor([[ind]], device=device, dtype=torch.int)
    else:
      weapons_left = (weapons_tensor == 0).nonzero().flatten()
      weapon = np.random.choice(weapons_left.numpy())
      target = random.randrange(n)
      ind = weapon * n + target
      return torch.tensor([[ind]], device=device, dtype=torch.int)

episode_durations = []

def plot_durations(ep):
    fig = plt.figure(2)
    plt.clf()
    durations_t = torch.tensor(episode_durations, dtype=torch.float)
    plt.title('Training...')
    plt.xlabel('Episode')
    plt.ylabel('Duration')
    plt.plot(durations_t.numpy())
    # Take 100 episode averages and plot them too
    if len(durations_t) >= 100:
        means = durations_t.unfold(0, 100, 1).mean(1).view(-1)
        means = torch.cat((torch.zeros(99), means))
        plt.plot(means.numpy())

    plt.pause(0.001)  # pause a bit so that plots are updated
    if is_ipython:
        display.clear_output(wait=True)
        display.display(plt.gcf())
        # if ep % 5 == 0:
        #   fig.savefig('episode{}'.format(ep))

In [0]:
def optimize_model(memory):
    if len(memory) < BATCH_SIZE:
        return
    transitions = memory.sample(BATCH_SIZE)
    # Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for
    # detailed explanation). This converts batch-array of Transitions
    # to Transition of batch-arrays.
    batch = Transition(*zip(*transitions))

    # Compute a mask of non-final states and concatenate the batch elements
    # (a final state would've been the one after which simulation ended)
    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                          batch.next_state)), device=device, dtype=torch.bool)
    non_final_next_states = torch.cat([s for s in batch.next_state
                                                if s is not None])

    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)
    # Compute Q(s_t, a) - the model computes Q(s_t), then we select the
    # columns of actions taken. These are the actions which would've been taken
    # for each batch state according to policy_net
    state_action_values = policy_net(state_batch).gather(dim=1, index =action_batch.long())

    # Compute V(s_{t+1}) for all next states.
    # Expected values of actions for non_final_next_states are computed based
    # on the "older" target_net; selecting their best reward with max(1)[0].
    # This is merged based on the mask, such that we'll have either the expected
    # state value or 0 in case the state was final.
    next_state_values = torch.zeros(BATCH_SIZE, device=device)
    next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0].detach()
    # Compute the expected Q values
    expected_state_action_values = (next_state_values * GAMMA) + reward_batch

    # Compute Huber loss
    loss = F.smooth_l1_loss(state_action_values.double(), expected_state_action_values.unsqueeze(1).double())
    value_loss = loss.item()
    # Optimize the model
    optimizer.zero_grad()
    loss.backward()
    for param in policy_net.parameters():
        param.grad.data.clamp_(-1, 1)
    optimizer.step()
    return value_loss

In [0]:
'''Take in n by m matrix, convert it to 1D feature vector '''
class DQN(nn.Module):
  
    def __init__(self, n, m, embedding_size=8):
        super(DQN, self).__init__()
        # The assignment becomes embedded, so it has size m * embedding_size
        # when flattened
        self.assignment_size = m * n
        # The input consists of the current partial assignment, with the 
        # n target values appended to the end
        self.input_size = self.assignment_size + n
        self.n = n
        self.m = m
        self.mask_val = -2 ** 30
          
        units = 50
        # Return the probabilities of making a new assignment of weapon 
        # m to target n
        self.output_size = m * n
        self.lin1 = nn.Linear(self.input_size, units)
        self.drop1 = nn.Dropout(0.2)
        self.lin2 = nn.Linear(units, self.output_size)
        self.drop2 = nn.Dropout(0.2)

    # Called with either one element to determine next action, or a batch
    # during optimization. Returns tensor([[left0exp,right0exp]...]).
    def forward(self, state):
        end_assign = self.m * self.n
        assignment = state[:, :end_assign].float()
        assigned = state[:, end_assign:(end_assign + m)].int()
        values = state[:, (end_assign + m):].float()
        # Append the target values
        x = torch.cat([assignment, values], dim=1)
        x = F.relu(self.drop1(self.lin1(x)))
        x = F.relu(self.drop2(self.lin2(x)))
        # Mask out weapons that have already been assigned by adding a large
        # negative number 
        mask = assigned.unsqueeze(-1).repeat(1, 1, n).view(assigned.shape[0], -1)
        mask *= self.mask_val 
        x = x.add(torch.tensor(mask))
        return x

In [0]:
BATCH_SIZE = 128
GAMMA = 0.999
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 200
TARGET_UPDATE = 100

n = 4
m = 5
lower_val = 25
upper_val = 50
lower_prob = 0.6
upper_prob = 0.9
values = np.random.uniform(lower_val, upper_val, n)
prob = np.random.uniform(lower_prob, upper_prob, (m, n))
# assignment, weapons = generate_initial_assignment(n, m)
env = WTAEnv(values, prob, device)

policy_net = DQN(n, m).to(device)
target_net = DQN(n, m).to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

optimizer = optim.RMSprop(policy_net.parameters())

In [29]:
num_episodes = 50
env.reset()
memory = ReplayMemory(10000)
losses = []
for i_episode in range(num_episodes):
    # Initialize the environment and state
    env.reset()
    state = env.get_state()
    for t in count():
        # print(f'episode {i_episode}/{num_episodes}, iteration {t}', end='\n')
        # Select and perform an action
        action = select_action(state, env.weapons_assigned, env.m, env.n)
        observation, reward, done, _ = env.step(action)
        reward = torch.tensor([reward], device=device)
        
        if not done:
            next_state = observation
        else:
            next_state = None

        # Store the transition in memory
        memory.push(state, action, next_state, reward)
        # Move to the next state
        state = next_state

        # Perform one step of the optimization (on the target network)
        loss = optimize_model(memory)
        if loss != None:
          losses.append(loss)
        if done:
          episode_durations.append(t + 1)
          # plot_durations(i_episode)
          break
    # Update the target network, copying all weights and biases in DQN
    if i_episode % TARGET_UPDATE == 0:
        target_net.load_state_dict(policy_net.state_dict())

# Just print out the last 10 values or so of loss
print(losses[-10:])
env.render()
env.close()
plt.show()
plt.ioff()



[10.64640604438131, 9.394657822303108, 8.441663288591723, 8.915235121685468, 10.866693179366283, 10.15083732941228, 8.753035384793995, 9.701084736108823, 8.95574353385202, 8.745841075013935]
