In [1]:
%matplotlib notebook
%load_ext autoreload
%autoreload 2


In [2]:
import itertools

import inept
import matplotlib.animation as animation
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch


# Style
sns.set_theme(context='paper', style='white', palette='Set2')

In [3]:
# Resources
# https://openai.com/research/emergent-tool-use
# https://glouppe.github.io/info8004-advanced-machine-learning/pdf/pleroy-hide-and-seek.pdf
# https://github.com/nikhilbarhate99/PPO-PyTorch/blob/master/PPO.py#L38


In [4]:
# Reproducibility
torch.manual_seed(42)
np.random.seed(42)

### Create Environment

In [5]:
# Params
num_nodes = 100
num_dims = 2

# Data
M1 = torch.rand((num_nodes, 1000))
M2 = torch.rand((num_nodes, 2000))

# Environment
# x, y, vx, vy
env = inept.environments.trajectory(M1, M2, dim=num_dims)

### Train Policy

In [6]:
# Make policy
input_dims = 2*num_dims+M1.shape[1]+M2.shape[1]
policy = inept.models.PPO(input_dims, num_dims)

# Run model over all nodes
states = []
for timestep in range(101):
    # Get current state
    state = env.get_state(include_modalities=True)
    states.append(env.get_state())

    # Get self features for each node
    self_entity = state

    # Get node features for each state
    idx = torch.zeros((num_nodes, state.shape[0]), dtype=torch.bool)
    for i, j in itertools.product(*[range(x) for x in idx.shape]):
        idx[i, j] = i!=j
    node_entities = state.unsqueeze(0).expand(num_nodes, *state.shape)
    node_entities = node_entities[idx].reshape(num_nodes, num_nodes-1, input_dims)

    # Get actions from policy
    actions = policy.act(self_entity, node_entities)

    # Update velocities
    env.add_velocities(env.delta * actions)

    # Step environment
    env.step()
states = np.stack(states, axis=0)

### Animate Latent Space

In [7]:
# Create figure
fig, ax = plt.subplots()
plt.sca(ax)
ax.set_xlim([-1, 1])
ax.set_ylim([-1, 1])
ax.axis('equal')

# Initial scatter
sct = ax.scatter(*states[0, :, :num_dims].T)

# Update function
def update(frame):
    # Filter data
    sct.set_offsets(states[frame, :, :num_dims])

    return sct

# Run animation
ani = animation.FuncAnimation(fig=fig, func=update, frames=states.shape[0], interval=50)

# Save animation
# ani.save('test.gif')

# Show animation
from IPython.display import HTML
HTML(ani.to_jshtml())

<IPython.core.display.Javascript object>

In [8]:
# TODO: PPO
# DETAILS: Implement old (beginning of each epoch)
# DETAILS: Make sure to train on variable `node_entities` length so new nodes may be added