In [1]:
import numpy as np
import gym
from collections import deque
import random
import torch.autograd
import os
import time
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F 
from torch.autograd import Variable
import sys
import pickle
import matplotlib.pyplot as plt
import pybullet as p 
# if torch.cuda.is_available():
#     device = torch.device("cuda:0")  # you can continue going on here, like cuda:1 cuda:2....etc. 
#     print("Running on the GPU")
# else:
#     device = torch.device("cpu")
#     print("Running on the CPU")
device = torch.device("cpu")
import networkx as nx
from tqdm import tqdm
import dgl
import morphsim as m
from graphenvs import HalfCheetahGraphEnv
import itertools

Using backend: pytorch


In [2]:
class Network(nn.Module):
    def __init__(
        self,
        input_size,
        output_size,
        hidden_sizes,
        batch_size=256, # Needed only for batch norm
        with_batch_norm=False,
        activation=None
    ):
        super(Network, self).__init__()
        self.hidden_sizes = hidden_sizes
        self.input_size = input_size
        self.output_size = output_size
        
        self.layers = nn.ModuleList()

        self.layers.append(nn.Linear(self.input_size, hidden_sizes[0]))
        if with_batch_norm:
#             self.layers.append(nn.BatchNorm1d(batch_size))
            self.layers.append(nn.LayerNorm(normalized_shape=(hidden_sizes[0])))
        self.layers.append(nn.ReLU())
        
        for i in range(len(hidden_sizes) - 1):
            self.layers.append(nn.Linear(hidden_sizes[i], hidden_sizes[i+1]))
            if with_batch_norm:
#                 self.layers.append(nn.BatchNorm1d(batch_size))
                self.layers.append(nn.LayerNorm(normalized_shape=(hidden_sizes[i+1])))
            self.layers.append(nn.ReLU())
        
        self.layers.append(nn.Linear(hidden_sizes[len(hidden_sizes) - 1], self.output_size))
        
        if activation is not None:
            self.layers.append(activation())
            
    def forward(self, x):
        out = x
        
        for layer in self.layers:
            out = layer(out)
            
        return out

In [3]:
class GraphNeuralNetwork(nn.Module):
    def __init__(
        self,
        inputNetwork,
        messageNetwork,
        updateNetwork,
        outputNetwork,
        numMessagePassingIterations,
        encoder = True
    ):
        
        super(GraphNeuralNetwork, self).__init__()
                
        self.inputNetwork = inputNetwork
        self.messageNetwork = messageNetwork
        self.updateNetwork = updateNetwork
        self.outputNetwork = outputNetwork
        
        self.numMessagePassingIterations = numMessagePassingIterations
        self.encoder = encoder
        
    def inputFunction(self, nodes):
        return {'state' : self.inputNetwork(nodes.data['input'])}
    
    def messageFunction(self, edges):
        
        batchSize = edges.src['state'].shape[1]
        edgeData = edges.data['feature'].repeat(batchSize, 1).T.unsqueeze(-1)
        nodeInput = edges.src['input']
        
#         print(edges.src['state'].shape)
#         print(nodeInput.shape)
        return {'m' : self.messageNetwork(torch.cat((edges.src['state'], edgeData, nodeInput), -1))}
        

    def updateFunction(self, nodes):
        return {'state': self.updateNetwork(torch.cat((nodes.data['m_hat'], nodes.data['state']), -1))}
    
    def outputFunction(self, nodes):
        
#         numNodes, batchSize, stateSize = graph.ndata['state'].shape
#         return self.outputNetwork.forward(graph.ndata['state'])
        return {'output': self.outputNetwork(nodes.data['state'])}


    def forward(self, graph, state):
        
        self.update_states_in_graph(graph, state)
        
        graph.apply_nodes(self.inputFunction)
        
        for messagePassingIteration in range(self.numMessagePassingIterations):
            graph.update_all(self.messageFunction, dgl.function.max('m', 'm_hat'), self.updateFunction)
        
        graph.apply_nodes(self.outputFunction)
        
        output = graph.ndata['output']
        
        return output
    
    def update_states_in_graph(self, graph, state):
        
        if self.encoder:
            if len(state.shape) == 1:
                state = state.unsqueeze(0)

            numGraphFeature = 6
            numGlobalStateInformation = 5
            numLocalStateInformation = 2
            numStateVar = state.shape[1]
            globalInformation = state[:, 0:5]
            batch_size = state.shape[0]
            numNodes = (numStateVar - 5) // 2

            nodeData = torch.empty((numNodes, batch_size, numGraphFeature + numGlobalStateInformation + numLocalStateInformation)).to(device)

            nodeData[:, :, 0:numGlobalStateInformation] = globalInformation            
            for nodeIdx in range(numNodes):
                # Assign local state information
                nodeData[nodeIdx, :, numGlobalStateInformation] = state[:, 5 + nodeIdx]
                nodeData[nodeIdx, :, numGlobalStateInformation + 1] = state[:, 5 + numNodes + nodeIdx]
                # Assign global features from graph
                nodeData[nodeIdx, :, numGlobalStateInformation + 2 : numGlobalStateInformation + 2 + numGraphFeature] = graph.ndata['feature'][nodeIdx]

            graph.ndata['input'] = nodeData
        
        else:
            numNodes, batchSize, inputSize = state.shape
            nodeData = torch.empty((numNodes, batchSize, inputSize + 6)).to(device)
            nodeData[:, :, :inputSize] = state
            nodeData[:, :, inputSize : inputSize + 6] = graph.ndata['feature'].unsqueeze(dim=1).repeat_interleave(batchSize, dim=1)
#             for nodeIdx in range(numNodes):
#                 nodeData[nodeIdx, :, inputSize : inputSize + 6] = graph.ndata['feature'][nodeIdx]
            
            graph.ndata['input'] = nodeData
        

In [4]:
states = {}
actions = {}
rewards = {}
next_states = {}
dones = {}
env = {}

for morphIdx in range(7):

    prefix = '../datasets/{}/'.format(morphIdx)
    
    states[morphIdx] = np.load(prefix + 'states_array.npy')[:500]
    actions[morphIdx] = np.load(prefix + 'actions_array.npy')[:500]
    rewards[morphIdx] = np.load(prefix + 'rewards_array.npy')[:500]
    next_states[morphIdx] = np.load(prefix + 'next_states_array.npy')[:500]
    dones[morphIdx] = np.load(prefix + 'dones_array.npy')[:500]
    
    env[morphIdx] = HalfCheetahGraphEnv(None)
    env[morphIdx].set_morphology(morphIdx)
    env[morphIdx].reset()

NoneType: None


None
*************************************************************************************************************


NoneType: None


None
*************************************************************************************************************


NoneType: None


None
*************************************************************************************************************


NoneType: None


None
*************************************************************************************************************


NoneType: None


None
*************************************************************************************************************


NoneType: None


None
*************************************************************************************************************
None
*************************************************************************************************************


NoneType: None


In [5]:
hidden_sizes = [256, 256]

inputSize = 13
stateSize = 64
messageSize = 64
latentSize = 2
numMessagePassingIterations = 4
batch_size = 1024
numBatchesPerTrainingStep = 1
minRandomDistance = 1
maxSequentialDistance = 0.04
with_batch_norm = True

# # Encoder Networks 
encoderInputNetwork = Network(inputSize, stateSize, hidden_sizes, with_batch_norm=with_batch_norm)
encoderMessageNetwork = Network(stateSize + inputSize + 1, messageSize, hidden_sizes, with_batch_norm=with_batch_norm, activation=nn.Tanh)
encoderUpdateNetwork = Network(stateSize + messageSize, stateSize, hidden_sizes, with_batch_norm=with_batch_norm)
encoderOutputNetwork = Network(stateSize, latentSize, hidden_sizes, with_batch_norm=with_batch_norm, activation=nn.Tanh)
encoderGNN = GraphNeuralNetwork(encoderInputNetwork, encoderMessageNetwork, encoderUpdateNetwork, encoderOutputNetwork, numMessagePassingIterations, encoder=True).to(device)

# # Decoder Networks
decoderInputNetwork = Network(latentSize + 6, stateSize, hidden_sizes, with_batch_norm=with_batch_norm)
decoderMessageNetwork = Network(stateSize + latentSize + 7, messageSize, hidden_sizes, with_batch_norm=with_batch_norm, activation=nn.Tanh)
decoderUpdateNetwork = Network(stateSize + messageSize, stateSize, hidden_sizes, with_batch_norm=with_batch_norm)
decoderOutputNetwork = Network(stateSize, 7, hidden_sizes, with_batch_norm=with_batch_norm)
decoderGNN = GraphNeuralNetwork(decoderInputNetwork, decoderMessageNetwork, decoderUpdateNetwork, decoderOutputNetwork, numMessagePassingIterations, encoder=False).to(device)


encoderGNN.load_state_dict(torch.load('encoderGNN.pt'))
decoderGNN.load_state_dict(torch.load('decoderGNN.pt'))

# Optimizer
lr =  1e-5
optimizer = optim.Adam(itertools.chain(
                    encoderInputNetwork.parameters(), encoderMessageNetwork.parameters(), 
                    encoderUpdateNetwork.parameters(), encoderOutputNetwork.parameters(),
                    decoderInputNetwork.parameters(), decoderMessageNetwork.parameters(), 
                    decoderUpdateNetwork.parameters(), decoderOutputNetwork.parameters()),
                    lr, weight_decay=0)

lr_lambda = lambda epoch: 0.7
lr_scheduler = optim.lr_scheduler.MultiplicativeLR(optimizer, lr_lambda)
criterion  = nn.MSELoss(reduction='none')

In [6]:
prefix = 'multitask-'
numMessagePassingIterations = 4

encoderInputNetwork.load_state_dict(torch.load(prefix + 'encoderInputNetwork' + '.pt'))
encoderMessageNetwork.load_state_dict(torch.load(prefix + 'encoderMessageNetwork' + '.pt'))
encoderUpdateNetwork.load_state_dict(torch.load(prefix + 'encoderUpdateNetwork' + '.pt'))
encoderOutputNetwork.load_state_dict(torch.load(prefix + 'encoderOutputNetwork' + '.pt'))
encoderGNN = GraphNeuralNetwork(encoderInputNetwork, encoderMessageNetwork, encoderUpdateNetwork, encoderOutputNetwork, numMessagePassingIterations, encoder=True).to(device)

decoderInputNetwork.load_state_dict(torch.load(prefix + 'decoderInputNetwork' + '.pt'))
decoderMessageNetwork.load_state_dict(torch.load(prefix + 'decoderMessageNetwork' + '.pt'))
decoderUpdateNetwork.load_state_dict(torch.load(prefix + 'decoderUpdateNetwork' + '.pt'))
decoderOutputNetwork.load_state_dict(torch.load(prefix + 'decoderOutputNetwork' + '.pt'))
decoderGNN = GraphNeuralNetwork(decoderInputNetwork, decoderMessageNetwork, decoderUpdateNetwork, decoderOutputNetwork, numMessagePassingIterations, encoder=False).to(device)

encoderInputNetwork.eval()
encoderMessageNetwork.eval()
encoderUpdateNetwork.eval()
encoderOutputNetwork.eval()

decoderInputNetwork.eval()
decoderMessageNetwork.eval()
decoderUpdateNetwork.eval()
decoderOutputNetwork.eval()

Network(
  (layers): ModuleList(
    (0): Linear(in_features=64, out_features=256, bias=True)
    (1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    (2): ReLU()
    (3): Linear(in_features=256, out_features=256, bias=True)
    (4): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    (5): ReLU()
    (6): Linear(in_features=256, out_features=7, bias=True)
  )
)

In [7]:
a = torch.rand(2, 19)

In [8]:
morphIdx = 4

with torch.no_grad():
    g1 = env[morphIdx].get_graph()._get_dgl_graph().to('cpu')
    
    latentEncodings = encoderGNN(g1, a)
    originalInput = g1.ndata['input'][:, :, :7]

    g2 = env[morphIdx].get_graph()._get_dgl_graph().to('cpu')
    stateReconstruction = decoderGNN(g2, latentEncodings)
    stateReconstruction[:, :, 0:5] = stateReconstruction[:, :, 0:5].mean(dim=0)

In [9]:
print(g2.ndata['feature'])

tensor([[-0.5000,  0.0000,  0.0000,  0.3000,  0.0000,  0.7000],
        [-0.5000,  0.0000, -0.3000,  0.2000,  0.0000,  0.7000],
        [-0.5000,  0.0000, -0.5000,  0.2000,  0.0000,  0.7000],
        [-0.5000,  0.0000, -0.7000,  0.1500, -1.5700,  0.7000],
        [ 0.5000,  0.0000,  0.0000,  0.3000,  0.0000,  0.7000],
        [ 0.5000,  0.0000, -0.3000,  0.2000,  0.0000,  0.7000],
        [ 0.5000,  0.0000, -0.5000,  0.1500, -1.5700,  0.7000]])


In [100]:
print(a)
print(latentEncodings)
print(stateReconstruction)

tensor([[0.4082, 0.6828, 0.1377, 0.3356, 0.3284, 0.2448, 0.6077, 0.8391, 0.9633,
         0.3126, 0.0998, 0.7740, 0.5747, 0.2868, 0.4104, 0.0566, 0.5307, 0.4024,
         0.2544],
        [0.9969, 0.3216, 0.0069, 0.1808, 0.1510, 0.5144, 0.7217, 0.0204, 0.6995,
         0.8897, 0.2783, 0.8410, 0.9073, 0.0085, 0.3420, 0.1271, 0.6714, 0.5995,
         0.4799]])
tensor([[[-0.9305,  0.9545],
         [-0.9174,  0.9401]],

        [[-0.8990,  0.9438],
         [-0.8871,  0.9327]],

        [[-0.8137,  0.9071],
         [-0.8146,  0.8944]],

        [[-0.9405,  0.9658],
         [-0.9433,  0.9636]],

        [[-0.9150,  0.9417],
         [-0.9072,  0.9360]],

        [[-0.9005,  0.9560],
         [-0.8909,  0.9458]],

        [[-0.9325,  0.9459],
         [-0.9363,  0.9530]]])
tensor([[[-0.4061,  6.7473,  0.6282,  0.1058,  2.1843,  0.5489,  1.5153],
         [-0.4112,  6.7795,  0.6402,  0.0977,  2.1788,  0.5659,  1.5150]],

        [[-0.4061,  6.7473,  0.6282,  0.1058,  2.1843,  1.8745,  1.66

In [90]:
print(torch.nn.MSELoss(reduction='none')(stateReconstruction, originalInput).mean().item())

10.338685989379883


In [72]:
decoderGNN.state_dict()

OrderedDict([('inputNetwork.layers.0.weight',
              tensor([[-0.2555, -0.1739,  0.1580,  ...,  0.0895,  0.0754,  0.0935],
                      [ 0.2648, -0.1616,  0.0540,  ...,  0.1292, -0.3203,  0.2862],
                      [ 0.1270, -0.1272,  0.1431,  ..., -0.1147,  0.2002, -0.1269],
                      ...,
                      [ 0.2968,  0.1648,  0.2995,  ..., -0.2269,  0.0887,  0.1168],
                      [-0.2694, -0.4318,  0.0039,  ..., -0.0209, -0.2338, -0.0135],
                      [-0.0488, -0.0623, -0.2562,  ...,  0.1377,  0.3263, -0.0899]])),
             ('inputNetwork.layers.0.bias',
              tensor([ 0.1060, -0.3399, -0.0059,  0.2553,  0.0247,  0.1560,  0.3391,  0.0981,
                       0.1841, -0.2273,  0.2045, -0.3481, -0.0547,  0.2828, -0.0760, -0.1732,
                      -0.2082, -0.0143,  0.2593, -0.2022, -0.2030, -0.1152, -0.1895, -0.2356,
                       0.1740, -0.1269, -0.3403, -0.0389, -0.1594,  0.1803,  0.1764, -0.0820,
