In [62]:
import math
import random
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple
from itertools import count
from PIL import Image
from pyemd import emd

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T

# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display

plt.ion()

# if gpu is to be used
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [63]:
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))


class ReplayMemory(object):

    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, *args):
        """Saves a transition."""
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

def combine(state, goal):
    if state is None:
        return None
    return torch.cat((state, goal), 1)

In [64]:
class miniDQN(nn.Module):

    def __init__(self, hidden):
        super(miniDQN, self).__init__()
        self.W1 = nn.Linear(8, hidden)
        self.W2 = nn.Linear(hidden, 4)
        nn.init.xavier_uniform_(self.W1.weight)
        nn.init.xavier_uniform_(self.W2.weight)

    def forward(self, x):
        x = F.relu(self.W1(x))
        return self.W2(x)

In [65]:
num_row = 8
num_col = 8

action_direction = [(0,1), (0,-1), (1,0), (-1,0)]
block = { "empty" : 0, "blue" : 1, "red" : 2, "goal" : 3, "wall" : 4}

reds = [(4,0)]
blues = [(0,4)]
start = (0,0)
end = (6,6)
walls = []

start_state = np.asarray(list(start) + [0,0])
end_state = np.asarray(list(end) + [1,1])


def is_done(s):
    row, col, blue, red = tuple(s)
    return blue and red and (row, col) == end

def step(s, a):
    row, col, blue, red = tuple(s)
    a_row, a_col = action_direction[a]

    row = row + a_row
    col = col + a_col

    if row < 0 or row >= num_row or col < 0 or col >= num_col or (row, col) in walls:
        return s

    blue = blue or (row, col) in blues
    red = red or (row, col) in reds

    return np.asarray((row, col, blue, red))

def get_reward(s, a, s_prime):
    if is_done(s):
        return 0.0
    return -1.0

def L1(s1, s2):
    return 1.0 * np.sum(np.abs(np.asarray(s1) - np.asarray(s2)))

In [66]:
#Very adhoc, requires no dynamics
distance_cache = {}

def distance(a, b):
    global distance_cache
    ta = tuple(a)
    tb = tuple(b)
    if (ta,tb) in distance_cache:
        return distance_cache[(ta,tb)]
    
    row_a, col_a, blue_a, red_a = ta
    row_b, col_b, blue_b, red_b = tb
    
    blue_match = blue_a == blue_b
    red_match = red_a == red_b
    
    if red_match and not blue_match:
        dist = min([L1((row_a,col_a),blue) + L1(blue, (row_b, col_b)) for blue in blues])
    elif not red_match and blue_match:
        dist = min([L1((row_a,col_a),red) + L1(red, (row_b, col_b)) for red in reds])
    elif not red_match and not blue_match:
        x = min([L1((row_a,col_a),red) + L1(red, blue) + L1(blue, (row_b, col_b)) for red, blue in zip(reds, blues)])
        y = min([L1((row_a,col_a),blue) + L1(blue, red) + L1(red, (row_b, col_b)) for red, blue in zip(reds, blues)])
        dist = min(x,y)
    else:
        dist = L1(ta, tb)
    distance_cache[(ta,tb)] = dist
    return dist

In [73]:
BATCH_SIZE = 128
GAMMA = 0.999
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 4000
TARGET_UPDATE = 50
HIDDEN_LAYER_SIZE = 30
EPISODE_LENGTH = 75
LEARNING_RATE = 0.0008

DIVERSE_EXPERT = True
BACKPLAY = True

num_episodes = 40000
random.seed(6)
distance_cache = {}

policy_net = miniDQN(HIDDEN_LAYER_SIZE).to(device)
target_net = miniDQN(HIDDEN_LAYER_SIZE).to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

optimizer = optim.Adam(policy_net.parameters(), lr = LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.99999)
memory = ReplayMemory(40000)

In [74]:
steps_done = 0

def select_action(state, goal, test = False):
    global steps_done
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END) * math.exp(-1. * steps_done / EPS_DECAY)
    steps_done += 1
    if test or sample > eps_threshold:
        with torch.no_grad():
            x = torch.cat((state, goal), 1)
            return policy_net(x).max(1)[1].view(1, 1)
    else:
        return torch.tensor([[random.randrange(4)]], device=device, dtype=torch.long)

def select_action_entropy(state, goal, alpha):
    x = torch.cat((state, goal), 1)
    potential = (alpha * policy_net(x).detach()).exp().squeeze()
    dist = (potential / potential.sum()).numpy()
    return torch.tensor([[np.random.choice(4, p = dist)]], device=device, dtype=torch.long)
    

def select_expert_action(state, goal):
    row, col, blue, red = tuple(state)
    action_scores = np.zeros(4)
    for a in range(4):
        #Assumes determinism
        next_state = step(state, a)
        action_scores[a] = distance(next_state, goal)
    action = np.random.choice(np.flatnonzero(action_scores == action_scores.min()))
    return torch.tensor([[action]], device=device, dtype=torch.long)    

def sample_expert_trajectory():
    state = start_state
    next_state = None

    trajectory = []
    
    for t in range(EPISODE_LENGTH):
        action = select_expert_action(state, end_state)
        
        torch_state = torch.from_numpy(state).float().unsqueeze(0)
        next_state = step(state, action.item())
        torch_next_state = torch.from_numpy(next_state).float().unsqueeze(0)
        reward = get_reward(state, action.item(), next_state)
        done = is_done(state)
        reward = torch.tensor([reward], device=device)
        if done:
            torch_next_state = None

        trajectory.append((torch_state, action, torch_next_state, reward))
        if done:
            break
        state = next_state
    return trajectory

In [75]:
def optimize_model():
    if len(memory) < BATCH_SIZE:
        return
    transitions = memory.sample(BATCH_SIZE)
    batch = Transition(*zip(*transitions))

    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,batch.next_state)),
                                  device=device, dtype=torch.uint8)
    non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])
    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)

    state_action_values = policy_net(state_batch).gather(1, action_batch)

    #DDQN
    policy_action_indices = policy_net(non_final_next_states).max(1)[1].detach()
    target_actions = target_net(non_final_next_states).detach()
    next_state_values = torch.zeros(BATCH_SIZE, device=device)
    next_state_values[non_final_mask] = target_actions.gather(1, policy_action_indices.unsqueeze(1)).squeeze()
    expected_state_action_values = (next_state_values * GAMMA) + reward_batch

    loss = F.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1))

    
    # Optimize the model
    optimizer.zero_grad()
    loss.backward()
    for param in policy_net.parameters():
        param.grad.data.clamp_(-1, 1)
    optimizer.step()
    return loss.item()

In [76]:
losses = [0]
torch_end = torch.from_numpy(end_state).float().unsqueeze(0)
expert_trajectories = [sample_expert_trajectory()]
for i in range(4000):
    if DIVERSE_EXPERT:
        expert_trajectories.append(sample_expert_trajectory())
    for torch_state, action, torch_next_state, reward in expert_trajectories[-1]:
        memory.push(combine(torch_state, torch_end), action, combine(torch_next_state, torch_end), reward)

for i_episode in range(num_episodes):
    state = start_state
    next_state = None
    final = i_episode == num_episodes - 1
    trajectory = []
    
    if i_episode % 500 == 0:
        scheduler.step()
    if BACKPLAY:
        trajectory_index = np.random.choice(len(expert_trajectories))
        expert_trajectory = expert_trajectories[trajectory_index]
        
        start_index = np.random.choice(len(expert_trajectory))
        if np.random.random_sample() < 0.5 or final:
            start_index = 0
        state = expert_trajectory[start_index][0].squeeze().data.numpy()
    
    
    
    for t in range(EPISODE_LENGTH):
        torch_state = torch.from_numpy(state).float().unsqueeze(0)

        action = select_action(torch_state, torch_end, final)
            
        next_state = step(state, action.item())
        torch_next_state = torch.from_numpy(next_state).float().unsqueeze(0)
        reward = get_reward(state, action.item(), next_state)
        done = is_done(state)
        reward = torch.tensor([reward], device=device)

        if done:
            torch_next_state = None

        trajectory.append((torch_state, action, torch_next_state, reward))

        loss = optimize_model()
        losses.append(loss if loss is not None else 0)
        if final:
            print(state)
        if done:
            break
        state = next_state

    if i_episode % 1000 == 0:
        print(i_episode)
        print(np.mean(np.asarray(losses)))
        losses = []

    #HER updates with goal as real goal OR last state in trajectory
    for torch_state, action, torch_next_state, reward in trajectory:
        memory.push(combine(torch_state, torch_end), action, combine(torch_next_state, torch_end), reward)
        if not is_done(state):
            torch_local_goal = torch.from_numpy(state).float().unsqueeze(0)
            local_reward = torch.equal(torch_next_state, torch_local_goal)
            new_reward = torch.add(reward, local_reward)
            memory.push(combine(torch_state, torch_local_goal), action, combine(torch_next_state, torch_local_goal), new_reward)
                    
                          
    # Update the target network
    if i_episode % TARGET_UPDATE == 0:
        target_net.load_state_dict(policy_net.state_dict())
    
        
print('Complete')




0
2.1680260206523694
1000
0.04666393666250584
2000
0.12309491512466351
3000
0.3514441646369647
4000
0.27041079194123224
5000
0.26135663845699747
6000
0.13802155052809245
7000
0.19490906885353051
8000
0.38986544502118736
9000
0.4837424426547403
10000
0.3542641099774379
11000
0.1161667428795856
12000
0.08058793691787965
13000
0.07076517516224032
14000
0.03808578766909613
15000
0.007336091710721726
16000
0.00420067164045383
17000
0.008227111381766425
18000
0.010017994821338838
19000
0.007055548939348065
20000
0.00666453392841716
21000
0.008478398370269268
22000
0.004295759275985837
23000
0.0033461334071038656
24000
0.002422072687700963
25000
0.0029711568362704504
26000
0.01688124093129019
27000
0.0070400458811909995
28000
0.005273298605128459
29000
0.002779538278213819
30000
0.007325481924704427
31000
0.007920948982089506
32000
0.0068493822990630384
33000
0.0044673483301123544
34000
0.013496104122349044
35000
0.015354310896390553
36000
0.0050932630577498915
37000
0.002679690474500781
3800

In [77]:
def predicted_distance(state, goal):
    torch_state = torch.from_numpy(state).float().unsqueeze(0)
    torch_goal = torch.from_numpy(goal).float().unsqueeze(0)
    x = torch.cat((torch_state, torch_goal), 1)
    return -1 * policy_net(x).max(1)[0].view(1, 1).item()

def predicted_measure_ball(state, goal, n = 500):
    measure = {}
    torch_state = torch.from_numpy(state).float().unsqueeze(0)
    torch_goal = torch.from_numpy(goal).float().unsqueeze(0)
    x = torch.cat((torch_state, torch_goal), 1)
    for i in range(n):
        action = select_action_entropy(torch_state, torch_goal, alpha = 4.0)
        next_state = step(state, action.item())
        key = tuple(next_state)
        measure[key] = measure.get(key, 0) + 1.0/n
    return measure

def wasserstein(a, b, goal):
    if tuple(a) == tuple(b):
        return 0
    m_a = predicted_measure_ball(a, goal)
    m_b = predicted_measure_ball(b, goal)
    states = list(set(m_a.keys()) | set(m_b.keys()))
    
    m_a = [m_a.get(s, 0.0) for s in states]
    m_b = [m_b.get(s, 0.0) for s in states]
    dist = [[distance(np.asarray(s1),np.asarray(s2)) for s1 in states] for s2 in states]
    return emd(np.asarray(m_a), np.asarray(m_b), np.asarray(dist))

def curvature(a, b, goal):
    return 1 - wasserstein(a, b, goal) / max(distance(a, b), 0.001)

def forward_curvature(state, goal):
    #Assumes deterministic dynamics!
    next_states = [step(state, a) for a in range(4)]
    return min([curvature(state, next_state, goal) for next_state in next_states])

def print_space(fn, goal):
    np.set_printoptions(precision=2, suppress=True)
    for red in [0,1]:
        for blue in [0,1]:
            grid = np.zeros((num_row, num_col))
            for row in range(num_row):
                for col in range(num_col):
                    state = np.asarray((row, col, blue, red))
                    grid[row, col] = fn(state, goal)
            print("red = " + str(red))
            print("blue = " + str(blue))
            print(grid)

a = np.asarray([2,0,0,0])
b = np.asarray([1,0,0,0])
print(predicted_measure_ball(start_state, end_state))
print(curvature(a,b,end_state))

{(0, 1, 0, 0): 0.9800000000000008, (0, 0, 0, 0): 0.020000000000000004}
0.0020009859999993163


In [78]:
print_space(distance, end_state)
print('\n')
print_space(predicted_distance, end_state)

red = 0
blue = 0
[[20. 19. 18. 17. 16. 17. 18. 19.]
 [19. 20. 19. 18. 17. 18. 19. 20.]
 [18. 19. 20. 19. 18. 19. 20. 21.]
 [17. 18. 19. 20. 19. 20. 21. 22.]
 [16. 17. 18. 19. 20. 21. 22. 23.]
 [17. 18. 19. 20. 21. 22. 23. 24.]
 [18. 19. 20. 21. 22. 23. 24. 25.]
 [19. 20. 21. 22. 23. 24. 25. 26.]]
red = 0
blue = 1
[[12. 13. 14. 15. 16. 17. 18. 19.]
 [11. 12. 13. 14. 15. 16. 17. 18.]
 [10. 11. 12. 13. 14. 15. 16. 17.]
 [ 9. 10. 11. 12. 13. 14. 15. 16.]
 [ 8.  9. 10. 11. 12. 13. 14. 15.]
 [ 9. 10. 11. 12. 13. 14. 15. 16.]
 [10. 11. 12. 13. 14. 15. 16. 17.]
 [11. 12. 13. 14. 15. 16. 17. 18.]]
red = 1
blue = 0
[[12. 11. 10.  9.  8.  9. 10. 11.]
 [13. 12. 11. 10.  9. 10. 11. 12.]
 [14. 13. 12. 11. 10. 11. 12. 13.]
 [15. 14. 13. 12. 11. 12. 13. 14.]
 [16. 15. 14. 13. 12. 13. 14. 15.]
 [17. 16. 15. 14. 13. 14. 15. 16.]
 [18. 17. 16. 15. 14. 15. 16. 17.]
 [19. 18. 17. 16. 15. 16. 17. 18.]]
red = 1
blue = 1
[[12. 11. 10.  9.  8.  7.  6.  7.]
 [11. 10.  9.  8.  7.  6.  5.  6.]
 [10.  9.  8.  7.  

In [79]:
print_space(forward_curvature, end_state)



red = 0
blue = 0
[[ 0.   -0.   -0.02 -0.03  0.66 -0.26 -0.   -0.  ]
 [-0.2  -0.09 -0.2  -0.14 -0.01  0.   -0.01 -0.  ]
 [-0.09 -0.09 -0.08 -0.16 -0.   -0.03 -0.01 -0.  ]
 [-1.98 -0.1  -0.12 -0.19 -0.01 -0.   -0.02 -0.  ]
 [-0.   -1.99 -0.12 -0.09 -0.   -0.04 -0.04 -0.  ]
 [-1.06 -0.08 -0.11 -0.06 -0.01 -0.09 -1.94 -1.  ]
 [-1.27 -1.46 -0.   -0.   -0.01 -2.   -2.   -1.  ]
 [-1.18 -1.48 -0.   -0.   -0.   -2.   -2.   -0.  ]]
red = 0
blue = 1
[[-0.01 -0.01 -0.08 -0.03 -0.11 -0.01 -0.06  0.  ]
 [-0.01 -0.02 -0.06 -0.04  0.   -0.08 -0.02  0.  ]
 [ 0.02 -0.03 -0.06 -0.11 -0.03 -0.11 -0.1  -0.01]
 [-0.03 -0.   -0.07 -0.   -0.08 -0.06 -0.02 -0.01]
 [ 0.91 -0.01 -0.   -0.   -0.   -0.   -0.   -0.5 ]
 [-0.14 -0.2  -0.15 -0.07 -0.02 -0.03 -0.51 -0.51]
 [-0.   -0.   -0.   -0.   -0.   -0.03 -0.97 -0.95]
 [-0.   -0.   -0.   -0.   -0.   -1.14 -0.99 -0.  ]]
red = 1
blue = 0
[[-0.    0.01 -0.02 -0.03  0.49 -0.47 -0.   -0.  ]
 [-0.14 -0.15 -0.16 -0.07 -0.   -0.04 -0.01 -0.  ]
 [-0.01 -0.15 -0.14 -0.05 -0.