In [2]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import hypll
from hypll.manifolds.poincare_ball import Curvature, PoincareBall
from hypll.optim import RiemannianAdam
import hypll.nn as hnn
from hypll.tensors import TangentTensor
from matplotlib.animation import FuncAnimation
import argparse
from pyramid import create_pyramid
from continuous_maze import bfs, gen_traj, plot_traj, ContinuousGridEnvironment, TrajectoryDataset, LabelDataset
from hyperbolic_networks import HyperbolicMLP, hyperbolic_infoNCE_loss, manifold_map
from networks import StateActionEncoder, StateEncoder, CategoricalEncoder, infoNCE_loss
import os

import wandb

device = "cuda" if torch.cuda.is_available() else "cpu"

def evaluate(maze, num_trials, encoder1, encoder2, manifold, max_steps=100, hyperbolic=False, eps=10., step_size=0.5, verbose=False):
    valid_indices = np.argwhere(maze == 0)
    np.random.shuffle(valid_indices)
    
    results = []
    for i in range(num_trials):
        with torch.no_grad():
            start, end = np.random.randint(0, len(valid_indices), size=2)
            start = tuple(valid_indices[start])
            end = tuple(valid_indices[end])
            
            goal = torch.tensor(end).to(torch.float32).to(device).unsqueeze(0)
            # if hyperbolic:
            #     goal = manifold_map(goal, manifold=manifold)
            goal = encoder2(goal)
            
            # print(start)
            env = ContinuousGridEnvironment(maze, start, {})
            
            def reached(cur_pos, goal_pos):
                # print(f'cur pos: {cur_pos}')
                cur_pos = (int(cur_pos[0]), int(cur_pos[1]))
                goal_pos = (int(goal_pos[0]), int(goal_pos[1]))
                return cur_pos == goal_pos
            
            def step():
                cur_pos = env.agent_position
                if verbose:
                    print(f'cur_pos: {cur_pos}, goal: {goal}')
                activations = []
                angles = torch.linspace(0., 2 * torch.pi, 16)
                for a in angles:
                    action = torch.tensor([torch.sin(a), torch.cos(a)])
                    cur = torch.tensor([cur_pos[0], cur_pos[1], torch.sin(a), torch.cos(a)]).to(device, torch.float32)
                    # if hyperbolic:
                    #     cur = manifold_map(cur, manifold)
                    cur = encoder1(cur)

                    # MANIFOLD EVAL
                    if hyperbolic:
                        activations.append((action, -manifold.dist(x=cur, y=goal)))
                    else:
                        activations.append((action, -torch.norm(cur - goal)))
                        
            

                best_action = activations[np.argmax([x[1].cpu() for x in activations])][0]
                angle = np.arctan2(best_action[0], best_action[1]) + np.random.normal() * eps * (2 * np.pi / 360)
                best_action = torch.tensor(np.array([np.sin(angle), np.cos(angle)]))
                env.move_agent(best_action)
                # print(f'agent position: {env.agent_position}')
                
                
            def SPL(maze, start, end, num_steps, success): # Success weighted by (normalized inverse) Path Length
                if not success:
                    return 0
                else:
                    p = num_steps * step_size
                    l = len(bfs(maze, start, end))
                    return (l / max(p, l))
            
            steps = 0
            while not reached(env.agent_position, end):
                if steps > max_steps:
                    break
                step()
                steps += 1
                
            result = (not reached(env.agent_position, end), steps, SPL(maze, start, end, steps, reached(env.agent_position, end)))
            if verbose:
                print(reached(env.agent_position, end))
                print(f'start: {start}, goal: {end}, end_pos: {env.agent_position}, steps: {steps}')
                print(results)
                
            results.append(result)
    
    return results


def get_maze(name):
    maze = np.zeros((10, 10))
    
    if 'blank' in name:
        print('blank maze')
        maze = np.zeros((10, 10))
    elif 'slit' in name:
        print('slit maze')
        maze = np.zeros((11, 11))
        maze[:,5] = 1
        maze[5, 5] = 0
    elif 'blocker' in name:
        maze = np.zeros((11, 11))
        maze[3,:] = 1
        maze[3, 10] = 0
    elif 'nested_pyramid' in name:
        maze = create_pyramid(np.zeros((2, 2)), 2)[0]
    else:
        maze = create_pyramid(np.zeros((2, 2)), 1)[0]

    return maze

def save_models(encoder1, encoder2, best_encoder1, best_encoder2, epoch, best_epoch, name=''):
    os.makedirs('models', exist_ok=True)
    torch.save(encoder1.state_dict(), f'models/{name}_encoder1_epoch_{epoch}.pth')
    torch.save(encoder2.state_dict(), f'models/{name}_encoder2_epoch_{epoch}.pth')
    torch.save(best_encoder1, f'models/{name}_best_encoder1_epoch_{best_epoch}.pth')
    torch.save(best_encoder2, f'models/{name}_best_encoder2_epoch_{best_epoch}.pth')

    

In [4]:
maze = get_maze('blocker')

wandb.init(
    project='noproject', 
    name='noname', 
    # Track hyperparameters and run metadata
    config={
        "embedding_dim": 8,
        "eval_trials": 100,
        "max_steps": 100,
        "hyperbolic": False,
        "num_epochs": 16,
        "temperature": 0.1,
        "batch_size": 64,
        "num_negatives": 64,
        "learning_rate": 0.001,
        "architecture": "MLP",
        "maze": maze,
        "num_trajectories": 1000,
        "maze_type": 'blocker',
        "gamma":0.1,
        "hyp_layers": 1
    }
)

# configs
config = wandb.config

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


In [9]:
print(config)
manifold = PoincareBall(c=Curvature(value=0.1, requires_grad=True))

dataset = TrajectoryDataset(maze, config.num_trajectories, embedding_dim=config.embedding_dim, num_negatives=10, gamma=config.gamma)
dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True, num_workers=1)
street_dataset = LabelDataset(maze, size=1000, embedding_dim=config.embedding_dim, num_negatives=10)
street_dataloader = DataLoader(street_dataset, batch_size=config.batch_size, shuffle=True, num_workers=1)


encoder1 = StateActionEncoder(config.embedding_dim).to(device)
encoder2 = StateEncoder(config.embedding_dim).to(device)
street_encoder = CategoricalEncoder(street_dataset.num_categories, config.embedding_dim)
optimizer = optim.Adam(list(encoder1.parameters()) + list(encoder2.parameters()), lr=config.learning_rate)

best_spl = 0
best_encoder1 = encoder1.state_dict()
best_encoder2 = encoder2.state_dict()
best_epoch = 0

# Training loop
for epoch in range(config.num_epochs):
    total_loss = 0
    street_iterator = iter(street_dataloader)
    
    for anchor, positive, negatives in dataloader:
        try:
            s_anchor, s_positive, s_negatives, s_neg_cats = next(street_iterator)
        except StopIteration:
            street_iterator = iter(street_dataloader)
            s_anchor, s_positive, s_negatives, s_neg_cats = next(street_iterator)

        # (s,a) <-> (s)
        anchor = torch.tensor(anchor).to(device, torch.float32)
        positive = torch.tensor(positive).to(device, torch.float32)
        negatives = torch.tensor(negatives).to(device, torch.float32)


        anchor_enc = encoder1(anchor) # takes state, action tuple
        positive_enc = encoder2(positive) # takes state
        negatives_enc = encoder2(negatives)

        cur_state = anchor[:,[0,1]]
        angle = torch.arctan2(anchor[:,2], anchor[:,3])

        negative_actions = (angle + torch.pi)[:,None] + (torch.rand(config.num_negatives)[None,:].to(device) - 0.5) * (3 * torch.pi / 2)
        negative_dirs = torch.stack([torch.sin(negative_actions), torch.cos(negative_actions)]).moveaxis(0, -1)
        negative_full = torch.cat((cur_state.unsqueeze(1).expand(-1, config.num_negatives, -1), negative_dirs), dim=-1).to(device)
        neg_action_enc = encoder1(negative_full)

        action_loss = infoNCE_loss(positive_enc, anchor_enc, neg_action_enc, config.temperature, metric_type=1)
        future_loss = infoNCE_loss(anchor_enc, positive_enc, negatives_enc, config.temperature, metric_type=1)

        loss = action_loss + future_loss

        s_anchor = torch.tensor(s_anchor).to(device)
        s_positive = torch.tensor(s_positive).to(device, torch.float32)
        s_negatives = torch.tensor(s_negatives).to(device, torch.float32)
        s_neg_cats = torch.tensor(s_neg_cats).to(device)

        s_anchor_enc = street_encoder(s_anchor) # takes state, action tuple
        s_positive_enc = encoder2(s_positive) # takes state
        s_negatives_enc = encoder2(s_negatives)
        s_neg_cats_enc = street_encoder(s_neg_cats)

        s_loss = infoNCE_loss(s_anchor_enc, s_positive_enc, s_negatives_enc, config.temperature, metric_type=1)
        s2_loss = infoNCE_loss(s_positive_enc, s_anchor_enc, s_neg_cats_enc, config.temperature, metric_type=1)
        loss += s_loss + s2_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    loss = total_loss / len(dataloader)
    evals = evaluate(maze, config.eval_trials, encoder1, encoder2, manifold, max_steps=config.max_steps, hyperbolic=config.hyperbolic, eps=50.)
    acc = np.mean([x[2] for x in evals])
    fail = np.mean([x[0] for x in evals])

    metrics = {
        "epoch": epoch + 1,
        "loss": loss,
        "spl": acc,
        "fail": fail
    }
    wandb.log(metrics)

    if acc > best_spl:
        best_spl = acc
        best_encoder1 = encoder1.state_dict()
        best_encoder2 = encoder2.state_dict()
        best_epoch = epoch + 1

#     if epoch % 32 == 0:
#         save_models(encoder1, encoder2, best_encoder1, best_encoder2, epoch + 1, best_epoch, experiment_name)

    print(f'Epoch {epoch+1}, Loss: {loss}, SPL: {acc}, Failure %: {fail}')

# save_models(encoder1, encoder2, best_encoder1, best_encoder2, epoch + 1, best_epoch, experiment_name)


{'embedding_dim': 8, 'eval_trials': 100, 'max_steps': 100, 'hyperbolic': False, 'num_epochs': 16, 'temperature': 0.1, 'batch_size': 64, 'num_negatives': 64, 'learning_rate': 0.001, 'architecture': 'MLP', 'maze': '[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0.]\n [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]', 'num_trajectories': 1000, 'maze_type': 'blocker', 'gamma': 0.1, 'hyp_layers': 1}
gamma: 0.1


  anchor = torch.tensor(anchor).to(device, torch.float32)
  positive = torch.tensor(positive).to(device, torch.float32)
  negatives = torch.tensor(negatives).to(device, torch.float32)
  s_anchor = torch.tensor(s_anchor).to(device)
  s_positive = torch.tensor(s_positive).to(device, torch.float32)
  s_negatives = torch.tensor(s_negatives).to(device, torch.float32)
  s_neg_cats = torch.tensor(s_neg_cats).to(device)


Epoch 1, Loss: 12.599339246749878, SPL: 0.08655559494269172, Failure %: 0.9
Epoch 2, Loss: 10.938816964626312, SPL: 0.1350411914382147, Failure %: 0.83
Epoch 3, Loss: 10.658222734928131, SPL: 0.2329696693758978, Failure %: 0.65
Epoch 4, Loss: 10.440049171447754, SPL: 0.22008432397872507, Failure %: 0.67
Epoch 5, Loss: 10.194572269916534, SPL: 0.18762195864005918, Failure %: 0.74
Epoch 6, Loss: 10.127788543701172, SPL: 0.28371740034557863, Failure %: 0.63
Epoch 7, Loss: 10.019772589206696, SPL: 0.28327553253113025, Failure %: 0.62
Epoch 8, Loss: 9.837863385677338, SPL: 0.29450731530111196, Failure %: 0.59
Epoch 9, Loss: 9.847705364227295, SPL: 0.18088182866608168, Failure %: 0.73
Epoch 10, Loss: 9.821453154087067, SPL: 0.2658004419441617, Failure %: 0.63
Epoch 11, Loss: 9.875501453876495, SPL: 0.16753876229040351, Failure %: 0.75
Epoch 12, Loss: 9.800004124641418, SPL: 0.32594911908023333, Failure %: 0.52
Epoch 13, Loss: 9.660109221935272, SPL: 0.25685650699538276, Failure %: 0.65
Epoch

In [8]:
street_dataset[3]

(11,
 array([0.23100757, 0.06919214]),
 array([[ 8.82555558,  2.38111242],
        [ 3.60719127,  7.32444709],
        [10.45399338,  9.69942179],
        [ 5.9742097 ,  4.98081709],
        [ 8.85545035,  2.13547802],
        [ 4.86092919, 10.31306719],
        [ 4.65853024,  5.35934154],
        [ 8.33903978,  6.26368967],
        [ 2.60227142,  9.98562671],
        [ 1.16966521,  1.09311244]]),
 array([ 5,  5,  5, 17, 19, 12, 13, 21,  5,  3]))

In [None]:
maze