In [44]:
import math
import random
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple
from itertools import count
from PIL import Image
from pyemd import emd

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T

# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display

plt.ion()

# if gpu is to be used
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [45]:
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))


class ReplayMemory(object):

    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, *args):
        """Saves a transition."""
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

def combine(state, goal):
    if state is None:
        return None
    return torch.cat((state, goal), 1)

In [46]:
class miniDQN(nn.Module):

    def __init__(self, hidden):
        super(miniDQN, self).__init__()
        self.W1 = nn.Linear(8, hidden)
        self.W2 = nn.Linear(hidden, 4)
        nn.init.xavier_uniform_(self.W1.weight)
        nn.init.xavier_uniform_(self.W2.weight)

    def forward(self, x):
        x = F.relu(self.W1(x))
        return self.W2(x)

In [61]:
num_row = 5
num_col = 5

action_word = ["RIGHT", "LEFT", "DOWN", "UP"]
action_direction = [(0,1), (0,-1), (1,0), (-1,0)]
block = { "empty" : 0, "blue" : 1, "red" : 2, "goal" : 3, "wall" : 4}

reds = [(3,0)]
blues = [(0,3)]
start = (0,0)
end = (4,4)
walls = []

start_state = np.asarray(list(start) + [0,0])
end_state = np.asarray(list(end) + [1,1])


def is_done(s):
    row, col, blue, red = tuple(s)
    return blue and red and (row, col) == end

def step(s, a):
    row, col, blue, red = tuple(s)
    a_row, a_col = action_direction[a]

    row = row + a_row
    col = col + a_col

    if row < 0 or row >= num_row or col < 0 or col >= num_col or (row, col) in walls:
        return s

    blue = blue or (row, col) in blues
    red = red or (row, col) in reds

    return np.asarray((row, col, blue, red))

def get_reward(s):
    if is_done(s):
        return 1.0
    return -1.0

def L1(s1, s2):
    return 1.0 * np.sum(np.abs(np.asarray(s1) - np.asarray(s2)))

def get_expert_skeleton():
    trajs = {}
    for blue in blues:
        for red in reds:
            trajs[(blue, red, end)] = L1(start, blue) + L1(blue, red) + L1(red, end)
            trajs[(red, blue, end)] = L1(start, red) + L1(red, blue) + L1(blue, end)
    
    return min(trajs, key=trajs.get)


In [62]:
#Very adhoc, requires no dynamics
def distance(a, b):
    row_a, col_a, blue_a, red_a = tuple(a)
    row_b, col_b, blue_b, red_b = tuple(b)
    
    blue_match = blue_a == blue_b
    red_match = red_a == red_b
    
    if red_match and not blue_match:
        return min([L1((row_a,col_a),blue) + L1(blue, (row_b, col_b)) for blue in blues])
    if not red_match and blue_match:
        return min([L1((row_a,col_a),red) + L1(red, (row_b, col_b)) for red in reds])
    if not red_match and not blue_match:
        return min([L1((row_a,col_a),red) + L1(red, blue) + L1(blue, (row_b, col_b)) for red, blue in zip(reds, blues)])
    return L1(a,b)

In [63]:
BATCH_SIZE = 128
GAMMA = 0.999
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 4000
TARGET_UPDATE = 50
HIDDEN_LAYER_SIZE = 20
EPISODE_LENGTH = 25
LEARNING_RATE = 0.0008

DIVERSE_EXPERT = False
BACKPLAY = True

num_episodes = 20000
random.seed(5)

policy_net = miniDQN(HIDDEN_LAYER_SIZE).to(device)
target_net = miniDQN(HIDDEN_LAYER_SIZE).to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

optimizer = optim.Adam(policy_net.parameters(), lr = LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.99999)
memory = ReplayMemory(40000)

In [64]:
steps_done = 0

def select_action(state, goal, test = False):
    global steps_done
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END) * math.exp(-1. * steps_done / EPS_DECAY)
    steps_done += 1
    if test or sample > eps_threshold:
        with torch.no_grad():
            x = torch.cat((state, goal), 1)
            #x = state
            return policy_net(x).max(1)[1].view(1, 1)
    else:
        return torch.tensor([[random.randrange(4)]], device=device, dtype=torch.long)

def select_action_entropy(state, goal, alpha):
    x = torch.cat((state, goal), 1)
    potential = (alpha * policy_net(x).detach()).exp().squeeze()
    dist = (potential / potential.sum()).numpy()
    return torch.tensor([[np.random.choice(4, p = dist)]], device=device, dtype=torch.long)
    

def select_expert_action(s, expert_skeleton, diverse = False):
    row, col, blue, red = tuple(s)
    action_scores = np.zeros(4)
    for a in range(4):
        s_prime = step(s, a)
        coordinates = (s_prime[0], s_prime[1])
        action_scores[a] = L1(coordinates, expert_skeleton[0])
    if not diverse:
        action = np.argmin(action_scores)
    else:
        action = np.random.choice(np.flatnonzero(action_scores == action_scores.min()))
    return torch.tensor([[action]], device=device, dtype=torch.long)    

def sample_expert_trajectory():
    state = start_state
    next_state = None
    
    expert_skeleton = get_expert_skeleton()
    trajectory = []
    
    for t in range(EPISODE_LENGTH):
        if (state[0], state[1]) == expert_skeleton[0]:
            expert_skeleton = expert_skeleton[1:]
        action = select_expert_action(state, expert_skeleton, diverse = DIVERSE_EXPERT)
        
        torch_state = torch.from_numpy(state).float().unsqueeze(0)
        next_state = step(state, action.item())
        torch_next_state = torch.from_numpy(next_state).float().unsqueeze(0)
        reward = get_reward(next_state)
        done = is_done(next_state)
        reward = torch.tensor([reward], device=device)
        if done:
            torch_next_state = None

        trajectory.append((torch_state, action, torch_next_state, reward))
        state = next_state
        if done:
            break
    return trajectory

In [66]:
def optimize_model():
    if len(memory) < BATCH_SIZE:
        return
    transitions = memory.sample(BATCH_SIZE)
    batch = Transition(*zip(*transitions))

    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,batch.next_state)),
                                  device=device, dtype=torch.uint8)
    non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])
    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)

    state_action_values = policy_net(state_batch).gather(1, action_batch)

    #DDQN
    policy_action_indices = policy_net(non_final_next_states).max(1)[1].detach()
    target_actions = target_net(non_final_next_states).detach()
    next_state_values = torch.zeros(BATCH_SIZE, device=device)
    next_state_values[non_final_mask] = target_actions.gather(1, policy_action_indices.unsqueeze(1)).squeeze()
    expected_state_action_values = (next_state_values * GAMMA) + reward_batch

    loss = F.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1))

    
    # Optimize the model
    optimizer.zero_grad()
    loss.backward()
    for param in policy_net.parameters():
        param.grad.data.clamp_(-1, 1)
    optimizer.step()
    return loss.item()

In [67]:
losses = [0]
expert_trajectory = sample_expert_trajectory()
torch_end = torch.from_numpy(end_state).float().unsqueeze(0)
for i in range(1000):
    for torch_state, action, torch_next_state, reward in expert_trajectory:
        memory.push(combine(torch_state, torch_end), action, combine(torch_next_state, torch_end), reward)

for i_episode in range(num_episodes):
    next_state = None
    final = i_episode == num_episodes - 1
    trajectory = []
    
    if i_episode % 500 == 0:
        scheduler.step()
    if BACKPLAY:
        start_index = np.random.choice(len(expert_trajectory))
        if np.random.random_sample() < 0.2 or final:
            start_index = 0
        state = expert_trajectory[start_index][0].squeeze().data.numpy()
    else:
        state = start_state
    
    
    
    for t in range(EPISODE_LENGTH):
        torch_state = torch.from_numpy(state).float().unsqueeze(0)

        action = select_action(torch_state, torch_end, final)
            
        next_state = step(state, action.item())
        torch_next_state = torch.from_numpy(next_state).float().unsqueeze(0)
        reward = get_reward(next_state)
        done = is_done(next_state)
        reward = torch.tensor([reward], device=device)

        if done:
            torch_next_state = None

        trajectory.append((torch_state, action, torch_next_state, reward))
        state = next_state

        loss = optimize_model()
        losses.append(loss if loss is not None else 0)
        if final:
            print(state)
        if done:
            break

    if i_episode % 1000 == 0:
        print(i_episode)
        print(np.mean(np.asarray(losses)))
        losses = []

    #HER updates with goal as real goal OR last state in trajectory
    for torch_state, action, torch_next_state, reward in trajectory:
        memory.push(combine(torch_state, torch_end), action, combine(torch_next_state, torch_end), reward)
        if not is_done(state):
            torch_local_goal = torch.from_numpy(state).float().unsqueeze(0)
            local_reward = torch.equal(torch_next_state, torch_local_goal)
            new_reward = torch.add(reward, local_reward)
            memory.push(combine(torch_state, torch_local_goal), action, combine(torch_next_state, torch_local_goal), new_reward)
                    
                          
    # Update the target network
    if i_episode % TARGET_UPDATE == 0:
        target_net.load_state_dict(policy_net.state_dict())
    
        
print('Complete')




0
0.10976577177643776
1000
0.034756281135490986
2000
0.0669450320500668
3000
0.1255519562301086
4000
0.1615730989518105
5000
0.10344901194569385
6000
0.02891480982868688
7000
0.012584909677275314
8000
0.007967893652050157
9000
0.005176566027666484
10000
0.003099541823204673
11000
0.002312721141716282
12000
0.0023100369732335065
13000
0.0016060443382644326
14000
0.0015930945076919301
15000
0.0016017675265469185
16000
0.0011980306033430575
17000
0.0009320711940304085
18000
0.0010434620746009758
19000
0.001176036336413585
[0. 1. 0. 0.]
[0. 2. 0. 0.]
[0. 3. 1. 0.]
[1. 3. 1. 0.]
[2. 3. 1. 0.]
[2. 2. 1. 0.]
[2. 1. 1. 0.]
[2. 0. 1. 0.]
[3. 0. 1. 1.]
[3. 1. 1. 1.]
[3. 2. 1. 1.]
[3. 3. 1. 1.]
[3. 4. 1. 1.]
[4. 4. 1. 1.]
Complete


In [95]:
def predicted_distance(state, goal):
    torch_state = torch.from_numpy(state).float().unsqueeze(0)
    torch_goal = torch.from_numpy(goal).float().unsqueeze(0)
    x = torch.cat((torch_state, torch_goal), 1)
    return policy_net(x).max(1)[0].view(1, 1).item()

def predicted_measure_ball(state, goal, n = 500):
    measure = {}
    torch_state = torch.from_numpy(state).float().unsqueeze(0)
    torch_goal = torch.from_numpy(goal).float().unsqueeze(0)
    x = torch.cat((torch_state, torch_goal), 1)
    for i in range(n):
        action = select_action_entropy(torch_state, torch_goal, alpha = 2.0)
        next_state = step(state, action.item())
        key = tuple(next_state)
        measure[key] = measure.get(key, 0) + 1.0/n
    return measure

def wasserstein(a, b, goal):
    if tuple(a) == tuple(b):
        return 0
    m_a = predicted_measure_ball(a, goal)
    m_b = predicted_measure_ball(b, goal)
    states = list(set(m_a.keys()) | set(m_b.keys()))
    
    m_a = [m_a.get(s, 0.0) for s in states]
    m_b = [m_b.get(s, 0.0) for s in states]
    dist = [[distance(np.asarray(s1),np.asarray(s2)) for s1 in states] for s2 in states]
    return emd(np.asarray(m_a), np.asarray(m_b), np.asarray(dist))

def curvature(a, b, goal):
    return 1 - wasserstein(a, b, goal) / max(distance(a, b), 0.001)

def local_curvature(state, goal):
    #Again, assume determinisim to find all neighbors
    next_states = [step(state, a) for a in range(4)]
    return min([curvature(state, next_state, goal) for next_state in next_states])

def print_space(fn, goal):
    np.set_printoptions(precision=2, suppress=True)
    for red in [0,1]:
        for blue in [0,1]:
            grid = np.zeros((num_row, num_col))
            for row in range(num_row):
                for col in range(num_col):
                    state = np.asarray((row, col, blue, red))
                    grid[row, col] = fn(state, goal)
            print("red = " + str(red))
            print("blue = " + str(blue))
            print(grid)

a = np.asarray([2,0,0,0])
b = np.asarray([1,0,0,0])
print(curvature(a,b,end_state))

0.00200099799999931


In [96]:
print_space(predicted_distance, end_state)

red = 0
blue = 0
[[-11.96 -10.95  -9.95 -11.1  -13.36]
 [-12.98 -11.97 -10.97  -9.78  -8.91]
 [-14.96 -14.18 -12.39 -10.45  -8.86]
 [-18.45 -16.51 -14.58 -12.64 -10.68]
 [-20.64 -18.7  -16.77 -14.83 -12.87]]
red = 0
blue = 1
[[-6.   -6.99 -7.98 -8.98 -9.97]
 [-5.   -6.   -6.99 -7.99 -8.98]
 [-3.98 -4.97 -5.99 -7.   -8.02]
 [-5.09 -4.13 -5.19 -6.26 -7.07]
 [-6.19 -4.33 -4.15 -4.9  -6.44]]
red = 1
blue = 0
[[-58.29 -58.28 -58.25 -59.17 -60.13]
 [-61.45 -61.82 -62.01 -62.92 -62.89]
 [-64.74 -65.3  -65.15 -64.62 -63.35]
 [-67.44 -66.54 -65.27 -64.15 -62.93]
 [-65.85 -64.94 -64.03 -63.09 -62.15]]
red = 1
blue = 1
[[-4.1  -4.42 -3.48 -4.91 -6.58]
 [-4.58 -3.63 -2.69 -1.75 -2.29]
 [-3.79 -2.85 -1.91 -0.96  0.01]
 [-2.98 -1.98 -0.99 -0.    1.  ]
 [-1.98 -0.99  0.    0.99  1.45]]


In [97]:
print_space(local_curvature, end_state)



red = 0
blue = 0
[[-0.04 -0.04 -0.08  0.   -0.92]
 [-0.   -0.08 -0.1  -0.51  0.15]
 [-2.   -0.06 -0.07 -0.11  0.02]
 [-0.   -2.   -0.04 -0.02  0.  ]
 [-0.   -0.   -0.   -0.   -0.  ]]
red = 0
blue = 1
[[-0.29 -0.06 -0.14 -0.13 -0.2 ]
 [-0.27 -0.17 -0.15 -0.18 -0.2 ]
 [-0.12  0.   -0.07 -0.06 -0.03]
 [ 0.11 -0.45 -0.09 -0.09 -0.07]
 [-0.84 -0.03 -0.08 -0.07 -0.07]]
red = 1
blue = 0
[[-0.   -0.   -0.   -0.   -0.97]
 [-0.   -0.   -0.   -1.97 -0.  ]
 [-0.   -0.   -0.   -0.   -0.  ]
 [-0.   -0.   -0.   -0.   -0.  ]
 [-0.   -0.   -0.   -0.   -0.  ]]
red = 1
blue = 1
[[-0.13 -0.09 -0.3  -0.37 -0.03]
 [-0.1  -0.12 -0.11 -0.08 -0.09]
 [-0.1  -0.07 -0.15 -0.02 -0.09]
 [-0.08 -0.09 -0.1  -0.02  0.  ]
 [ 0.05 -0.   -0.1  -0.08  0.67]]
