In [1]:
import math
import random
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple
from itertools import count
from PIL import Image
from pyemd import emd

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T

# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display

plt.ion()

# if gpu is to be used
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))


class ReplayMemory(object):

    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, *args):
        """Saves a transition."""
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

def combine(state, goal):
    if state is None:
        return None
    return torch.cat((state, goal), 1)

In [3]:
class miniDQN(nn.Module):

    def __init__(self, hidden):
        super(miniDQN, self).__init__()
        self.W1 = nn.Linear(8, hidden)
        self.W2 = nn.Linear(hidden, 4)
        nn.init.xavier_uniform_(self.W1.weight)
        nn.init.xavier_uniform_(self.W2.weight)

    def forward(self, x):
        x = F.relu(self.W1(x))
        return self.W2(x)

In [4]:
num_row = 5
num_col = 5

action_direction = [(0,1), (0,-1), (1,0), (-1,0)]
block = { "empty" : 0, "blue" : 1, "red" : 2, "goal" : 3, "wall" : 4}

reds = [(3,0)]
blues = [(0,3)]
start = (0,0)
end = (4,4)
walls = []

start_state = np.asarray(list(start) + [0,0])
end_state = np.asarray(list(end) + [1,1])


def is_done(s):
    row, col, blue, red = tuple(s)
    return blue and red and (row, col) == end

def step(s, a):
    row, col, blue, red = tuple(s)
    a_row, a_col = action_direction[a]

    row = row + a_row
    col = col + a_col

    if row < 0 or row >= num_row or col < 0 or col >= num_col or (row, col) in walls:
        return s

    blue = blue or (row, col) in blues
    red = red or (row, col) in reds

    return np.asarray((row, col, blue, red))

def get_reward(s, a, s_prime):
    if is_done(s):
        return 0.0
    return -1.0

def L1(s1, s2):
    return 1.0 * np.sum(np.abs(np.asarray(s1) - np.asarray(s2)))

In [5]:
#Very adhoc, requires no dynamics
distance_cache = {}

def distance(a, b):
    global distance_cache
    ta = tuple(a)
    tb = tuple(b)
    if (ta,tb) in distance_cache:
        return distance_cache[(ta,tb)]
    
    row_a, col_a, blue_a, red_a = ta
    row_b, col_b, blue_b, red_b = tb
    
    blue_match = blue_a == blue_b
    red_match = red_a == red_b
    
    if red_match and not blue_match:
        dist = min([L1((row_a,col_a),blue) + L1(blue, (row_b, col_b)) for blue in blues])
    elif not red_match and blue_match:
        dist = min([L1((row_a,col_a),red) + L1(red, (row_b, col_b)) for red in reds])
    elif not red_match and not blue_match:
        x = min([L1((row_a,col_a),red) + L1(red, blue) + L1(blue, (row_b, col_b)) for red, blue in zip(reds, blues)])
        y = min([L1((row_a,col_a),blue) + L1(blue, red) + L1(red, (row_b, col_b)) for red, blue in zip(reds, blues)])
        dist = min(x,y)
    else:
        dist = L1(ta, tb)
    distance_cache[(ta,tb)] = dist
    return dist

In [39]:
BATCH_SIZE = 128
GAMMA = 0.999
EPS_START = 0.9
EPS_END = 0.06
EPS_DECAY = 4000
TARGET_UPDATE = 50
HIDDEN_LAYER_SIZE = 20
EPISODE_LENGTH = 30
LEARNING_RATE = 0.0003

DIVERSE_EXPERT = True

num_episodes = 40000
random.seed(3)
distance_cache = {}

policy_net = miniDQN(HIDDEN_LAYER_SIZE).to(device)
target_net = miniDQN(HIDDEN_LAYER_SIZE).to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

optimizer = optim.Adam(policy_net.parameters(), lr = LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.99999)
memory = ReplayMemory(40000)

In [40]:
steps_done = 0

def select_action(state, goal, test = False):
    global steps_done
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END) * math.exp(-1. * steps_done / EPS_DECAY)
    steps_done += 1
    if test or sample > eps_threshold:
        with torch.no_grad():
            x = torch.cat((state, goal), 1)
            return policy_net(x).max(1)[1].view(1, 1)
    else:
        return torch.tensor([[random.randrange(4)]], device=device, dtype=torch.long)

def select_action_entropy(state, goal, alpha):
    x = torch.cat((state, goal), 1)
    potential = (alpha * policy_net(x).detach()).exp().squeeze()
    dist = (potential / potential.sum()).numpy()
    return torch.tensor([[np.random.choice(4, p = dist)]], device=device, dtype=torch.long)
    

def select_expert_action(state, goal):
    row, col, blue, red = tuple(state)
    action_scores = np.zeros(4)
    for a in range(4):
        #Assumes determinism
        next_state = step(state, a)
        action_scores[a] = distance(next_state, goal)
    action = np.random.choice(np.flatnonzero(action_scores == action_scores.min()))
    return torch.tensor([[action]], device=device, dtype=torch.long)    

def sample_expert_trajectory():
    state = start_state
    next_state = None

    trajectory = []
    
    for t in range(EPISODE_LENGTH):
        action = select_expert_action(state, end_state)
        
        torch_state = torch.from_numpy(state).float().unsqueeze(0)
        next_state = step(state, action.item())
        torch_next_state = torch.from_numpy(next_state).float().unsqueeze(0)
        reward = get_reward(state, action.item(), next_state)
        done = is_done(state)
        reward = torch.tensor([reward], device=device)
        if done:
            torch_next_state = None

        trajectory.append((torch_state, action, torch_next_state, reward))
        if done:
            break
        state = next_state
    return trajectory

In [41]:
def optimize_model():
    if len(memory) < BATCH_SIZE:
        return
    transitions = memory.sample(BATCH_SIZE)
    batch = Transition(*zip(*transitions))

    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,batch.next_state)),
                                  device=device, dtype=torch.uint8)
    non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])
    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)

    state_action_values = policy_net(state_batch).gather(1, action_batch)

    #DDQN
    policy_action_indices = policy_net(non_final_next_states).max(1)[1].detach()
    target_actions = target_net(non_final_next_states).detach()
    next_state_values = torch.zeros(BATCH_SIZE, device=device)
    next_state_values[non_final_mask] = target_actions.gather(1, policy_action_indices.unsqueeze(1)).squeeze()
    expected_state_action_values = (next_state_values * GAMMA) + reward_batch

    loss = F.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1))

    
    # Optimize the model
    optimizer.zero_grad()
    loss.backward()
    for param in policy_net.parameters():
        param.grad.data.clamp_(-1, 1)
    optimizer.step()
    return loss.item()

In [42]:
losses = [0]
torch_end = torch.from_numpy(end_state).float().unsqueeze(0)
expert_trajectories = [sample_expert_trajectory()]
for i in range(10000):
    if DIVERSE_EXPERT:
        expert_trajectories.append(sample_expert_trajectory())
    for torch_state, action, torch_next_state, reward in expert_trajectories[-1]:
        memory.push(combine(torch_state, torch_end), action, combine(torch_next_state, torch_end), reward)

for i_episode in range(num_episodes):
    state = start_state
    next_state = None
    final = i_episode == num_episodes - 1
    trajectory = []
    
    if i_episode % 500 == 0:
        scheduler.step()
    
    for t in range(EPISODE_LENGTH):
        torch_state = torch.from_numpy(state).float().unsqueeze(0)

        action = select_action(torch_state, torch_end, final)
            
        next_state = step(state, action.item())
        torch_next_state = torch.from_numpy(next_state).float().unsqueeze(0)
        reward = get_reward(state, action.item(), next_state)
        done = is_done(state)
        reward = torch.tensor([reward], device=device)

        if done:
            torch_next_state = None

        trajectory.append((torch_state, action, torch_next_state, reward))

        loss = optimize_model()
        losses.append(loss if loss is not None else 0)
        if final:
            print(state)
        if done:
            break
        state = next_state

    if i_episode % 1000 == 0:
        print(i_episode)
        print(np.mean(np.asarray(losses)))
        losses = []

    #HER updates with goal as real goal OR last state in trajectory
    for torch_state, action, torch_next_state, reward in trajectory:
        memory.push(combine(torch_state, torch_end), action, combine(torch_next_state, torch_end), reward)
        if not is_done(state):
            torch_local_goal = torch.from_numpy(state).float().unsqueeze(0)
            local_reward = torch.equal(torch_next_state, torch_local_goal)
            new_reward = torch.add(reward, local_reward)
            memory.push(combine(torch_state, torch_local_goal), action, combine(torch_next_state, torch_local_goal), new_reward)
                    
                          
    # Update the target network
    if i_episode % TARGET_UPDATE == 0:
        target_net.load_state_dict(policy_net.state_dict())
    
        
print('Complete')




0
0.38847021229805484
1000
0.05111439498597756
2000
0.12478386413094134
3000
0.45092014468499914
4000
0.29605271870646077
5000
0.2258685481688745
6000
0.12007253752096982
7000
0.04371648656464179
8000
0.021334728512579462
9000
0.01093820929897364
10000
0.012014316723250948
11000
0.014022343133917406
12000
0.041863681958052644
13000
0.08901956006481683
14000
0.07231782415087981
15000
0.06202855918922732
16000
0.022801288374194
17000
0.01220506660997345
18000
0.003973251197270051
19000
0.004384407612070548
20000
0.00615568041589518
21000
0.010150316466067818
22000
0.008543751950415563
23000
0.007174221876029408
24000
0.008793049879841447
25000
0.011347572735524574
26000
0.008198494071975383
27000
0.0027211072684390604
28000
0.012243853237334987
29000
0.01208595953563764
30000
0.004905742340543527
31000
0.0018752481010136986
32000
0.0067144675218266765
33000
0.008840188748981176
34000
0.006939716111046678
35000
0.0022967207471766434
36000
0.0017262735102881366
37000
0.0017301052479709016


In [None]:
def predicted_distance(state, goal):
    torch_state = torch.from_numpy(state).float().unsqueeze(0)
    torch_goal = torch.from_numpy(goal).float().unsqueeze(0)
    x = torch.cat((torch_state, torch_goal), 1)
    return -1 * policy_net(x).max(1)[0].view(1, 1).item()

def predicted_measure_ball(state, goal, n = 500):
    measure = {}
    torch_state = torch.from_numpy(state).float().unsqueeze(0)
    torch_goal = torch.from_numpy(goal).float().unsqueeze(0)
    x = torch.cat((torch_state, torch_goal), 1)
    for i in range(n):
        action = select_action_entropy(torch_state, torch_goal, alpha = 4.0)
        next_state = step(state, action.item())
        key = tuple(next_state)
        measure[key] = measure.get(key, 0) + 1.0/n
    return measure

def wasserstein(a, b, goal):
    if tuple(a) == tuple(b):
        return 0
    m_a = predicted_measure_ball(a, goal)
    m_b = predicted_measure_ball(b, goal)
    states = list(set(m_a.keys()) | set(m_b.keys()))
    
    m_a = [m_a.get(s, 0.0) for s in states]
    m_b = [m_b.get(s, 0.0) for s in states]
    dist = [[distance(np.asarray(s1),np.asarray(s2)) for s1 in states] for s2 in states]
    return emd(np.asarray(m_a), np.asarray(m_b), np.asarray(dist))
