In [1]:
%config Completer.use_jedi = False

In [2]:
import numpy as np
import gym
from collections import deque
import random
import torch.autograd
import os
import math
import time
from scipy.ndimage.filters import uniform_filter1d
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F 
from torch.autograd import Variable
import sys
import pickle
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.cm as cm
import matplotlib.markers as markers
import pybullet as p 
if torch.cuda.is_available():
    device = torch.device("cuda:0")  # you can continue going on here, like cuda:1 cuda:2....etc. 
    print("Running on the GPU")
else:
    device = torch.device("cpu")
    print("Running on the CPU")
import networkx as nx
from tqdm import tqdm
import dgl
import morphsim as m
from graphenvs import HalfCheetahGraphEnv
import itertools
import queue

%matplotlib widget

Running on the GPU


Using backend: pytorch


In [3]:
class NetworkInverseDynamics(nn.Module):
    def __init__(
        self,
        input_size,
        output_size,
        hidden_sizes,
        with_batch_norm=False,
        activation=None
    ):
        super(NetworkInverseDynamics, self).__init__()
        self.hidden_sizes = hidden_sizes
        self.input_size = input_size
        self.output_size = output_size
        
        self.layers = nn.ModuleList()

        self.layers.append(nn.Linear(self.input_size, hidden_sizes[0]))
        if with_batch_norm:
            self.layers.append(nn.LayerNorm(normalized_shape=(hidden_sizes[0])))
        self.layers.append(nn.ReLU())
        
        for i in range(len(hidden_sizes) - 1):
            self.layers.append(nn.Linear(hidden_sizes[i], hidden_sizes[i+1]))
            if with_batch_norm:
                self.layers.append(nn.LayerNorm(normalized_shape=(hidden_sizes[i+1])))
            self.layers.append(nn.ReLU())
        
        self.layers.append(nn.Linear(hidden_sizes[len(hidden_sizes) - 1], self.output_size))
        
        if activation is not None:
            self.layers.append(activation())
            
    def forward(self, x):
        out = x
        
        for layer in self.layers:
            out = layer(out)
            
        return out


In [4]:
class GNNInverseDynamics(nn.Module):
    def __init__(
        self,
        inputNetwork,
        messageNetwork,
        updateNetwork,
        outputNetwork,
        numMessagePassingIterations,
        withInputNetwork = True
    ):
        
        super(GNNInverseDynamics, self).__init__()
                
        self.inputNetwork = inputNetwork
        self.messageNetwork = messageNetwork
        self.updateNetwork = updateNetwork
        self.outputNetwork = outputNetwork
        
        self.numMessagePassingIterations = numMessagePassingIterations
        self.withInputNetwork = withInputNetwork
        
    def inputFunction(self, nodes):
        return {'state' : self.inputNetwork(nodes.data['input'])}
    
    def messageFunction(self, edges):
        
        batchSize = edges.src['state'].shape[1]
        edgeData = edges.data['feature'].repeat(batchSize, 1).T.unsqueeze(-1)
        nodeInput = edges.src['input']
        
        return {'m' : self.messageNetwork(torch.cat((edges.src['state'], edgeData, nodeInput), -1))}
    
    def updateFunction(self, nodes):
        return {'state': self.updateNetwork(torch.cat((nodes.data['m_hat'], nodes.data['state']), -1))}
    
    def outputFunction(self, nodes):
        
        return {'output': self.outputNetwork(nodes.data['state'])}


    def forward(self, graph, state):
        
        self.update_states_in_graph(graph, state)
        
        if self.withInputNetwork:
            graph.apply_nodes(self.inputFunction)
        
        for messagePassingIteration in range(self.numMessagePassingIterations):
            graph.update_all(self.messageFunction, dgl.function.mean('m', 'm_hat'), self.updateFunction)
        
        graph.apply_nodes(self.outputFunction)
        
        output = graph.ndata['output']
        output = output.squeeze(-1).mean(0)
                
        return output
    
    def update_states_in_graph(self, graph, state):
        if len(state.shape) == 1:
            state = state.unsqueeze(0)
        
        numGraphFeature = 6
        numGlobalStateInformation = 5
        numLocalStateInformation = 2
        numStateVar = state.shape[1] // 2
        globalInformation = torch.cat((state[:, 0:5], state[:, numStateVar:numStateVar+5]), -1)
        
        numNodes = (numStateVar - 5) // 2

        nodeData = torch.empty((numNodes, state.shape[0], numGraphFeature + 2 * numGlobalStateInformation + 2 * numLocalStateInformation)).to(device)
        for nodeIdx in range(numNodes):

            # Assign global features from graph
            nodeData[nodeIdx, :, :6] = graph.ndata['feature'][nodeIdx]
            # Assign local state information
            nodeData[nodeIdx, :, 16] = state[:, 5 + nodeIdx]
            nodeData[nodeIdx, :, 17] = state[:, 5 + numNodes + nodeIdx]
            nodeData[nodeIdx, :, 18] = state[:, numStateVar + 5 + nodeIdx]
            nodeData[nodeIdx, :, 19] = state[:, numStateVar + 5 + numNodes + nodeIdx]

        # Assdign global state information
        nodeData[:, :, 6:16] = globalInformation
        
        if self.withInputNetwork:
            graph.ndata['input'] = nodeData        
        
        else:
            graph.ndata['state'] = nodeData


In [5]:
class NetworkAutoEncoder(nn.Module):
    def __init__(
        self,
        input_size,
        output_size,
        hidden_sizes,
        batch_size=256, # Needed only for batch norm
        with_batch_norm=False,
        activation=None
    ):
        super(NetworkAutoEncoder, self).__init__()
        self.hidden_sizes = hidden_sizes
        self.input_size = input_size
        self.output_size = output_size
        
        self.layers = nn.ModuleList()

        self.layers.append(nn.Linear(self.input_size, hidden_sizes[0]))
        if with_batch_norm:
#             self.layers.append(nn.BatchNorm1d(batch_size))
            self.layers.append(nn.LayerNorm(normalized_shape=(hidden_sizes[0])))
        self.layers.append(nn.ReLU())
        
        for i in range(len(hidden_sizes) - 1):
            self.layers.append(nn.Linear(hidden_sizes[i], hidden_sizes[i+1]))
            if with_batch_norm:
#                 self.layers.append(nn.BatchNorm1d(batch_size))
                self.layers.append(nn.LayerNorm(normalized_shape=(hidden_sizes[i+1])))
            self.layers.append(nn.ReLU())
        
        self.layers.append(nn.Linear(hidden_sizes[len(hidden_sizes) - 1], self.output_size))
        
        if activation is not None:
            self.layers.append(activation())
            
    def forward(self, x):
        out = x
        
        for layer in self.layers:
            out = layer(out)
            
        return out

In [6]:
class GraphNeuralNetworkAutoEncoder(nn.Module):
    def __init__(
            self,
            inputNetwork,
            messageNetwork,
            updateNetwork,
            outputNetwork,
            numMessagePassingIterations,
            encoder=True
    ):

        super(GraphNeuralNetworkAutoEncoder, self).__init__()

        self.inputNetwork = inputNetwork
        self.messageNetwork = messageNetwork
        self.updateNetwork = updateNetwork
        self.outputNetwork = outputNetwork

        self.numMessagePassingIterations = numMessagePassingIterations
        self.encoder = encoder

    def inputFunction(self, nodes):
        return {'state': self.inputNetwork(nodes.data['input'])}

    def messageFunction(self, edges):

        batchSize = edges.src['state'].shape[1]
        edgeData = edges.data['feature'].repeat(batchSize, 1).T.unsqueeze(-1)
        nodeInput = edges.src['input']

        #         print(edges.src['state'].shape)
        #         print(nodeInput.shape)
        return {'m': self.messageNetwork(torch.cat((edges.src['state'], edgeData, nodeInput), -1))}

    def updateFunction(self, nodes):
        return {'state': self.updateNetwork(torch.cat((nodes.data['m_hat'], nodes.data['state']), -1))}

    def outputFunction(self, nodes):

        #         numNodes, batchSize, stateSize = graph.ndata['state'].shape
        #         return self.outputNetwork.forward(graph.ndata['state'])
        return {'output': self.outputNetwork(nodes.data['state'])}

    def forward(self, graph, state):

        self.update_states_in_graph(graph, state)

        graph.apply_nodes(self.inputFunction)

        for messagePassingIteration in range(self.numMessagePassingIterations):
            graph.update_all(self.messageFunction, dgl.function.max('m', 'm_hat'), self.updateFunction)

        graph.apply_nodes(self.outputFunction)

        output = graph.ndata['output']

        if self.encoder:
            output = F.normalize(output, dim=-1)

        return output

    def update_states_in_graph(self, graph, state):

        if self.encoder:
            if len(state.shape) == 1:
                state = state.unsqueeze(0)

            numGraphFeature = 6
            numGlobalStateInformation = 5
            numLocalStateInformation = 2
            numStateVar = state.shape[1]
            globalInformation = state[:, 0:5]
            batch_size = state.shape[0]
            numNodes = (numStateVar - 5) // 2

            nodeData = torch.empty(
                (numNodes, batch_size, numGraphFeature + numGlobalStateInformation + numLocalStateInformation)).to(
                device)

            nodeData[:, :, 0:numGlobalStateInformation] = globalInformation
            for nodeIdx in range(numNodes):
                # Assign local state information
                nodeData[nodeIdx, :, numGlobalStateInformation] = state[:, 5 + nodeIdx]
                nodeData[nodeIdx, :, numGlobalStateInformation + 1] = state[:, 5 + numNodes + nodeIdx]
                # Assign global features from graph
                nodeData[nodeIdx, :, numGlobalStateInformation + 2: numGlobalStateInformation + 2 + numGraphFeature] = \
                graph.ndata['feature'][nodeIdx]

            graph.ndata['input'] = nodeData

        else:
            numNodes, batchSize, inputSize = state.shape
            nodeData = torch.empty((numNodes, batchSize, inputSize + 6)).to(device)
            nodeData[:, :, :inputSize] = state
            nodeData[:, :, inputSize: inputSize + 6] = graph.ndata['feature'].unsqueeze(dim=1).repeat_interleave(
                batchSize, dim=1)
            #             for nodeIdx in range(numNodes):
            #                 nodeData[nodeIdx, :, inputSize : inputSize + 6] = graph.ndata['feature'][nodeIdx]

            graph.ndata['input'] = nodeData

In [33]:
states = {}
actions = {}
rewards = {}
next_states = {}
dones = {}
env = {}

for morphIdx in [5]:

    prefix = '../datasets/{}/'.format(morphIdx)
    
    states[morphIdx] = torch.from_numpy(np.load(prefix + 'states_array.npy'))
    actions[morphIdx] = torch.from_numpy(np.load(prefix + 'actions_array.npy'))
    rewards[morphIdx] = torch.from_numpy(np.load(prefix + 'rewards_array.npy'))
    next_states[morphIdx] = torch.from_numpy(np.load(prefix + 'next_states_array.npy'))
    dones[morphIdx] = torch.from_numpy(np.load(prefix + 'dones_array.npy'))
    
    env[morphIdx] = HalfCheetahGraphEnv(None)
    env[morphIdx].set_morphology(morphIdx)
    env[morphIdx].reset()

None
*************************************************************************************************************


NoneType: None


In [8]:
def getNewGraph(env, morphIdx):
    return env[morphIdx].get_graph()._get_dgl_graph()

In [9]:
def fibonacci_sphere(samples=1):

    points = []
    phi = math.pi * (3. - math.sqrt(5.))  # golden angle in radians

    for i in range(samples):
        y = 1 - (i / float(samples - 1)) * 2  # y goes from 1 to -1
        radius = math.sqrt(1 - y * y)  # radius at y

        theta = phi * i  # golden angle increment

        x = math.cos(theta) * radius
        z = math.sin(theta) * radius

        points.append((x, y, z))

    return np.array(points)

In [10]:
hidden_sizes = [256, 256]

inputSize = 13
stateSize = 64
messageSize = 64
latentSize = 3
numMessagePassingIterations = 6
with_batch_norm = True

# # Encoder Networks
encoderInputNetwork = NetworkAutoEncoder(inputSize, stateSize, hidden_sizes, with_batch_norm=with_batch_norm)
encoderMessageNetwork = NetworkAutoEncoder(stateSize + inputSize + 1, messageSize, hidden_sizes, with_batch_norm=with_batch_norm, activation=nn.Tanh)
encoderUpdateNetwork = NetworkAutoEncoder(stateSize + messageSize, stateSize, hidden_sizes, with_batch_norm=with_batch_norm)
encoderOutputNetwork = NetworkAutoEncoder(stateSize, latentSize, hidden_sizes, with_batch_norm=with_batch_norm)
encoderGNN = GraphNeuralNetworkAutoEncoder(encoderInputNetwork, encoderMessageNetwork, encoderUpdateNetwork, encoderOutputNetwork, numMessagePassingIterations, encoder=True).to(device)

# # Decoder Networks
decoderInputNetwork = NetworkAutoEncoder(latentSize + 6, stateSize, hidden_sizes, with_batch_norm=with_batch_norm)
decoderMessageNetwork = NetworkAutoEncoder(stateSize + latentSize + 7, messageSize, hidden_sizes, with_batch_norm=with_batch_norm, activation=nn.Tanh)
decoderUpdateNetwork = NetworkAutoEncoder(stateSize + messageSize, stateSize, hidden_sizes, with_batch_norm=with_batch_norm)
decoderOutputNetwork = NetworkAutoEncoder(stateSize, 7, hidden_sizes, with_batch_norm=with_batch_norm)
decoderGNN = GraphNeuralNetworkAutoEncoder(decoderInputNetwork, decoderMessageNetwork, decoderUpdateNetwork, decoderOutputNetwork, numMessagePassingIterations, encoder=False).to(device)

# encoderGNN.load_state_dict(torch.load('../models/new/4-latent-single-GNN-AutoEncoder/5/0.0-1.5/encoderGNN.pt'))
# decoderGNN.load_state_dict(torch.load('../models/new/4-latent-single-GNN-AutoEncoder/5/0.0-1.5/decoderGNN.pt'))

print(encoderGNN.load_state_dict(torch.load('../models/new/3-latent-single-GNN-AutoEncoder/5/0.4-1.33/' + 'encoderGNN.pt')))
print(decoderGNN.load_state_dict(torch.load('../models/new/3-latent-single-GNN-AutoEncoder/5/0.4-1.33/' + 'decoderGNN.pt')))

# print(encoderGNN.load_state_dict(torch.load('../models/new/3-latent-single-GNN-AutoEncoder/5/no-contrastive/' + 'encoderGNN.pt')))
# print(decoderGNN.load_state_dict(torch.load('../models/new/3-latent-single-GNN-AutoEncoder/5/no-contrastive/' + 'decoderGNN.pt')))


# print(encoderGNN.load_state_dict(torch.load('../models/new/3-latent-single-GNN-AutoEncoder/5/no-contrastive/' + 'encoderGNN.pt')))
# print(decoderGNN.load_state_dict(torch.load('../models/new/3-latent-single-GNN-AutoEncoder/5/no-contrastive/' + 'decoderGNN.pt')))


<All keys matched successfully>
<All keys matched successfully>


In [11]:
hidden_sizes = [256, 256]

inputSize = 20
stateSize = 64
messageSize = 64
outputSize = 1
numMessagePassingIterations = 6
with_batch_norm = True

inputNetwork = NetworkInverseDynamics(inputSize, stateSize, hidden_sizes, with_batch_norm)
messageNetwork = NetworkInverseDynamics(stateSize + inputSize + 1, messageSize, hidden_sizes, with_batch_norm, nn.Tanh)
updateNetwork = NetworkInverseDynamics(stateSize + messageSize, stateSize, hidden_sizes, with_batch_norm)
outputNetwork = NetworkInverseDynamics(stateSize, outputSize, hidden_sizes, with_batch_norm, nn.Sigmoid)

inverseDynamics = GNNInverseDynamics(inputNetwork, messageNetwork, updateNetwork, outputNetwork, numMessagePassingIterations).to(device)
inverseDynamics.load_state_dict(torch.load('mixed-delta-validTransition.pt'))

<All keys matched successfully>

In [12]:
def reconstructStateFromGraph(graph_data):
    num_nodes, batch_size, data_size = graph_data.shape
    output = torch.empty((batch_size, 5 + 2 * num_nodes))
    output[:, :5] = graph_data[:, :, :5].mean(dim=0)
    output[:, 5 : 5 + num_nodes] = graph_data[:, :, 5].T
    output[:, 5 + num_nodes:] = graph_data[:, :, 6].T
    
    return output

In [None]:
# !!!!
# Construct Graph
# !!!!

morphIdx = 5 # Morphology of the Agent
alpha = 0.4 # Largest distance expected between sequential states
num_offsets = 128 # Number of states that we will `attempt` to add to the graph on every iteration
transition_threshold = 0.95 # Lowest allowed likelihood for transition (as given by valid-transition-network) allowed in the graph
num_samples = 10000 # Size of the subset of points we will choose from when deciding
max_dst_threshold = 0.15
min_dst_threshold = 0.03
jitter_parameter = 5e-2 # Mean of the distribution from which to sample jitter; UNPROVEN

# Just to get things started, randomly sample a state from the dataset, encode it, and then add it to the graph
saved_states = []
saved_encodings = []
adjacency_list = []
explored_state = []
failed_encodings = []
succesful_encodings = []
success_rate = []
within_range_list = []

random_idx = np.random.choice(int(1e6))
random_state = states[morphIdx][random_idx].cpu()
g = getNewGraph(env, morphIdx)
random_encoding = encoderGNN(g, random_state.unsqueeze(0))[:, 0].detach().cpu()
saved_states.append(random_state)
saved_encodings.append(random_encoding)
explored_state.append(False)
adjacency_list.append([])

t0 = time.time()


with torch.no_grad():
    
    i = 0
    #     for i in range(50000):
    while len(saved_states) < int(5e4):
        
        i += 1
        
        if i % 250 == 1:
            print('Reached iteration {} with {} states'.format(i, len(saved_states)))
        
        if i >= 100 and len(saved_states) == 1:
            saved_states = []
            saved_encodings = []
            adjacency_list = []
            explored_state = []
            failed_encodings = []
            succesful_encodings = []
            success_rate = []
            within_range_list = []

            random_idx = np.random.choice(int(1e6))
            random_state = torch.from_numpy(states[morphIdx][random_idx])
            g = getNewGraph(env, morphIdx)
            random_encoding = encoderGNN(g, random_state.unsqueeze(0))[:, 0].detach().cpu()
            saved_states.append(random_state)
            saved_encodings.append(random_encoding)
            explored_state.append(False)
            adjacency_list.append([])
            i=0
            print('Restarting...')
        
        # np array of size batch: min(len(saved_states), num_samples) 
        random_subset_indeces = np.random.choice(len(saved_states), size=min(len(saved_states), num_samples))
        # tensor of shape (batch, 19)
        random_subset_states = [saved_states[i] for i in random_subset_indeces]
        # tensor of shape (batch, 7, 3)
        random_subset_encodings = torch.stack([saved_encodings[i] for i in random_subset_indeces]).to(device)
        # boolean tensor of shape (batch)
        random_subset_explored = torch.BoolTensor([explored_state[i] for i in random_subset_indeces]).to(device)
        # tensor of shape (7, batch, 3)
        random_subset_encodings = random_subset_encodings.transpose(0, 1)
        
        
        # tensor of size (7, batch, batch)
        distances = torch.cdist(random_subset_encodings, random_subset_encodings)
                
        # New Way
        num_neighbors = (distances.mean(0) <= alpha).sum(-1)
        num_neighbors += distances.size(1) * random_subset_explored
        fewest_neighbors_idx = torch.argmin(num_neighbors)
        
        # tensor of shape (batch, 7, 3)
        random_subset_encodings = random_subset_encodings.transpose(0, 1)
        current_encoding = random_subset_encodings[fewest_neighbors_idx]
        current_state = random_subset_states[fewest_neighbors_idx]
        
        distances = None
        num_neighbors = None
        jitter = None
        offsets = None
        new_encodings = None
        new_states = None
        transitions = None
        kept_states = None
        kept_encodings = None
        random_subset_states = None
        random_subset_encodings = None
        random_subset_explored = None
        torch.cuda.empty_cache()            
        
        # tensor of shape (7, batch, 3)
        
        current_encoding_avg_dst = torch.cdist(current_encoding, current_encoding).mean()
        offsets = torch.normal(0, 1, (1, num_offsets, current_encoding.size(1))).repeat(current_encoding.size(0), 1, 1).to(device)
        
        jitter = torch.normal(0, jitter_parameter, size=offsets.size()).to(device)
        offsets += jitter
        offsets = F.normalize(offsets, dim=-1)
        offsets *= alpha
        
        # Convert coordinates from sphere into global ones, and project onto original sphere
        new_encodings = current_encoding.unsqueeze(1) - offsets
        new_encodings = F.normalize(new_encodings, dim=-1)
        
        
        tmp_dst = torch.cdist(new_encodings.transpose(0,1), new_encodings.transpose(0,1)).mean(-1).mean(-1)
        
        under_max_dst = tmp_dst < max_dst_threshold
        over_min_dst = tmp_dst > min_dst_threshold
        within_dst_range = under_max_dst *over_min_dst

        within_range_list.append(within_dst_range.sum() / num_offsets)        
        new_encodings = new_encodings[:, within_dst_range]
        new_encodings = new_encodings.to(device)
        
        g = getNewGraph(env, morphIdx)
        new_states = decoderGNN(g, new_encodings)
        new_states[:, :, :5] = new_states[:, :, :5].mean(dim=0)
        new_states = reconstructStateFromGraph(new_states)
        transitions = current_state.repeat(new_states.size(0), 2)
        transitions [:, transitions.size(1) // 2:] -= new_states

        g = getNewGraph(env, morphIdx)
        probabilities = inverseDynamics(g, transitions).cpu()
        
        valid_transition_indeces = probabilities > transition_threshold

        kept_states = new_states[valid_transition_indeces]
        kept_encodings = new_encodings[:, valid_transition_indeces]
        
        fewest_neighbors_idx_global = random_subset_indeces[fewest_neighbors_idx]
        explored_state[fewest_neighbors_idx_global] = True
        adjacency_list[fewest_neighbors_idx_global].extend(range(len(saved_states), len(saved_states) + kept_states.size(0)))
        
        current_encoding = current_encoding.cpu()
        if kept_states.size(0) == 0:
            failed_encodings.append(current_encoding)
        else:
            succesful_encodings.append(current_encoding)
        
        success_rate.append(kept_states.size(0) / num_offsets)
        
        kept_states = kept_states.cpu()
        kept_encodings = kept_encodings.cpu()
        for new_idx in range(kept_states.size(0)):
            saved_states.append(kept_states[new_idx])
            saved_encodings.append(kept_encodings[:, new_idx])
            explored_state.append(False)
            adjacency_list.append([])
        
print(time.time() - t0)

Reached iteration 1 with 1 states
Reached iteration 251 with 825 states
Reached iteration 501 with 1395 states
Reached iteration 751 with 2079 states
Reached iteration 1001 with 2592 states
Reached iteration 1251 with 3260 states
Reached iteration 1501 with 3934 states
Reached iteration 1751 with 4525 states
Reached iteration 2001 with 5095 states
Reached iteration 2251 with 5866 states
Reached iteration 2501 with 6587 states
Reached iteration 2751 with 7126 states
Reached iteration 3001 with 7817 states
Reached iteration 3251 with 8573 states
Reached iteration 3501 with 9340 states
Reached iteration 3751 with 10019 states
Reached iteration 4001 with 10745 states
Reached iteration 4251 with 11478 states
Reached iteration 4501 with 12293 states
Reached iteration 4751 with 13020 states
Reached iteration 5001 with 13725 states
Reached iteration 5251 with 14416 states
Reached iteration 5501 with 15278 states
Reached iteration 5751 with 15953 states
Reached iteration 6001 with 16515 states


In [30]:
if not torch.is_tensor(saved_encodings):
    saved_encodings = torch.stack(saved_encodings)

if not torch.is_tensor(saved_states):
    saved_states = torch.stack(saved_states)

In [None]:
if fig:
    plt.close(fig)
fig = plt.figure()
within_range_list = torch.tensor(within_range_list)
plt.plot(np.arange(within_range_list.size(0)), uniform_filter1d(within_range_list, size=128))
plt.show()

In [224]:
# torch.save(torch.stack(saved_encodings), 'saved_encodings_1e6')
# torch.save(torch.stack(saved_states), 'saved_states_1e6')
# with open('adjecency_list_1e6', 'wb') as file:
#     pickle.dump(adjacency_list, file)

In [48]:
saved_encodings = torch.load('saved_encodings_1e6').to(device)
saved_states = torch.load('saved_states_1e6').to(device)
with open('adjecency_list_1e6', 'rb') as file:
    adjacency_list = pickle.load(file)

In [25]:
# plt.close(fig)
fig = plt.figure()

x = saved_encodings[:, 0, 0]
y = saved_encodings[:, 0, 1]
z = saved_encodings[:, 0, 2]

ax = plt.axes(projection ='3d')
ax.scatter3D(x, y, z, alpha=0.1)
# for idx in range(100):
#     ax.scatter3D(succesful_encodings[-idx, 0, 0], succesful_encodings[-idx, 0, 1], succesful_encodings[-idx, 0, 2], c='red')

# for idx in range(100):
#     ax.scatter3D(succesful_encodings[idx, 0, 0], succesful_encodings[idx, 0, 1], succesful_encodings[idx, 0, 2], c='green')

# for idx in indices:
#     ax.scatter3D(saved_encodings[idx, 0, 0], saved_encodings[idx, 0, 1], saved_encodings[idx, 0, 2], c='green')

ax.view_init(elev=0, azim=45)
fig.savefig('1-graph-sphere.png')

ax.view_init(elev=30, azim=90)
fig.savefig('2-graph-sphere.png')


ax.view_init(elev=15, azim=0)
fig.savefig('3-graph-sphere.png')

ax.view_init(elev=90, azim=0)
fig.savefig('4-graph-sphere.png')

fig.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [26]:
distances = torch.cdist(saved_encodings, saved_encodings).mean(-1).mean(-1).numpy()
fig, axs = plt.subplots(2)
axs[0].plot(np.arange(distances.shape[0]), np.log10(uniform_filter1d(distances, size=4096)))
axs[1].boxplot(np.log10(distances))
fig.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [27]:
try:
    plt.close(fig)
except:
    pass
fig = plt.figure()


if not torch.is_tensor(saved_encodings):
    saved_encodings = torch.stack(saved_encodings)

batch_size, num_nodes, dimensionality = saved_encodings.shape
ax = plt.axes(projection ='3d')
random_idx = np.random.choice(int(1e5))
sphere_skeleton = fibonacci_sphere(1000)

ax.scatter3D(sphere_skeleton[:, 0], sphere_skeleton[:, 1], sphere_skeleton[:, 2], alpha=0.02)

for node in range(num_nodes):
    x = saved_encodings[random_idx, node, 0]
    y = saved_encodings[random_idx, node, 1]
    z = saved_encodings[random_idx, node, 2]

    ax.scatter3D(x, y, z)

ax.view_init(elev=0, azim=45)
fig.savefig('one-state-all-dimensions/view1.png')

ax.view_init(elev=30, azim=90)
fig.savefig('one-state-all-dimensions/view2.png')


ax.view_init(elev=15, azim=0)
fig.savefig('one-state-all-dimensions/view3.png')

ax.view_init(elev=90, azim=0)
fig.savefig('one-state-all-dimensions/view4.png')

plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

IndexError: index 92305 is out of bounds for dimension 0 with size 50002

In [68]:
random_idx = np.random.choice(int(1e6))
random_state = torch.from_numpy(states[morphIdx][random_idx]).unsqueeze(0)
g = getNewGraph(env, morphIdx).cpu()
random_encoding = encoderGNN(g, random_state)[:, 0].detach().cpu()
encoding_dst = torch.norm(saved_encodings - random_encoding, dim=-1).mean(-1)
encoding_dst_values, encoding_dst_indices = torch.topk(encoding_dst, k=1024, largest=False)
g = getNewGraph(env, morphIdx).cpu()
decoded_states = decoderGNN(g, saved_encodings[indices].transpose(0,1))
decoded_states[:, :, :5] = decoded_states[:, :, :5].mean(dim=0)
decoded_states = reconstructStateFromGraph(decoded_states)
state_reconstructin_dst = ((decoded_states - random_state) ** 2).mean(-1)
state_dst_values, state_dst_indices = torch.topk(state_reconstructin_dst, k=100, largest=False)
for i in range(100):
    print(state_dst_values[i].item(), encoding_dst_values[state_dst_indices[i]].item())

0.31565554174505817 0.08468402922153473
0.3255173258025794 0.10427697002887726
0.3510498726671994 0.08252780884504318
0.36165502126080984 0.09851332753896713
0.37021095213548727 0.06532897055149078
0.37191065732923806 0.0619584359228611
0.37394992900795737 0.10268601030111313
0.3750120293497547 0.07801061123609543
0.37545736456748524 0.07362300902605057
0.3769017676786426 0.07642907649278641
0.3780011946544113 0.09347216784954071
0.3791947771960488 0.0682683065533638
0.38064690722856637 0.08713258802890778
0.38295015723153397 0.06372345983982086
0.38657553353577023 0.05872494727373123
0.3880243956575652 0.0633518174290657
0.38804781124472537 0.07083785533905029
0.38924660810007455 0.07098906487226486
0.39073408879923666 0.08038192242383957
0.39139079762926793 0.09408371150493622
0.3956888588529802 0.09944630414247513
0.3965938772642021 0.10067739337682724
0.39834836678336943 0.05604859068989754
0.3989453734784868 0.059946928173303604
0.3993954404678788 0.07223814725875854
0.40290628337

In [109]:
with torch.no_grad():
    random_indeces = np.random.choice(int(1e6), size=8192)
    random_states = torch.from_numpy(states[morphIdx][random_indeces])
    g1 = getNewGraph(env, morphIdx).cpu()
    random_encodings = encoderGNN(g1, random_states)
    g2 = getNewGraph(env, morphIdx).cpu()
    states_reconstructed = decoderGNN(g2, random_encodings)
    states_reconstructed[:, :, :5] = states_reconstructed[:, :, :5].mean(dim=0)
    states_reconstructed = reconstructStateFromGraph(states_reconstructed)

In [36]:
alls = []
states[morphIdx] = states[morphIdx].to(device)
for i in range(saved_states.size(0)):
    if i % 5000 == 0:
        print(i)
    min_dst = ((states[morphIdx] - saved_states[i].to(device)) ** 2).mean(-1).min()
    alls.append(min_dst)

if fig:
    plt.close(fig)
    
fig = plt.figure()
plt.boxplot(torch.stack(alls).cpu().numpy())
fig.show()

0
5000
10000
15000
20000
25000
30000
35000
40000
45000
50000


In [15]:
def addNodeToGraph(new_state, new_encoding, saved_encodings, saved_states, adjacency_list, encoderGNN, decoderGNN, 
                   inverseDynamics, device, env, morphIdx, alpha=0.4, batch_size=1024, transition_threshold=0.9, forwards=True):
    
    with torch.no_grad():
    
        if new_encoding is None:
            g = getNewGraph(env, morphIdx).to(device)
            new_encoding = encoderGNN(g, new_state.unsqueeze(0)).squeeze(1)

        # Check if node already exists in the graph

#         smallest_mse_exisiting_states = ((saved_states - new_state) ** 2).mean(-1).min()
#         if smallest_mse_exisiting_states <= 1e-6:
#             print('State already exists')
#             return

        # tensor of size (len(saved_encodings),)
        encoding_distances = torch.norm(saved_encodings - new_encoding, dim=-1).mean(-1)
        boolean_indeces_within_range = encoding_distances <= alpha
        indeces_within_range = torch.arange(end=saved_encodings.size(0))[boolean_indeces_within_range]
        states_within_range = saved_states[indeces_within_range]
                
        num_batches = int(np.ceil((states_within_range.size(0) / batch_size)))
        
        valid_neighbors = []
        for batch in range(num_batches):
            
            transition_states = states_within_range[batch * batch_size : (batch + 1) * batch_size]
            
            if forwards:
                transitions = new_state.repeat(transition_states.size(0), 2)
                transitions [:, transitions.size(1) // 2:] -= transition_states
            else:
                transitions = transition_states.repeat(1, 2)
                transitions [:, transitions.size(1) // 2:] -= new_state
                
            g = getNewGraph(env, morphIdx).to(device)
            probabilities = inverseDynamics(g, transitions)
            boolean_valid_transition_indeces = probabilities > transition_threshold
            possibile_batch_indeces = torch.arange(start=batch * batch_size, end=min((batch + 1) * batch_size, states_within_range.size(0)))
            valid_batch_indeces = possibile_batch_indeces[boolean_valid_transition_indeces]
            valid_global_indeces = indeces_within_range[valid_batch_indeces]
            
            valid_neighbors.extend(valid_batch_indeces.tolist())
    
    saved_encodings = torch.cat((saved_encodings, new_encoding.unsqueeze(0)), dim=0)
    saved_states = torch.cat((saved_states, new_state.unsqueeze(0)), dim=0)
    
    if forwards:
        adjacency_list.append(valid_neighbors)
    else:
        adjacency_list.append([])
        new_state_idx = len(adjacency_list) - 1
        for neigbor_idx in valid_neighbors:
            adjacency_list[neigbor_idx].append(new_state_idx)

In [31]:
new_state = torch.from_numpy(states[morphIdx][1]).cuda()
t0 = time.time()
addNodeToGraph(new_state, None, saved_encodings, saved_states, adjacency_list, encoderGNN, decoderGNN, inverseDynamics, device, env, morphIdx, forwards=True)

In [21]:
new_state = torch.from_numpy(states[morphIdx][999]).cuda()
t0 = time.time()
addNodeToGraph(new_state, None, saved_encodings, saved_states, adjacency_list, encoderGNN, decoderGNN, inverseDynamics, device, env, morphIdx, forwards=False)

In [54]:
lengths = []
is_zero = 0
for l in adjacency_list:
    lengths.append(len(l))
    if len(l) == 0:
        is_zero += 1
        
lengths = np.array(lengths)
print(len(lengths) - is_zero)

if fig:
    plt.close(fig)
fig = plt.figure()
plt.boxplot(np.array(lengths))
fig.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [74]:
def bfs(adjacency_list, start, end):
    # maintain a queue of paths
    queue = []
    # push the first path into the queue
    queue.append([start])
    while queue:
        # get the first path from the queue
        path = queue.pop(0)
        # get the last node from the path
        node = path[-1]
        # path found
        if node == end:
            return path
        # enumerate all adjacent nodes, construct a 
        # new path and push it into the queue
        neighbors = adjacency_list[node]
        for adjacent in neighbors:
            print(path)
            new_path = list(path)
            print(new_path)
            new_path.append(adjacent)
            queue.append(new_path)