In [49]:
%config Completer.use_jedi = False

In [50]:
import numpy as np
import gym
from collections import deque
import random
import torch.autograd
import os
import math
import time
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F 
from torch.autograd import Variable
import sys
import pickle
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.cm as cm
import matplotlib.markers as markers
import pybullet as p 
if torch.cuda.is_available():
    device = torch.device("cuda:0")  # you can continue going on here, like cuda:1 cuda:2....etc. 
    print("Running on the GPU")
else:
    device = torch.device("cpu")
    print("Running on the CPU")
import networkx as nx
from tqdm import tqdm
import dgl
import morphsim as m
from graphenvs import HalfCheetahGraphEnv
import itertools
import queue

%matplotlib widget

Running on the GPU


In [51]:
class NetworkInverseDynamics(nn.Module):
    def __init__(
        self,
        input_size,
        output_size,
        hidden_sizes,
        with_batch_norm=False,
        activation=None
    ):
        super(NetworkInverseDynamics, self).__init__()
        self.hidden_sizes = hidden_sizes
        self.input_size = input_size
        self.output_size = output_size
        
        self.layers = nn.ModuleList()

        self.layers.append(nn.Linear(self.input_size, hidden_sizes[0]))
        if with_batch_norm:
            self.layers.append(nn.LayerNorm(normalized_shape=(hidden_sizes[0])))
        self.layers.append(nn.ReLU())
        
        for i in range(len(hidden_sizes) - 1):
            self.layers.append(nn.Linear(hidden_sizes[i], hidden_sizes[i+1]))
            if with_batch_norm:
                self.layers.append(nn.LayerNorm(normalized_shape=(hidden_sizes[i+1])))
            self.layers.append(nn.ReLU())
        
        self.layers.append(nn.Linear(hidden_sizes[len(hidden_sizes) - 1], self.output_size))
        
        if activation is not None:
            self.layers.append(activation())
            
    def forward(self, x):
        out = x
        
        for layer in self.layers:
            out = layer(out)
            
        return out


In [52]:
class GNNInverseDynamics(nn.Module):
    def __init__(
        self,
        inputNetwork,
        messageNetwork,
        updateNetwork,
        outputNetwork,
        numMessagePassingIterations,
        withInputNetwork = True
    ):
        
        super(GNNInverseDynamics, self).__init__()
                
        self.inputNetwork = inputNetwork
        self.messageNetwork = messageNetwork
        self.updateNetwork = updateNetwork
        self.outputNetwork = outputNetwork
        
        self.numMessagePassingIterations = numMessagePassingIterations
        self.withInputNetwork = withInputNetwork
        
    def inputFunction(self, nodes):
        return {'state' : self.inputNetwork(nodes.data['input'])}
    
    def messageFunction(self, edges):
        
        batchSize = edges.src['state'].shape[1]
        edgeData = edges.data['feature'].repeat(batchSize, 1).T.unsqueeze(-1)
        nodeInput = edges.src['input']
        
        return {'m' : self.messageNetwork(torch.cat((edges.src['state'], edgeData, nodeInput), -1))}
    
    def updateFunction(self, nodes):
        return {'state': self.updateNetwork(torch.cat((nodes.data['m_hat'], nodes.data['state']), -1))}
    
    def outputFunction(self, nodes):
        
        return {'output': self.outputNetwork(nodes.data['state'])}


    def forward(self, graph, state):
        
        self.update_states_in_graph(graph, state)
        
        if self.withInputNetwork:
            graph.apply_nodes(self.inputFunction)
        
        for messagePassingIteration in range(self.numMessagePassingIterations):
            graph.update_all(self.messageFunction, dgl.function.mean('m', 'm_hat'), self.updateFunction)
        
        graph.apply_nodes(self.outputFunction)
        
        output = graph.ndata['output']
        output = output.squeeze(-1).mean(0)
                
        return output
    
    def update_states_in_graph(self, graph, state):
        if len(state.shape) == 1:
            state = state.unsqueeze(0)
        
        numGraphFeature = 6
        numGlobalStateInformation = 5
        numLocalStateInformation = 2
        numStateVar = state.shape[1] // 2
        globalInformation = torch.cat((state[:, 0:5], state[:, numStateVar:numStateVar+5]), -1)
        
        numNodes = (numStateVar - 5) // 2

        nodeData = torch.empty((numNodes, state.shape[0], numGraphFeature + 2 * numGlobalStateInformation + 2 * numLocalStateInformation)).to(device)
        for nodeIdx in range(numNodes):

            # Assign global features from graph
            nodeData[nodeIdx, :, :6] = graph.ndata['feature'][nodeIdx]
            # Assign local state information
            nodeData[nodeIdx, :, 16] = state[:, 5 + nodeIdx]
            nodeData[nodeIdx, :, 17] = state[:, 5 + numNodes + nodeIdx]
            nodeData[nodeIdx, :, 18] = state[:, numStateVar + 5 + nodeIdx]
            nodeData[nodeIdx, :, 19] = state[:, numStateVar + 5 + numNodes + nodeIdx]

        # Assdign global state information
        nodeData[:, :, 6:16] = globalInformation
        
        if self.withInputNetwork:
            graph.ndata['input'] = nodeData        
        
        else:
            graph.ndata['state'] = nodeData


In [53]:
class NetworkAutoEncoder(nn.Module):
    def __init__(
        self,
        input_size,
        output_size,
        hidden_sizes,
        batch_size=256, # Needed only for batch norm
        with_batch_norm=False,
        activation=None
    ):
        super(NetworkAutoEncoder, self).__init__()
        self.hidden_sizes = hidden_sizes
        self.input_size = input_size
        self.output_size = output_size
        
        self.layers = nn.ModuleList()

        self.layers.append(nn.Linear(self.input_size, hidden_sizes[0]))
        if with_batch_norm:
#             self.layers.append(nn.BatchNorm1d(batch_size))
            self.layers.append(nn.LayerNorm(normalized_shape=(hidden_sizes[0])))
        self.layers.append(nn.ReLU())
        
        for i in range(len(hidden_sizes) - 1):
            self.layers.append(nn.Linear(hidden_sizes[i], hidden_sizes[i+1]))
            if with_batch_norm:
#                 self.layers.append(nn.BatchNorm1d(batch_size))
                self.layers.append(nn.LayerNorm(normalized_shape=(hidden_sizes[i+1])))
            self.layers.append(nn.ReLU())
        
        self.layers.append(nn.Linear(hidden_sizes[len(hidden_sizes) - 1], self.output_size))
        
        if activation is not None:
            self.layers.append(activation())
            
    def forward(self, x):
        out = x
        
        for layer in self.layers:
            out = layer(out)
            
        return out

In [54]:
class GraphNeuralNetworkAutoEncoder(nn.Module):
    def __init__(
            self,
            inputNetwork,
            messageNetwork,
            updateNetwork,
            outputNetwork,
            numMessagePassingIterations,
            encoder=True
    ):

        super(GraphNeuralNetworkAutoEncoder, self).__init__()

        self.inputNetwork = inputNetwork
        self.messageNetwork = messageNetwork
        self.updateNetwork = updateNetwork
        self.outputNetwork = outputNetwork

        self.numMessagePassingIterations = numMessagePassingIterations
        self.encoder = encoder

    def inputFunction(self, nodes):
        return {'state': self.inputNetwork(nodes.data['input'])}

    def messageFunction(self, edges):

        batchSize = edges.src['state'].shape[1]
        edgeData = edges.data['feature'].repeat(batchSize, 1).T.unsqueeze(-1)
        nodeInput = edges.src['input']

        #         print(edges.src['state'].shape)
        #         print(nodeInput.shape)
        return {'m': self.messageNetwork(torch.cat((edges.src['state'], edgeData, nodeInput), -1))}

    def updateFunction(self, nodes):
        return {'state': self.updateNetwork(torch.cat((nodes.data['m_hat'], nodes.data['state']), -1))}

    def outputFunction(self, nodes):

        #         numNodes, batchSize, stateSize = graph.ndata['state'].shape
        #         return self.outputNetwork.forward(graph.ndata['state'])
        return {'output': self.outputNetwork(nodes.data['state'])}

    def forward(self, graph, state):

        self.update_states_in_graph(graph, state)

        graph.apply_nodes(self.inputFunction)

        for messagePassingIteration in range(self.numMessagePassingIterations):
            graph.update_all(self.messageFunction, dgl.function.max('m', 'm_hat'), self.updateFunction)

        graph.apply_nodes(self.outputFunction)

        output = graph.ndata['output']

        if self.encoder:
            output = F.normalize(output, dim=-1)

        return output

    def update_states_in_graph(self, graph, state):

        if self.encoder:
            if len(state.shape) == 1:
                state = state.unsqueeze(0)

            numGraphFeature = 6
            numGlobalStateInformation = 5
            numLocalStateInformation = 2
            numStateVar = state.shape[1]
            globalInformation = state[:, 0:5]
            batch_size = state.shape[0]
            numNodes = (numStateVar - 5) // 2

            nodeData = torch.empty(
                (numNodes, batch_size, numGraphFeature + numGlobalStateInformation + numLocalStateInformation)).to(
                device)

            nodeData[:, :, 0:numGlobalStateInformation] = globalInformation
            for nodeIdx in range(numNodes):
                # Assign local state information
                nodeData[nodeIdx, :, numGlobalStateInformation] = state[:, 5 + nodeIdx]
                nodeData[nodeIdx, :, numGlobalStateInformation + 1] = state[:, 5 + numNodes + nodeIdx]
                # Assign global features from graph
                nodeData[nodeIdx, :, numGlobalStateInformation + 2: numGlobalStateInformation + 2 + numGraphFeature] = \
                graph.ndata['feature'][nodeIdx]

            graph.ndata['input'] = nodeData

        else:
            numNodes, batchSize, inputSize = state.shape
            nodeData = torch.empty((numNodes, batchSize, inputSize + 6)).to(device)
            nodeData[:, :, :inputSize] = state
            nodeData[:, :, inputSize: inputSize + 6] = graph.ndata['feature'].unsqueeze(dim=1).repeat_interleave(
                batchSize, dim=1)
            #             for nodeIdx in range(numNodes):
            #                 nodeData[nodeIdx, :, inputSize : inputSize + 6] = graph.ndata['feature'][nodeIdx]

            graph.ndata['input'] = nodeData

In [55]:
states = {}
actions = {}
rewards = {}
next_states = {}
dones = {}
env = {}

for morphIdx in [5]:

    prefix = '../datasets/{}/'.format(morphIdx)
    
    states[morphIdx] = np.load(prefix + 'states_array.npy')
    actions[morphIdx] = np.load(prefix + 'actions_array.npy')
    rewards[morphIdx] = np.load(prefix + 'rewards_array.npy')
    next_states[morphIdx] = np.load(prefix + 'next_states_array.npy')
    dones[morphIdx] = np.load(prefix + 'dones_array.npy')
    
    env[morphIdx] = HalfCheetahGraphEnv(None)
    env[morphIdx].set_morphology(morphIdx)
    env[morphIdx].reset()

None
*************************************************************************************************************


NoneType: None


In [56]:
def getNewGraph(env, morphIdx):
    return env[morphIdx].get_graph()._get_dgl_graph()

In [57]:
hidden_sizes = [256, 256]

inputSize = 13
stateSize = 64
messageSize = 64
latentSize = 3
numMessagePassingIterations = 6
batch_size = 1024
numBatchesPerTrainingStep = 1
minRandomDistance = 1
maxSequentialDistance = 0.04
with_batch_norm = True

# # Encoder Networks
encoderInputNetwork = NetworkAutoEncoder(inputSize, stateSize, hidden_sizes, with_batch_norm=with_batch_norm)
encoderMessageNetwork = NetworkAutoEncoder(stateSize + inputSize + 1, messageSize, hidden_sizes, with_batch_norm=with_batch_norm, activation=nn.Tanh)
encoderUpdateNetwork = NetworkAutoEncoder(stateSize + messageSize, stateSize, hidden_sizes, with_batch_norm=with_batch_norm)
encoderOutputNetwork = NetworkAutoEncoder(stateSize, latentSize, hidden_sizes, with_batch_norm=with_batch_norm)
encoderGNN = GraphNeuralNetworkAutoEncoder(encoderInputNetwork, encoderMessageNetwork, encoderUpdateNetwork, encoderOutputNetwork, numMessagePassingIterations, encoder=True).to(device)

# # Decoder Networks
decoderInputNetwork = NetworkAutoEncoder(latentSize + 6, stateSize, hidden_sizes, with_batch_norm=with_batch_norm)
decoderMessageNetwork = NetworkAutoEncoder(stateSize + latentSize + 7, messageSize, hidden_sizes, with_batch_norm=with_batch_norm, activation=nn.Tanh)
decoderUpdateNetwork = NetworkAutoEncoder(stateSize + messageSize, stateSize, hidden_sizes, with_batch_norm=with_batch_norm)
decoderOutputNetwork = NetworkAutoEncoder(stateSize, 7, hidden_sizes, with_batch_norm=with_batch_norm)
decoderGNN = GraphNeuralNetworkAutoEncoder(decoderInputNetwork, decoderMessageNetwork, decoderUpdateNetwork, decoderOutputNetwork, numMessagePassingIterations, encoder=False).to(device)

# encoderGNN.load_state_dict(torch.load('../models/new/4-latent-single-GNN-AutoEncoder/5/0.0-1.5/encoderGNN.pt'))
# decoderGNN.load_state_dict(torch.load('../models/new/4-latent-single-GNN-AutoEncoder/5/0.0-1.5/decoderGNN.pt'))

print(encoderGNN.load_state_dict(torch.load('../models/new/3-latent-single-GNN-AutoEncoder/5/0.4-1.33/' + 'encoderGNN.pt')))
print(decoderGNN.load_state_dict(torch.load('../models/new/3-latent-single-GNN-AutoEncoder/5/0.4-1.33/' + 'decoderGNN.pt')))

# print(encoderGNN.load_state_dict(torch.load('../models/new/3-latent-single-GNN-AutoEncoder/5/no-contrastive/' + 'encoderGNN.pt')))
# print(decoderGNN.load_state_dict(torch.load('../models/new/3-latent-single-GNN-AutoEncoder/5/no-contrastive/' + 'decoderGNN.pt')))


<All keys matched successfully>
<All keys matched successfully>


In [58]:
hidden_sizes = [256, 256]

inputSize = 20
stateSize = 64
messageSize = 64
outputSize = 1
numMessagePassingIterations = 6
batch_size = 2048
with_batch_norm = True
numBatchesPerTrainingStep = 1

inputNetwork = NetworkInverseDynamics(inputSize, stateSize, hidden_sizes, with_batch_norm)
messageNetwork = NetworkInverseDynamics(stateSize + inputSize + 1, messageSize, hidden_sizes, with_batch_norm, nn.Tanh)
updateNetwork = NetworkInverseDynamics(stateSize + messageSize, stateSize, hidden_sizes, with_batch_norm)
outputNetwork = NetworkInverseDynamics(stateSize, outputSize, hidden_sizes, with_batch_norm, nn.Sigmoid)

inverseDynamics = GNNInverseDynamics(inputNetwork, messageNetwork, updateNetwork, outputNetwork, numMessagePassingIterations).to(device)
inverseDynamics.load_state_dict(torch.load('mixed-delta-validTransition.pt'))

<All keys matched successfully>

In [59]:
def reconstructStateFromGraph(graph_data):
    num_nodes, batch_size, data_size = graph_data.shape
    output = torch.empty((batch_size, 5 + 2 * num_nodes))
    output[:, :5] = graph_data[:, :, :5].mean(dim=0)
    output[:, 5 : 5+num_nodes] = graph_data[:, :, 5].T
    output[:, 5+num_nodes:] = graph_data[:, :, 6].T
    
    return output

In [90]:
saved_states = []
saved_encodings = []
adjacency_list = []

morphIdx = 5
alpha = 0.4
num_offsets = 2048
transition_threshold = 0.9
num_samples = 10000

# Just to get things started
random_idx = np.random.choice(int(1e6))
random_state = torch.from_numpy(states[morphIdx][random_idx]).unsqueeze(0)
g = getNewGraph(env, morphIdx)
random_encoding = encoderGNN(g, random_state)[:, 0].detach().cpu()
saved_states.append(random_state)
saved_encodings.append(random_encoding)
adjacency_list.append([])

t0 = time.time()

with torch.no_grad():
    
    i = 0
    #     for i in range(50000):
    while len(saved_states) < 200000:
        
        i += 1
        
        if i % 1000 == 0:
            print('Reached iteration {} with {} states'.format(i, len(saved_states)))
            
        # np array of size batch: min(len(saved_states), num_samples) 
        random_subset_indeces = np.random.choice(len(saved_states), size=min(len(saved_states), num_samples))
        # tensor of shape (batch, 19)
        random_subset_states = [saved_states[i] for i in random_subset_indeces]
        # tensor of shape (batch, 7, 3)
        random_subset_encodings = torch.stack([saved_encodings[i] for i in random_subset_indeces]).to(device)
        # tensor of shape (7, batch, 3)
        random_subset_encodings = random_subset_encodings.transpose(0, 1)
        
        # tensor of shape (batch)
        distances = torch.cdist(random_subset_encodings, random_subset_encodings).sum(dim=0).sum(dim=-1)
        largest_distance_idx = torch.argmax(distances)
        
        # tensor of shape (batch, 7, 3)
        random_subset_encodings = random_subset_encodings.transpose(0, 1)
        current_encoding = random_subset_encodings[largest_distance_idx]
        current_state = random_subset_states[largest_distance_idx]
                            
        offsets = torch.normal(0, 1, (current_encoding.size(0), num_offsets, current_encoding.size(1))).to(device)
        offsets = F.normalize(offsets, dim=-1)
        offsets *= alpha
#         offsets *= np.random.uniform(low=0, high=alpha, size=num_offsets)[..., np.newaxis][np.newaxis, ...]

        # Convert coordinates from sphere into global ones, and project onto original sphere
        new_encodings = current_encoding.unsqueeze(1) - offsets
        new_encodings = F.normalize(new_encodings, dim=-1)
        new_encodings = new_encodings.to(device)

        g = getNewGraph(env, morphIdx)
        new_states = decoderGNN(g, new_encodings)
        new_states = reconstructStateFromGraph(new_states)
        transitions = current_state.repeat(num_offsets, 2)
        transitions [:, transitions.size(1) // 2:] -= new_states

        g = getNewGraph(env, morphIdx)
        probabilities = inverseDynamics(g, transitions).cpu()
        
        valid_transition_indeces = probabilities > transition_threshold

        kept_states = new_states[valid_transition_indeces]
        kept_encodings = new_encodings[:, valid_transition_indeces]
        
        largest_distance_idx_global = random_subset_indeces[largest_distance_idx]
        adjacency_list[largest_distance_idx_global].extend(range(len(saved_states), len(saved_states) + kept_states.size(0)))
        
        for new_idx in range(kept_states.size(0)):
            saved_states.append(kept_states[new_idx].cpu())
            saved_encodings.append(kept_encodings[:, new_idx].cpu())
            adjacency_list.append([])
        
print(time.time() - t0)

Reached iteration 1000 with 687 states
Reached iteration 2000 with 762 states
Reached iteration 3000 with 779 states
Reached iteration 4000 with 1233 states
Reached iteration 5000 with 1247 states
Reached iteration 6000 with 1419 states
Reached iteration 7000 with 1438 states
1313.729305267334


In [91]:
saved_encodings = torch.stack(saved_encodings)
print(saved_encodings.shape)

torch.Size([2002, 7, 3])


In [77]:
adjacency_list

[[1, 2],
 [30, 31, 32, 33, 34, 35, 36, 37],
 [3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [22, 23, 24, 25, 26, 27, 28, 29, 126, 127, 128, 129, 130, 131, 132, 133, 134],
 [],
 [],
 [],
 [],
 [138, 174],
 [139, 173, 175, 176],
 [],
 [],
 [],
 [38],
 [],
 [],
 [],
 [],
 [],
 [],
 [39],
 [40, 41, 42, 43, 44, 45, 46, 47, 48, 49],
 [],
 [],
 [],
 [],
 [50, 51, 52, 53, 54, 55, 56, 57, 58, 59],
 [],
 [],
 [],
 [],
 [],
 [64, 65],
 [],
 [],
 [],
 [846, 847, 848, 888, 901, 902, 903],
 [],
 [],
 [60, 61, 62, 63],
 [],
 [],
 [],
 [67],
 [],
 [],
 [66],
 [],
 [68, 69, 70, 71],
 [],
 [],
 [73, 74],
 [],
 [72],
 [],
 [],
 [75, 76, 77],
 [78, 79, 80],
 [],
 [],
 [87],
 [],
 [81, 82, 83, 84, 85, 86],
 [],
 [],
 [],
 [88, 89, 90],
 [],
 [],
 [],
 [91, 92, 93, 94, 95, 96, 97, 98],
 [],
 [],
 [99,
  100,
  101,
  102,
  103,
  104,
  105,
  106,
  107,
  108,
  109,
  110,
  111,
  112,
  1

In [92]:
fig = plt.figure()

x = saved_encodings[:, 0, 0].cpu().numpy()
y = saved_encodings[:, 0, 1].cpu().numpy()
z = saved_encodings[:, 0, 2].cpu().numpy()

ax = plt.axes(projection ='3d')
ax.scatter3D(x, y, z)
plt.show()


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …