In [1]:
import math
import random
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple
from itertools import count
from PIL import Image
from pyemd import emd

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T

# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display

plt.ion()

# if gpu is to be used
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))


class ReplayMemory(object):

    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, *args):
        """Saves a transition."""
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

def combine(state, goal):
    if state is None:
        return None
    return torch.cat((state, goal), 1)

In [60]:
class miniDQN(nn.Module):

    def __init__(self, hidden):
        super(miniDQN, self).__init__()
        self.W1 = nn.Linear(8, hidden) # state + goal
        self.W2 = nn.Linear(hidden, 4)
        nn.init.xavier_uniform_(self.W1.weight)
        nn.init.xavier_uniform_(self.W2.weight)

    def forward(self, x):
        x = F.relu(self.W1(x))
        return self.W2(x)

#For now, embedding on states (not state-action pairs) because we're lazy
class embeddingNN(nn.Module):

    def __init__(self, hidden, output):
        super(embeddingNN, self).__init__()
        self.W1 = nn.Linear(8, hidden) # state + goal
        self.W2 = nn.Linear(hidden, output)
        nn.init.xavier_uniform_(self.W1.weight)
        nn.init.xavier_uniform_(self.W2.weight)

    def forward(self, x):
        x = F.relu(self.W1(x))
        return self.W2(x)

In [61]:
num_row = 10
num_col = 10

action_direction = [(0,1), (0,-1), (1,0), (-1,0)]

reds = [(3,0)]
blues = [(0,3)]
walls = []

start_state = np.asarray([0,0,0,0])
end_state = np.asarray([4,4,1,1])
torch_end = torch.from_numpy(end_state).float().unsqueeze(0)


def is_done(s):
    return np.array_equal(s, end_state)

def step(s, a):
    row, col, blue, red = tuple(s)
    a_row, a_col = action_direction[a]

    row = row + a_row
    col = col + a_col

    if row < 0 or row >= num_row or col < 0 or col >= num_col or (row, col) in walls:
        return s

    blue = blue or (row, col) in blues
    red = red or (row, col) in reds

    return np.asarray((row, col, blue, red))

def get_reward(s, a, s_prime):
    if is_done(s):
        return 0.0
    return -1.0

def L1(s1, s2):
    return 1.0 * np.sum(np.abs(np.asarray(s1) - np.asarray(s2)))

In [62]:
#Very adhoc, requires no dynamics
distance_cache = {}

def distance(a, b):
    global distance_cache
    ta = tuple(a)
    tb = tuple(b)
    if (ta,tb) in distance_cache:
        return distance_cache[(ta,tb)]
    
    row_a, col_a, blue_a, red_a = ta
    row_b, col_b, blue_b, red_b = tb
    
    blue_match = blue_a == blue_b
    red_match = red_a == red_b
    
    if red_match and not blue_match:
        dist = min([L1((row_a,col_a),blue) + L1(blue, (row_b, col_b)) for blue in blues])
    elif not red_match and blue_match:
        dist = min([L1((row_a,col_a),red) + L1(red, (row_b, col_b)) for red in reds])
    elif not red_match and not blue_match:
        x = min([L1((row_a,col_a),red) + L1(red, blue) + L1(blue, (row_b, col_b)) for red, blue in zip(reds, blues)])
        y = min([L1((row_a,col_a),blue) + L1(blue, red) + L1(red, (row_b, col_b)) for red, blue in zip(reds, blues)])
        dist = min(x,y)
    else:
        dist = L1(ta, tb)
    distance_cache[(ta,tb)] = dist
    return dist

In [70]:
BATCH_SIZE = 64
GAMMA = 0.9999
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 6000
TARGET_UPDATE = 50
HIDDEN_LAYER_SIZE = 15
EMBEDDING_SIZE = 6
EMBEDDING_LAMBDA = 0.01
EPISODE_LENGTH = 30
LEARNING_RATE = 0.0004

DIVERSE_EXPERT = True

num_episodes = 30000
random.seed(10)
distance_cache = {}

policy_net = miniDQN(HIDDEN_LAYER_SIZE).to(device)
target_net = miniDQN(HIDDEN_LAYER_SIZE).to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

optimizer = optim.Adam(policy_net.parameters(), lr = LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.99999)
memory = ReplayMemory(40000)
expert_memory = ReplayMemory(40000)

In [71]:
embedding_net = embeddingNN(HIDDEN_LAYER_SIZE, EMBEDDING_SIZE).to(device)

embedding_optimizer = optim.Adam(embedding_net.parameters(), lr = LEARNING_RATE)

In [72]:
steps_done = 0

#Epsilon greedy actions
def select_action(state, goal, test = False):
    global steps_done
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END) * math.exp(-1. * steps_done / EPS_DECAY)
    steps_done += 1
    if test or sample > eps_threshold:
        with torch.no_grad():
            x = torch.cat((state, goal), 1)
            return policy_net(x).max(1)[1].view(1, 1)
    else:
        return torch.tensor([[random.randrange(4)]], device=device, dtype=torch.long)

#Max entropy action
def select_action_entropy(state, goal, alpha):
    x = torch.cat((state, goal), 1)
    potential = (alpha * policy_net(x).detach()).exp().squeeze()
    dist = (potential / potential.sum()).numpy()
    return torch.tensor([[np.random.choice(4, p = dist)]], device=device, dtype=torch.long)
    

def select_expert_action(state, goal):
    row, col, blue, red = tuple(state)
    action_scores = np.zeros(4)
    for a in range(4):
        #Assumes determinism
        next_state = step(state, a)
        action_scores[a] = distance(next_state, goal)
    action = np.random.choice(np.flatnonzero(action_scores == action_scores.min()))
    return torch.tensor([[action]], device=device, dtype=torch.long)    

def sample_expert_trajectory():
    state = start_state
    next_state = None

    trajectory = []
    
    for t in range(EPISODE_LENGTH):
        action = select_expert_action(state, end_state)
        
        torch_state = torch.from_numpy(state).float().unsqueeze(0)
        next_state = step(state, action.item())
        torch_next_state = torch.from_numpy(next_state).float().unsqueeze(0)
        reward = get_reward(state, action.item(), next_state)
        done = is_done(state)
        reward = torch.tensor([reward], device=device)
        if done:
            torch_next_state = None

        trajectory.append((torch_state, action, torch_next_state, reward))
        if done:
            break
        state = next_state
    return trajectory

In [73]:
def optimize_model():
    if len(memory) < BATCH_SIZE:
        return
    #randomly learn over expert trajectories or rollout buffer
    transitions = memory.sample(BATCH_SIZE) if np.random.random() > 0.5 else expert_memory.sample(BATCH_SIZE)
    batch = Transition(*zip(*transitions))

    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,batch.next_state)),
                                  device=device, dtype=torch.uint8)
    non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])
    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)

    state_action_values = policy_net(state_batch).gather(1, action_batch)

    #DDQN
    policy_action_indices = policy_net(non_final_next_states).max(1)[1].detach()
    target_actions = target_net(non_final_next_states).detach()
    next_state_values = torch.zeros(BATCH_SIZE, device=device)
    next_state_values[non_final_mask] = target_actions.gather(1, policy_action_indices.unsqueeze(1)).squeeze()
    expected_state_action_values = (next_state_values * GAMMA) + reward_batch

    loss = F.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1))

    
    # Optimize the model
    optimizer.zero_grad()
    loss.backward()
    for param in policy_net.parameters():
        param.grad.data.clamp_(-1, 1)
    optimizer.step()
    return loss.item()

In [74]:
def sample_buffer():
    transitions = memory.sample(BATCH_SIZE) if np.random.random() > 0.5 else expert_memory.sample(BATCH_SIZE)
    batch = Transition(*zip(*transitions))

    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,batch.next_state)),
                                  device=device, dtype=torch.uint8)
    non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])
    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)
    
    state_embeddings = embedding_net(state_batch)
    next_state_embeddings = torch.zeros((BATCH_SIZE, EMBEDDING_SIZE), device=device)
    next_state_embeddings[non_final_mask,:] = embedding_net(non_final_next_states)
        
    return (state_embeddings, reward_batch, next_state_embeddings)


def optimize_bisim():
    if len(memory) < BATCH_SIZE:
        return
    
    sample_1 = sample_buffer()
    sample_2 = sample_buffer()

    distance = torch.norm(sample_1[0] - sample_2[0], dim = 1)
    reward_distance = torch.abs(sample_1[1] - sample_2[1])
    next_distance = torch.norm(sample_1[2] - sample_2[2], dim = 1)
    
    TD_error = reward_distance + GAMMA * next_distance
    
    loss = F.smooth_l1_loss(distance, TD_error.detach()) #Detaching because the loss function complains
    
    loss += EMBEDDING_LAMBDA * torch.mean(torch.norm(sample_1[0], dim = 1) + torch.norm(sample_2[0], dim = 1))
    
    # Optimize the model
    embedding_optimizer.zero_grad()
    loss.backward()
    for param in policy_net.parameters():
        param.grad.data.clamp_(-1, 1)
    embedding_optimizer.step()
    return loss.item()

In [75]:
losses = [0]
expert_trajectory = sample_expert_trajectory()
#pre-populate the replay buffers
for i in range(1000):
    if DIVERSE_EXPERT:
        expert_trajectory = sample_expert_trajectory()
    for torch_state, action, torch_next_state, reward in expert_trajectory:
        memory.push(combine(torch_state, torch_end), action, combine(torch_next_state, torch_end), reward)
        expert_memory.push(combine(torch_state, torch_end), action, combine(torch_next_state, torch_end), reward)

In [76]:
for i_episode in range(num_episodes):
    state = start_state
    next_state = None
    final = i_episode == num_episodes - 1
    trajectory = []
    
    if i_episode % 1000 == 0:
        scheduler.step()
    
    for t in range(EPISODE_LENGTH):
        torch_state = torch.from_numpy(state).float().unsqueeze(0)

        action = select_action(torch_state, torch_end, final)
            
        next_state = step(state, action.item())
        torch_next_state = torch.from_numpy(next_state).float().unsqueeze(0)
        reward = get_reward(state, action.item(), next_state)
        done = is_done(state)
        reward = torch.tensor([reward], device=device)

        if done:
            torch_next_state = None

        trajectory.append((torch_state, action, torch_next_state, reward))

        loss = optimize_model()
        optimize_bisim()
        losses.append(loss if loss is not None else 0)
        if final:
            print(state)
        if done:
            break
        state = next_state

    if i_episode % 1000 == 0:
        print(i_episode)
        print(np.mean(np.asarray(losses)))
        losses = []

    #HER updates with goal as real goal OR last state in trajectory
    for torch_state, action, torch_next_state, reward in trajectory:
        memory.push(combine(torch_state, torch_end), action, combine(torch_next_state, torch_end), reward)
        if not is_done(state):
            torch_local_goal = torch.from_numpy(state).float().unsqueeze(0)
            local_reward = torch.equal(torch_next_state, torch_local_goal)
            new_reward = torch.add(reward, local_reward)
            memory.push(combine(torch_state, torch_local_goal), action, combine(torch_next_state, torch_local_goal), new_reward)
                    
                          
    # Update the target network
    if i_episode % TARGET_UPDATE == 0:
        target_net.load_state_dict(policy_net.state_dict())
    
        
print('Complete')

0
0.6312915288632915
1000
0.08762718802983208
2000
0.1404543710306425
3000
0.17711578013387436
4000
0.10978186437314605
5000
0.050820254418993974
6000
0.03119435774156541
7000
0.022082727984731713
8000
0.015939105472129397
9000
0.014727433244492917
10000
0.0124022966361009
11000
0.008364319186442044
12000
0.006286344531851153
13000
0.011872495687833655
14000
0.014769212168452818
15000
0.013327201313902363
16000
0.014145487907746774
17000
0.016953523750812802
18000
0.011899025037636176
19000
0.006097320523448108
20000
0.0051427513863071175
21000
0.005531684914489484
22000
0.004880143143681759
23000
0.004637071082103994
24000
0.0047074228836770215
25000
0.005003519789869507
26000
0.0045951351621715646
27000
0.00456770157946877
28000
0.004796006188372141
29000
0.004373626625805216
[0 0 0 0]
[1 0 0 0]
[2 0 0 0]
[3 0 0 1]
[2 0 0 1]
[2 1 0 1]
[2 2 0 1]
[1 2 0 1]
[1 3 0 1]
[0 3 1 1]
[1 3 1 1]
[1 4 1 1]
[2 4 1 1]
[3 4 1 1]
[4 4 1 1]
Complete


In [82]:
def predicted_embedding(state, goal):
    torch_state = torch.from_numpy(state).float().unsqueeze(0)
    torch_goal = torch.from_numpy(goal).float().unsqueeze(0)
    x = torch.cat((torch_state, torch_goal), 1)
    return embedding_net(x)

print(predicted_embedding(np.asarray([2,1,0,0]), end_state))
print(predicted_embedding(np.asarray([1,2,0,0]), end_state))
print(predicted_embedding(np.asarray([2,2,0,1]), end_state))
print(predicted_embedding(np.asarray([2,2,1,0]), end_state))
print(predicted_embedding(np.asarray([2,4,1,1]), end_state))
print(predicted_embedding(np.asarray([4,2,1,1]), end_state))
print(predicted_embedding(np.asarray([4,4,1,1]), end_state))

tensor([[-0.1647, -0.4212, -0.1037,  0.0042,  0.5669, -0.1319]],
       grad_fn=<ThAddmmBackward>)
tensor([[-0.1850, -0.6927,  0.0027,  0.0071,  0.2250, -0.1876]],
       grad_fn=<ThAddmmBackward>)
tensor([[-0.6710, -0.4280, -0.0702,  0.4902,  0.0672,  0.4747]],
       grad_fn=<ThAddmmBackward>)
tensor([[ 0.1164,  0.1212,  0.0615, -0.4809,  0.1225, -0.3981]],
       grad_fn=<ThAddmmBackward>)
tensor([[-0.6487, -0.0256,  0.1402, -0.2351, -0.5352,  0.1818]],
       grad_fn=<ThAddmmBackward>)
tensor([[-0.6080,  0.5175, -0.0726, -0.2409,  0.1488,  0.2931]],
       grad_fn=<ThAddmmBackward>)
tensor([[-1.1255,  0.2376,  0.0177, -0.7216, -0.1673,  0.2399]],
       grad_fn=<ThAddmmBackward>)


In [50]:
def predicted_distance(state, goal):
    torch_state = torch.from_numpy(state).float().unsqueeze(0)
    torch_goal = torch.from_numpy(goal).float().unsqueeze(0)
    x = torch.cat((torch_state, torch_goal), 1)
    return -1 * policy_net(x).max(1)[0].view(1, 1).item()

def predicted_measure_ball(state, goal, n = 500):
    measure = {}
    torch_state = torch.from_numpy(state).float().unsqueeze(0)
    torch_goal = torch.from_numpy(goal).float().unsqueeze(0)
    x = torch.cat((torch_state, torch_goal), 1)
    for i in range(n):
        action = select_action_entropy(torch_state, torch_goal, alpha = 4.0)
        next_state = step(state, action.item())
        key = tuple(next_state)
        measure[key] = measure.get(key, 0) + 1.0/n
    return measure

def wasserstein(a, b, goal):
    if tuple(a) == tuple(b):
        return 0
    m_a = predicted_measure_ball(a, goal)
    m_b = predicted_measure_ball(b, goal)
    states = list(set(m_a.keys()) | set(m_b.keys()))
    
    m_a = [m_a.get(s, 0.0) for s in states]
    m_b = [m_b.get(s, 0.0) for s in states]
    dist = [[distance(np.asarray(s1),np.asarray(s2)) for s1 in states] for s2 in states]
    return emd(np.asarray(m_a), np.asarray(m_b), np.asarray(dist))


In [51]:
values = np.zeros((num_row, num_col, 2, 2))
for index in np.ndindex(num_row, num_col, 2, 2):
    values[index] = predicted_distance(np.asarray(index), end_state)
print(np.round(values[:,:,0,0], 2))
print(np.round(values[:,:,1,0], 2))
print(np.round(values[:,:,0,1], 2))
print(np.round(values[:,:,1,1], 2))

[[13.8  12.84 11.92 11.16  9.98 11.21 12.44 12.75 12.04 11.34]
 [13.05 13.6  12.99 11.23 11.33 12.57 13.8  13.47 12.76 12.06]
 [12.07 14.29 13.79 12.7  12.69 13.92 14.9  14.19 13.49 12.78]
 [13.91 15.14 15.26 14.18 14.05 15.28 15.62 14.92 14.21 13.51]
 [16.52 17.22 17.78 16.52 16.11 17.17 16.7  15.83 14.95 14.23]
 [18.31 19.28 20.25 19.48 18.95 19.79 18.91 18.03 17.15 16.28]
 [19.7  20.67 21.64 21.78 21.79 21.99 21.12 20.24 19.36 18.48]
 [21.09 22.06 23.03 23.77 23.63 24.2  23.32 22.45 21.57 20.69]
 [22.48 23.45 24.42 25.4  25.22 25.61 25.53 24.65 23.77 22.9 ]
 [23.87 24.84 25.81 26.76 26.14 26.43 26.72 26.46 25.49 24.52]]
[[ 8.03  9.1  10.03 10.96 10.27 10.56 11.64 12.72 13.8  13.4 ]
 [ 7.21  8.14  9.07  9.95 10.2   9.79 10.88 11.96 12.87 12.01]
 [ 6.24  7.17  8.1   8.94  9.4   8.78  9.97 11.44 12.1  11.63]
 [ 5.62  6.19  7.18  8.11  9.04  9.06 10.47 11.93 11.96 11.49]
 [ 6.35  7.72  8.65  9.59 10.52 11.04 12.34 13.5  12.85 12.21]
 [ 7.08  8.29  9.5  10.71 11.92 13.02 14.31 14.84 14.2

In [16]:
def reflect_state(state):
    row, col, blue, red = tuple(state)
    return np.asarray((col, row, red, blue))

def compare(state):
    reflected = reflect_state(state)
    m_a = predicted_measure_ball(state, end_state)
    m_b = predicted_measure_ball(reflected, end_state)
    print(m_a)
    print(m_b)

compare(np.asarray((0,1,1,1)))
    

{(1, 1, 1, 1): 0.6840000000000005, (0, 2, 1, 1): 0.3140000000000002, (0, 1, 1, 1): 0.002}
{(1, 1, 1, 1): 0.25400000000000017, (2, 0, 1, 1): 0.7460000000000006}
