In [22]:
import math
import random
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple
from itertools import count
from PIL import Image
from pyemd import emd

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T

# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display

plt.ion()

# if gpu is to be used
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [23]:
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))


class ReplayMemory(object):

    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, *args):
        """Saves a transition."""
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

def combine(state, goal):
    if state is None:
        return None
    return torch.cat((state, goal), 1)

In [24]:
class miniDQN(nn.Module):

    def __init__(self, hidden):
        super(miniDQN, self).__init__()
        self.W1 = nn.Linear(8, hidden)
        self.W2 = nn.Linear(hidden, 4)
        nn.init.xavier_uniform_(self.W1.weight)
        nn.init.xavier_uniform_(self.W2.weight)

    def forward(self, x):
        x = F.relu(self.W1(x))
        return self.W2(x)

In [25]:
num_row = 10
num_col = 10

action_direction = [(0,1), (0,-1), (1,0), (-1,0)]
block = { "empty" : 0, "blue" : 1, "red" : 2, "goal" : 3, "wall" : 4}

reds = [(3,0)]
blues = [(0,3)]
start = (0,0)
end = (4,4)
walls = []

start_state = np.asarray(list(start) + [0,0])
end_state = np.asarray(list(end) + [1,1])


def is_done(s):
    row, col, blue, red = tuple(s)
    return blue and red and (row, col) == end

def step(s, a):
    row, col, blue, red = tuple(s)
    a_row, a_col = action_direction[a]

    row = row + a_row
    col = col + a_col

    if row < 0 or row >= num_row or col < 0 or col >= num_col or (row, col) in walls:
        return s

    blue = blue or (row, col) in blues
    red = red or (row, col) in reds

    return np.asarray((row, col, blue, red))

def get_reward(s, a, s_prime):
    if is_done(s):
        return 0.0
    return -1.0

def L1(s1, s2):
    return 1.0 * np.sum(np.abs(np.asarray(s1) - np.asarray(s2)))

In [26]:
#Very adhoc, requires no dynamics
distance_cache = {}

def distance(a, b):
    global distance_cache
    ta = tuple(a)
    tb = tuple(b)
    if (ta,tb) in distance_cache:
        return distance_cache[(ta,tb)]
    
    row_a, col_a, blue_a, red_a = ta
    row_b, col_b, blue_b, red_b = tb
    
    blue_match = blue_a == blue_b
    red_match = red_a == red_b
    
    if red_match and not blue_match:
        dist = min([L1((row_a,col_a),blue) + L1(blue, (row_b, col_b)) for blue in blues])
    elif not red_match and blue_match:
        dist = min([L1((row_a,col_a),red) + L1(red, (row_b, col_b)) for red in reds])
    elif not red_match and not blue_match:
        x = min([L1((row_a,col_a),red) + L1(red, blue) + L1(blue, (row_b, col_b)) for red, blue in zip(reds, blues)])
        y = min([L1((row_a,col_a),blue) + L1(blue, red) + L1(red, (row_b, col_b)) for red, blue in zip(reds, blues)])
        dist = min(x,y)
    else:
        dist = L1(ta, tb)
    distance_cache[(ta,tb)] = dist
    return dist

In [27]:
BATCH_SIZE = 128
GAMMA = 0.9999
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 6000
TARGET_UPDATE = 50
HIDDEN_LAYER_SIZE = 20
EPISODE_LENGTH = 30
LEARNING_RATE = 0.0005

DIVERSE_EXPERT = True

num_episodes = 40000
random.seed(10)
distance_cache = {}

policy_net = miniDQN(HIDDEN_LAYER_SIZE).to(device)
target_net = miniDQN(HIDDEN_LAYER_SIZE).to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

optimizer = optim.Adam(policy_net.parameters(), lr = LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.99999)
memory = ReplayMemory(40000)
expert_memory = ReplayMemory(40000)

In [28]:
steps_done = 0

#Epsilon greedy actions
def select_action(state, goal, test = False):
    global steps_done
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END) * math.exp(-1. * steps_done / EPS_DECAY)
    steps_done += 1
    if test or sample > eps_threshold:
        with torch.no_grad():
            x = torch.cat((state, goal), 1)
            return policy_net(x).max(1)[1].view(1, 1)
    else:
        return torch.tensor([[random.randrange(4)]], device=device, dtype=torch.long)

#Max entropy action
def select_action_entropy(state, goal, alpha):
    x = torch.cat((state, goal), 1)
    potential = (alpha * policy_net(x).detach()).exp().squeeze()
    dist = (potential / potential.sum()).numpy()
    return torch.tensor([[np.random.choice(4, p = dist)]], device=device, dtype=torch.long)
    

def select_expert_action(state, goal):
    row, col, blue, red = tuple(state)
    action_scores = np.zeros(4)
    for a in range(4):
        #Assumes determinism
        next_state = step(state, a)
        action_scores[a] = distance(next_state, goal)
    action = np.random.choice(np.flatnonzero(action_scores == action_scores.min()))
    return torch.tensor([[action]], device=device, dtype=torch.long)    

def sample_expert_trajectory():
    state = start_state
    next_state = None

    trajectory = []
    
    for t in range(EPISODE_LENGTH):
        action = select_expert_action(state, end_state)
        
        torch_state = torch.from_numpy(state).float().unsqueeze(0)
        next_state = step(state, action.item())
        torch_next_state = torch.from_numpy(next_state).float().unsqueeze(0)
        reward = get_reward(state, action.item(), next_state)
        done = is_done(state)
        reward = torch.tensor([reward], device=device)
        if done:
            torch_next_state = None

        trajectory.append((torch_state, action, torch_next_state, reward))
        if done:
            break
        state = next_state
    return trajectory

In [29]:
def optimize_model():
    if len(memory) < BATCH_SIZE:
        return
    #randomly learn over expert trajectories or rollout buffer
    transitions = memory.sample(BATCH_SIZE) if np.random.random() > 0.5 else expert_memory.sample(BATCH_SIZE)
    batch = Transition(*zip(*transitions))

    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,batch.next_state)),
                                  device=device, dtype=torch.uint8)
    non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])
    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)

    state_action_values = policy_net(state_batch).gather(1, action_batch)

    #DDQN
    policy_action_indices = policy_net(non_final_next_states).max(1)[1].detach()
    target_actions = target_net(non_final_next_states).detach()
    next_state_values = torch.zeros(BATCH_SIZE, device=device)
    next_state_values[non_final_mask] = target_actions.gather(1, policy_action_indices.unsqueeze(1)).squeeze()
    expected_state_action_values = (next_state_values * GAMMA) + reward_batch

    loss = F.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1))

    
    # Optimize the model
    optimizer.zero_grad()
    loss.backward()
    for param in policy_net.parameters():
        param.grad.data.clamp_(-1, 1)
    optimizer.step()
    return loss.item()

In [30]:
losses = [0]
torch_end = torch.from_numpy(end_state).float().unsqueeze(0)
expert_trajectory = sample_expert_trajectory()
for i in range(1000):
    if DIVERSE_EXPERT:
        expert_trajectory = sample_expert_trajectory()
    for torch_state, action, torch_next_state, reward in expert_trajectory:
        memory.push(combine(torch_state, torch_end), action, combine(torch_next_state, torch_end), reward)
        expert_memory.push(combine(torch_state, torch_end), action, combine(torch_next_state, torch_end), reward)

for i_episode in range(num_episodes):
    state = start_state
    next_state = None
    final = i_episode == num_episodes - 1
    trajectory = []
    
    if i_episode % 1000 == 0:
        scheduler.step()
    
    for t in range(EPISODE_LENGTH):
        torch_state = torch.from_numpy(state).float().unsqueeze(0)

        action = select_action(torch_state, torch_end, final)
            
        next_state = step(state, action.item())
        torch_next_state = torch.from_numpy(next_state).float().unsqueeze(0)
        reward = get_reward(state, action.item(), next_state)
        done = is_done(state)
        reward = torch.tensor([reward], device=device)

        if done:
            torch_next_state = None

        trajectory.append((torch_state, action, torch_next_state, reward))

        loss = optimize_model()
        losses.append(loss if loss is not None else 0)
        if final:
            print(state)
        if done:
            break
        state = next_state

    if i_episode % 1000 == 0:
        print(i_episode)
        print(np.mean(np.asarray(losses)))
        losses = []

    #HER updates with goal as real goal OR last state in trajectory
    for torch_state, action, torch_next_state, reward in trajectory:
        memory.push(combine(torch_state, torch_end), action, combine(torch_next_state, torch_end), reward)
        if not is_done(state):
            torch_local_goal = torch.from_numpy(state).float().unsqueeze(0)
            local_reward = torch.equal(torch_next_state, torch_local_goal)
            new_reward = torch.add(reward, local_reward)
            memory.push(combine(torch_state, torch_local_goal), action, combine(torch_next_state, torch_local_goal), new_reward)
                    
                          
    # Update the target network
    if i_episode % TARGET_UPDATE == 0:
        target_net.load_state_dict(policy_net.state_dict())
    
        
print('Complete')




0
0.6705022069715685
1000
0.051104274848802804
2000
0.142641825228066
3000
0.13363902751413226
4000
0.06597173460574471
5000
0.03276165601295216
6000
0.019828053240719422
7000
0.01728976880849489
8000
0.013971030690448095
9000
0.008806942815859286
10000
0.006305696236034837
11000
0.006190971785322995
12000
0.007404287356154946
13000
0.007323471727103416
14000
0.00637918957009906
15000
0.006865213173742146
16000
0.0063906448522493145
17000
0.005398927473592558
18000
0.006953048875291738
19000
0.008520299124991486
20000
0.006445021310289573
21000
0.006287495155593853
22000
0.00660173453577885
23000
0.006569667589129197
24000
0.005646780846157743
25000
0.004008308242556866
26000
0.004358893985784069
27000
0.004313568410255656
28000
0.004155850381836531
29000
0.003946640858522049
30000
0.004228500055397435
31000
0.0052311608782148085
32000
0.007312002069736764
33000
0.006580786054960197
34000
0.005132740762430584
35000
0.003852381779058704
36000
0.006080211624781541
37000
0.008038819697875

In [31]:
def predicted_distance(state, goal):
    torch_state = torch.from_numpy(state).float().unsqueeze(0)
    torch_goal = torch.from_numpy(goal).float().unsqueeze(0)
    x = torch.cat((torch_state, torch_goal), 1)
    return -1 * policy_net(x).max(1)[0].view(1, 1).item()

def predicted_measure_ball(state, goal, n = 500):
    measure = {}
    torch_state = torch.from_numpy(state).float().unsqueeze(0)
    torch_goal = torch.from_numpy(goal).float().unsqueeze(0)
    x = torch.cat((torch_state, torch_goal), 1)
    for i in range(n):
        action = select_action_entropy(torch_state, torch_goal, alpha = 4.0)
        next_state = step(state, action.item())
        key = tuple(next_state)
        measure[key] = measure.get(key, 0) + 1.0/n
    return measure

def wasserstein(a, b, goal):
    if tuple(a) == tuple(b):
        return 0
    m_a = predicted_measure_ball(a, goal)
    m_b = predicted_measure_ball(b, goal)
    states = list(set(m_a.keys()) | set(m_b.keys()))
    
    m_a = [m_a.get(s, 0.0) for s in states]
    m_b = [m_b.get(s, 0.0) for s in states]
    dist = [[distance(np.asarray(s1),np.asarray(s2)) for s1 in states] for s2 in states]
    return emd(np.asarray(m_a), np.asarray(m_b), np.asarray(dist))


In [32]:
values = np.zeros((num_row, num_col, 2, 2))
for index in np.ndindex(num_row, num_col, 2, 2):
    values[index] = predicted_distance(np.asarray(index), end_state)
print(np.round(values[:,:,0,0], 2))
print(np.round(values[:,:,1,0], 2))
print(np.round(values[:,:,0,1], 2))
print(np.round(values[:,:,1,1], 2))

[[14.   13.   12.   13.96 14.86 15.8  16.48 17.31 18.14 18.96]
 [13.   13.93 13.27 14.34 14.93 15.15 15.24 15.33 15.75 16.34]
 [12.04 12.94 13.73 12.73 12.96 13.19 13.42 13.65 13.92 14.2 ]
 [11.89 12.47 12.92 12.52 13.04 13.69 14.36 15.06 15.76 16.46]
 [15.   14.65 15.06 15.71 16.94 17.63 18.34 19.04 19.74 20.44]
 [16.38 16.03 16.41 17.16 19.   21.61 22.31 23.01 23.71 24.41]
 [16.95 17.09 18.01 19.03 20.19 23.68 26.28 26.99 27.69 28.39]
 [17.14 17.62 18.8  20.33 21.74 25.22 28.36 30.96 31.66 32.36]
 [17.77 18.25 19.41 20.95 22.36 25.39 29.73 33.05 35.63 36.26]
 [18.4  18.88 20.02 21.58 22.98 25.14 29.66 33.66 37.02 39.82]]
[[ 8.14  9.07 10.02 11.01 12.18 13.17 14.89 15.5  15.91 15.61]
 [ 7.17  8.1   9.1  10.06 10.81 11.77 12.39 13.   13.44 13.18]
 [ 6.1   7.16  8.11  9.04  9.87 10.27 10.67 11.2  11.68 11.42]
 [ 4.94  6.12  7.14  8.11  9.33 10.69 11.69 12.7  13.55 13.67]
 [ 8.19  7.42  8.03  8.98 10.57 12.6  13.66 14.66 15.53 15.63]
 [ 7.63  8.03  9.17 10.12 10.77 14.25 15.62 16.63 17.5

In [33]:
def reflect_state(state):
    row, col, blue, red = tuple(state)
    return np.asarray((col, row, red, blue))

def compare(state):
    reflected = reflect_state(state)
    m_a = predicted_measure_ball(state, end_state)
    m_b = predicted_measure_ball(reflected, end_state)
    print(m_a)
    print(m_b)

compare(np.asarray((0,1,1,1)))
    

{(1, 1, 1, 1): 0.7960000000000006, (0, 2, 1, 1): 0.19600000000000015, (0, 1, 1, 1): 0.004, (0, 0, 1, 1): 0.004}
{(1, 1, 1, 1): 0.6200000000000004, (2, 0, 1, 1): 0.3720000000000003, (1, 0, 1, 1): 0.008}
