In [1]:
import numpy as np
import gym
from collections import deque
import random
import torch.autograd
import os
import time
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F 
from torch.autograd import Variable
import sys
import pickle
import matplotlib.pyplot as plt
import pybullet as p 
from torch.utils.data.dataloader import DataLoader
import pybullet 
import pybullet_envs.gym_pendulum_envs 
import pybullet_envs.gym_locomotion_envs
if torch.cuda.is_available():
    device = torch.device("cuda:0")  # you can continue going on here, like cuda:1 cuda:2....etc. 
    print("Running on the GPU")
else:
    device = torch.device("cpu")
    print("Running on the CPU")
import networkx as nx
from tqdm import tqdm
import dgl
import morphsim as m
from graphenvs import HalfCheetahGraphEnv
import itertools

Running on the GPU


Using backend: pytorch


In [2]:
class Network(nn.Module):
    def __init__(
        self,
        input_size,
        output_size,
        hidden_sizes,
        batch_size,
        with_batch_norm=False,
        activation=None
    ):
        super(Network, self).__init__()
        self.hidden_sizes = hidden_sizes
        self.input_size = input_size
        self.output_size = output_size
        
        self.layers = nn.ModuleList()

        self.layers.append(nn.Linear(self.input_size, hidden_sizes[0]))
        if with_batch_norm:
            self.layers.append(nn.BatchNorm1d(batch_size))
        self.layers.append(nn.ReLU())
        
        for i in range(len(hidden_sizes) - 1):
            self.layers.append(nn.Linear(hidden_sizes[i], hidden_sizes[i+1]))
            if with_batch_norm:
                self.layers.append(nn.BatchNorm1d(batch_size))
            self.layers.append(nn.ReLU())
        
        self.layers.append(nn.Linear(hidden_sizes[len(hidden_sizes) - 1], self.output_size))
        
        if activation is not None:
            self.layers.append(activation())
            
    def forward(self, x):
        out = x
        
        for layer in self.layers:
            out = layer(out)
            
        return out

In [37]:
class GraphNeuralNetwork(nn.Module):
    def __init__(
        self,
        inputNetwork,
        messageNetwork,
        updateNetwork,
        outputNetwork,
        numMessagePassingIterations,
        encoder = True
    ):
        
        super(GraphNeuralNetwork, self).__init__()
                
        self.inputNetwork = inputNetwork
        self.messageNetwork = messageNetwork
        self.updateNetwork = updateNetwork
        self.outputNetwork = outputNetwork
        
        self.numMessagePassingIterations = numMessagePassingIterations
        self.encoder = encoder
        
    def inputFunction(self, nodes):
        return {'state' : self.inputNetwork(nodes.data['input'])}
    
    def messageFunction(self, edges):
        
        edgeData = edges.data['feature'].repeat(edges.src['state'].shape[1], 1).T.unsqueeze(-1)
        return {'m' : self.messageNetwork(torch.cat((edges.src['state'], edgeData), -1))}
        
 
    def updateFunction(self, nodes):
        return {'state': self.updateNetwork(torch.cat((nodes.data['m_hat'], nodes.data['state']), -1))}
    
    def outputFunction(self, nodes):
        
#         numNodes, batchSize, stateSize = graph.ndata['state'].shape
#         return self.outputNetwork.forward(graph.ndata['state'])
        return {'output': self.outputNetwork(nodes.data['state'])}


    def forward(self, graph, state):
        
        self.update_states_in_graph(graph, state)
        
        graph.apply_nodes(self.inputFunction)
        
        for messagePassingIteration in range(self.numMessagePassingIterations):
            graph.update_all(self.messageFunction, dgl.function.max('m', 'm_hat'), self.updateFunction)
        
        graph.apply_nodes(self.outputFunction)
        
        output = graph.ndata['output']
        
        return output
    
    def update_states_in_graph(self, graph, state):
        
        if self.encoder:
            if len(state.shape) == 1:
                state = state.unsqueeze(0)

            numGraphFeature = 6
            numGlobalStateInformation = 5
            numLocalStateInformation = 2
            numStateVar = state.shape[1] 
            globalInformation = state[:, 0:5]

            numNodes = (numStateVar - 5) // 2

            nodeData = torch.empty((numNodes, state.shape[0], numGraphFeature + numGlobalStateInformation + numLocalStateInformation)).to(device)

            nodeData[:, :, 0:numGlobalStateInformation] = globalInformation

            for nodeIdx in range(numNodes):
                # Assign global features from graph
                nodeData[nodeIdx, :, numGlobalStateInformation:numGlobalStateInformation + numGraphFeature] = graph.ndata['feature'][nodeIdx]
                # Assign local state information
                nodeData[nodeIdx, :, numGlobalStateInformation + numGraphFeature] = state[:, 5 + nodeIdx]
                nodeData[nodeIdx, :, numGlobalStateInformation + numGraphFeature + 1] = state[:, 5 + numNodes + nodeIdx]

            graph.ndata['input'] = nodeData
        
        else:
            numNodes, batchSize, inputSize = state.shape
            nodeData = torch.empty((numNodes, batchSize, inputSize + 6))
            nodeData[:, :, :inputSize] = state
            for nodeIdx in range(numNodes):
                nodeData[nodeIdx, :, inputSize : inputSize + 6] = graph.ndata['feature'][nodeIdx]
                
            graph.ndata['input'] = nodeData.to(device)
        

In [38]:
states = {}
actions = {}
rewards = {}
next_states = {}
dones = {}
env = {}

for morphIdx in range(7):

    prefix = '../datasets/{}/'.format(morphIdx)
    
    states[morphIdx] = np.load(prefix + 'states_array.npy')
    actions[morphIdx] = np.load(prefix + 'actions_array.npy')
    rewards[morphIdx] = np.load(prefix + 'rewards_array.npy')
    next_states[morphIdx] = np.load(prefix + 'next_states_array.npy')
    dones[morphIdx] = np.load(prefix + 'dones_array.npy')
    
    env[morphIdx] = HalfCheetahGraphEnv(None)
    env[morphIdx].set_morphology(morphIdx)
    env[morphIdx].reset()

NoneType: None
NoneType: None


None
*************************************************************************************************************
None
*************************************************************************************************************


NoneType: None


None
*************************************************************************************************************


NoneType: None
NoneType: None


None
*************************************************************************************************************
None
*************************************************************************************************************


NoneType: None


None
*************************************************************************************************************
None
*************************************************************************************************************


NoneType: None


In [39]:
X_test = {}
X_train = {}
Y_test = {}
Y_train = {}

for morphIdx in range(7):
    X = states[morphIdx]
    Y = next_states[morphIdx]
    permutation = np.random.permutation(X.shape[0])
    X = X[permutation]
    X_test[morphIdx] = torch.from_numpy(X[:100000]).float()
    X_train[morphIdx] = torch.from_numpy(X[100000:]).float()
    Y = Y[permutation]
    Y_test[morphIdx] = torch.from_numpy(Y[:100000]).float()
    Y_train[morphIdx] = torch.from_numpy(Y[100000:]).float()

In [54]:
hidden_sizes = [64, 64]

inputSize = 13
stateSize = 32
messageSize = 32
latentSize = 16
numMessagePassingIterations = 4
batch_size = 1024
numBatchesPerTrainingStep = 1
minRandomDistance = 1

# Encoder Networks 
encoderInputNetwork = Network(inputSize, stateSize, hidden_sizes, batch_size)
encoderMessageNetwork = Network(stateSize + 1, messageSize, hidden_sizes, batch_size, activation=nn.Tanh)
encoderUpdateNetwork = Network(stateSize + messageSize, stateSize, hidden_sizes, batch_size)
encoderOutputNetwork = Network(stateSize, latentSize, hidden_sizes, batch_size, activation=nn.Tanh)
encoderGNN = GraphNeuralNetwork(encoderInputNetwork, encoderMessageNetwork, encoderUpdateNetwork, encoderOutputNetwork, numMessagePassingIterations, encoder=True).to(device)

# Decoder Networks
decoderInputNetwork = Network(latentSize + 6, stateSize, hidden_sizes, batch_size)
decoderMessageNetwork = Network(stateSize + 1, messageSize, hidden_sizes, batch_size, activation=nn.Tanh)
decoderUpdateNetwork = Network(stateSize + messageSize, stateSize, hidden_sizes, batch_size)
decoderOutputNetwork = Network(stateSize, inputSize, hidden_sizes, batch_size, activation=nn.Tanh)
decoderGNN = GraphNeuralNetwork(decoderInputNetwork, decoderMessageNetwork, decoderUpdateNetwork, decoderOutputNetwork, numMessagePassingIterations, encoder=False).to(device)

# Optimizer
lr = 1e-4
optimizer = optim.Adam(itertools.chain(
                    encoderInputNetwork.parameters(), encoderMessageNetwork.parameters(), 
                    encoderUpdateNetwork.parameters(), encoderOutputNetwork.parameters(),
                    decoderInputNetwork.parameters(), decoderMessageNetwork.parameters(), 
                    decoderUpdateNetwork.parameters(), decoderOutputNetwork.parameters()),
                    lr, weight_decay=0)

lr_lambda = lambda epoch: 0.8
lr_scheduler = optim.lr_scheduler.MultiplicativeLR(optimizer, lr_lambda)
criterion  = nn.L1Loss(reduction='none')

In [None]:
numTrainingBatches = int(np.ceil(X_train[0].shape[0] / batch_size))
numTestingBatches = int(np.ceil(X_test[0].shape[0] / batch_size))

trainLosses = {}
testLosses = {}
validLosses = {}
trainingIdxs = [1]
validationIdxs = []

encoderInputNetworkGradients = []
encoderMessageNetworkGradients = []
encoderUpdateNetworkGradients = []
encoderOutputNetworkGradients = []
decoderInputNetworkGradients = []
decoderMessageNetworkGradients = []
decoderUpdateNetworkGradients = []
decoderOutputNetworkGradients = []

for morphIdx in range(7):
    trainLosses[morphIdx] = []
    testLosses[morphIdx] = []
    validLosses[morphIdx] = []

for epoch in range(10):
    
    for morphIdx in trainingIdxs:
        permutation = np.random.permutation(X_train[morphIdx].shape[0])
        X_train[morphIdx] = X_train[morphIdx][permutation]
        Y_train[morphIdx] = Y_train[morphIdx][permutation]
        
    stepLoss = None
    encoderGraph = []
    decoderGraph = []

    for batch in range(0, numTrainingBatches, numBatchesPerTrainingStep):
                
        t0 = time.time()
        
        for morphIdx in trainingIdxs:
            numNodes = (X_train[morphIdx].shape[1] - 5) // 2
            trainLosses[morphIdx].append(np.zeros(2))

        for batchOffset in range(numBatchesPerTrainingStep):
            
            if batch + batchOffset >= numTrainingBatches:
                break
                
            for morphIdx in trainingIdxs:
                encoderGraph.append(env[morphIdx].get_graph()._get_dgl_graph())
                decoderGraph.append(env[morphIdx].get_graph()._get_dgl_graph())

                current_states = X_train[morphIdx][(batch+batchOffset) * batch_size:(batch+batchOffset+1)*batch_size]
                next_states = Y_train[morphIdx][(batch+batchOffset) * batch_size:(batch+batchOffset+1)*batch_size]
                random_indexes = np.random.choice(X_train[0].shape[0],size=current_states.shape[0], replace=False)
                random_states = X_train[morphIdx][random_indexes]
                
                encoderInput = torch.cat((current_states, next_states, random_states), dim=0).to(device)
                latent_states = encoderGNN.forward(encoderGraph[-1], encoderInput)
                curren_state_reconstruction = decoderGNN.forward(decoderGraph[-1], latent_states[:, 0:current_states.shape[0], :])
                
                temp_batch_size = current_states.shape[0]
                autoencoder_loss = criterion(encoderGraph[-1].ndata['input'][:, 0:temp_batch_size, :], curren_state_reconstruction).mean()
                contrastive_loss = criterion(latent_states[:, 0:temp_batch_size, :], latent_states[:, temp_batch_size:temp_batch_size * 2, :]).mean()
                contrastive_loss += torch.max(torch.zeros(1).to(device), minRandomDistance - criterion(latent_states[:, 0:temp_batch_size, :], latent_states[:, 2 * temp_batch_size: 3 * temp_batch_size, :]).mean()).mean()
                
                trainLosses[morphIdx][-1][0] += autoencoder_loss.item()
                trainLosses[morphIdx][-1][1] += contrastive_loss.item()

                if stepLoss is None:
                    stepLoss = autoencoder_loss + contrastive_loss

                else:
                    stepLoss += autoencoder_loss + contrastive_loss
                    
        trainLosses[morphIdx][-1] /= numBatchesPerTrainingStep
        stepLoss /= numBatchesPerTrainingStep
        optimizer.zero_grad()
        stepLoss.backward()
        
        if batch % 50 == 49:
            print('Batch {} in {}s - Train AE {} CO {} | Test AE {} CO {}'.format(
                batch, np.round(time.time() - t0, decimals=1), np.round(trainLosses[morphIdx][-1][0], decimals=3), 
                np.round(trainLosses[morphIdx][-1][1], decimals=3), np.round(testLosses[morphIdx][-1][0], decimals=3), np.round(testLosses[morphIdx][-1][1], decimals=3)))
            
#             s = 0
#             for parameter in encoderInputNetwork.parameters():
#                 s += torch.abs(parameter.grad).mean()
#             encoderInputNetworkGradients.append(s.item())

#             s = 0
#             for parameter in encoderMessageNetwork.parameters():
#                 s += torch.abs(parameter.grad).mean()
#             encoderMessageNetworkGradients.append(s.item())

#             s = 0
#             for parameter in encoderUpdateNetwork.parameters():
#                 s += torch.abs(parameter.grad).mean()
#             encoderUpdateNetworkGradients.append(s.item())

#             s = 0
#             for parameter in encoderOutputNetwork.parameters():
#                 s += torch.abs(parameter.grad).mean()
#             encoderOutputNetworkGradients.append(s.item())

#             s = 0
#             for parameter in decoderInputNetwork.parameters():
#                 s += torch.abs(parameter.grad).mean()
#             decoderInputNetworkGradients.append(s.item())

#             s = 0
#             for parameter in decoderMessageNetwork.parameters():
#                 s += torch.abs(parameter.grad).mean()
#             decoderMessageNetworkGradients.append(s.item())

#             s = 0
#             for parameter in decoderUpdateNetwork.parameters():
#                 s += torch.abs(parameter.grad).mean()
#             decoderUpdateNetworkGradients.append(s.item())

#             s = 0
#             for parameter in decoderOutputNetwork.parameters():
#                 s += torch.abs(parameter.grad).mean()
#             decoderOutputNetworkGradients.append(s.item())

#             print('Gradients: Encoder Input {} | Encoder Message {} | Encoder  Update {} | Encoder Output {}'.format(
#                 np.log10(encoderInputNetworkGradients[-1]), np.log10(encoderMessageNetworkGradients[-1]), np.log10(encoderUpdateNetworkGradients[-1]), np.log10(encoderOutputNetworkGradients[-1])))    

#             print('Gradients: Decoder Input {} | Decoder Message {} | Decoder  Update {} | Decoder Output {}'.format(
#                 np.log10(decoderInputNetworkGradients[-1]), np.log10(decoderMessageNetworkGradients[-1]), np.log10(decoderUpdateNetworkGradients[-1]), np.log10(decoderOutputNetworkGradients[-1])))    
            
            print()
            
        optimizer.step()
        
        stepLoss = None
        currentStateEncoderGraph = []
        currentStateDecoderGraph = []
        nextStateEncoderGraph = []
        randomStateEncoderGraph = []
        randomStateGraph = []
        
        
        numBatchesForExectution = 50                
        for morphIdx in trainingIdxs:
            testLosses[morphIdx].append(np.zeros(2))
            for batch_ in np.random.choice(np.arange(numTestingBatches-1), numBatchesForExectution):
                
                encoder_graph = env[morphIdx].get_graph()._get_dgl_graph()
                decoder_graph = env[morphIdx].get_graph()._get_dgl_graph()

                current_states = X_test[morphIdx][(batch_+batchOffset) * batch_size:(batch_+batchOffset+1)*batch_size]
                next_states = Y_test[morphIdx][(batch_+batchOffset) * batch_size:(batch_+batchOffset+1)*batch_size]
                random_indexes = np.random.choice(X_test[0].shape[0],size=batch_size, replace=False)
                random_states = X_test[morphIdx][random_indexes]
                
                encoderInput = torch.cat((current_states, next_states, random_states), dim=0).to(device)

                latent_states = encoderGNN.forward(encoder_graph, encoderInput)
                curren_state_reconstruction = decoderGNN.forward(decoder_graph, latent_states[:, 0:batch_size, :])
                
                autoencoder_loss = criterion(encoderGraph[-1].ndata['input'][:, 0:batch_size, :], curren_state_reconstruction).mean()
                contrastive_loss = criterion(latent_states[:, 0:batch_size, :], latent_states[:, batch_size:batch_size * 2, :]).mean()
                contrastive_loss += torch.max(torch.zeros(1).to(device), minRandomDistance - criterion(latent_states[:, 0:batch_size, :], latent_states[:, 2 * batch_size: 3 * batch_size, :]).mean()).mean()
                
                testLosses[morphIdx][-1][0] += autoencoder_loss.item()
                testLosses[morphIdx][-1][1] += contrastive_loss.item()
            testLosses[morphIdx][-1] /= numBatchesForExectution
            
#         for morphIdx in validationIdxs:
#             numNodes = (X_train[morphIdx].shape[1] - 6) // 2
#             validLosses[morphIdx].append(torch.zeros(numNodes))
#             for batch_ in np.random.choice(np.arange(numTestingBatches-1), numBatchesForExectution):

#                 g = env[morphIdx].get_graph()._get_dgl_graph()
#                 x = X_test[morphIdx][batch_ * batch_size:(batch_+1)*batch_size].to(device)
#                 y = Y_test[morphIdx][batch_ * batch_size:(batch_+1)*batch_size].to(device)

#                 y_hat = gnn.forward(g, x)
#                 loss = criterion(y, y_hat).mean(dim=0)

#                 validLosses[morphIdx][-1] += loss.cpu().detach()
#             validLosses[morphIdx][-1] /= numBatchesForExectution

#         print('\n************** Batch {} in {} **************\n'.format(batch, time.time() - t0))
#         for morphIdx in trainingIdxs:
#             print('Training Idx {} \nTrain Loss {} \nTest Loss {}\n'.format(morphIdx, trainLosses[morphIdx][-1], testLosses[morphIdx][-1]))
#         for morphIdx in validationIdxs:
#             print('Valid Idx {} | Loss {}'.format(morphIdx, validLosses[morphIdx][-1]))
            
#         if batch % 20 ==0:
#             print('Gradients: Input {} | Message {} | Update {} | Output {}'.format(
#                 inputNetworkGradients[-1], messageNetworkGradients[-1], updateNetworkGradients[-1], outputNetworkGradients[-1]))    
#         if batch % 100 == 99:
#             lr_scheduler.step()



Batch 49 in 0.0s - Train AE 0.917 CO 1.0 | Test AE 0.918 CO 1.0

Batch 99 in 0.1s - Train AE 0.899 CO 0.977 | Test AE 0.905 CO 0.977

Batch 149 in 0.1s - Train AE 0.816 CO 0.601 | Test AE 0.818 CO 0.594

Batch 199 in 0.1s - Train AE 0.727 CO 0.244 | Test AE 0.73 CO 0.202

Batch 249 in 0.1s - Train AE 0.721 CO 0.202 | Test AE 0.724 CO 0.208

Batch 299 in 0.1s - Train AE 0.721 CO 0.188 | Test AE 0.719 CO 0.2

Batch 349 in 0.1s - Train AE 0.724 CO 0.268 | Test AE 0.717 CO 0.199

Batch 399 in 0.1s - Train AE 0.713 CO 0.196 | Test AE 0.717 CO 0.205

Batch 449 in 0.1s - Train AE 0.719 CO 0.194 | Test AE 0.723 CO 0.203

Batch 499 in 0.1s - Train AE 0.721 CO 0.224 | Test AE 0.718 CO 0.205

Batch 549 in 0.1s - Train AE 0.714 CO 0.186 | Test AE 0.713 CO 0.196

Batch 599 in 0.1s - Train AE 0.705 CO 0.234 | Test AE 0.71 CO 0.193

Batch 649 in 0.1s - Train AE 0.691 CO 0.173 | Test AE 0.687 CO 0.199

Batch 699 in 0.1s - Train AE 0.684 CO 0.191 | Test AE 0.697 CO 0.191

Batch 749 in 0.1s - Train AE 0

In [None]:
morphIdx = 0
lossArr = torch.stack(testLosses[morphIdx]).T
fig, ax = plt.subplots(1, sharex=True)
for i in range(lossArr.shape[0]):
    ax.plot(range(lossArr.shape[1]), torch.log10(lossArr[i]))
plt.legend(range(lossArr.shape[0]))
plt.xlabel('Training Step')
plt.grid()
plt.ylabel('Smooth L1 Loss')
plt.title('Per Node Loss Morphology {}, Train = {}'.format(morphIdx, morphIdx in trainingIdxs))
plt.savefig('per-node-loss-{}.jpg'.format(morphIdx))
plt.show()

In [None]:
# Cell for producing Per Node Loss for each Morphology

for morphIdx in range(7):
    if morphIdx in trainingIdxs:
        lossArr = torch.stack(testLosses[morphIdx]).T
    else:
        lossArr = torch.stack(validLosses[morphIdx]).T
    
    fig, ax = plt.subplots(1, sharex=True)
    for i in range(lossArr.shape[0]):
        ax.plot(range(lossArr.shape[1]), lossArr[i])
    plt.legend(range(lossArr.shape[0]))
    plt.xlabel('Training Step')
    plt.grid()
    plt.ylabel('Smooth L1 Loss')
    plt.title('Per Node Loss Morphology {}, Train = {}'.format(morphIdx, morphIdx in trainingIdxs))
    plt.savefig('per-node-loss-{}.jpg'.format(morphIdx))
    plt.show()
    

In [None]:
fig, ax = plt.subplots(1, sharex=True)
for morphIdx in trainingIdxs:
    lossArr = torch.stack(testLosses[morphIdx]).mean(dim=1)
    ax.plot(range(lossArr.shape[0]), lossArr)
for morphIdx in validationIdxs:
    lossArr = torch.stack(validLosses[morphIdx]).mean(dim=1)
    ax.plot(range(lossArr.shape[0]), lossArr)

plt.xlabel('Training Step')
plt.ylabel('Smooth L1 Loss')
plt.title('Mean Node Loss per Morphology')
plt.legend(trainingIdxs + validationIdxs)
plt.savefig('mean-node-losses.jpg')
plt.show()

In [None]:
fig, ax = plt.subplots(1, sharex=True)
ax.plot(range(len(inputNetworkGradients)), np.log10(inputNetworkGradients))
ax.plot(range(len(messageNetworkGradients)), np.log10(messageNetworkGradients))
ax.plot(range(len(updateNetworkGradients)), np.log10(updateNetworkGradients))
ax.plot(range(len(outputNetworkGradients)), np.log10(outputNetworkGradients))
plt.xlabel('Training Steps')
plt.ylabel('Mean Absolute Gradient (Log 10)')
plt.title('Gradient Magnitudes Over Time')
plt.legend(['Input Network', 'Message Network', 'Update Network', ' Output Network'])
plt.savefig('gradients.jpg')
plt.show()

In [None]:
trainLosses = {}
testLosses = {}
validLosses = {}
trainingIdxs = []
validationIdxs = []

inputNetworkGradients = []
messageNetworkGradients = []
updateNetworkGradients = []
outputNetworkGradients = []


for morphIdx in range(7):
    trainLosses[morphIdx] = []
    testLosses[morphIdx] = []
    validLosses[morphIdx] = []

for index in [0]:
    
    inputNetwork = Network(inputSize, stateSize, hidden_sizes, batch_size, with_batch_norm)
    messageNetwork = Network(stateSize + 1, messageSize, hidden_sizes, batch_size, with_batch_norm, nn.Tanh)
    updateNetwork = Network(stateSize + messageSize, stateSize, hidden_sizes, batch_size, with_batch_norm)
    outputNetwork = Network(stateSize, outputSize, hidden_sizes, batch_size, with_batch_norm, nn.Tanh)

    gnn = GraphNeuralNetwork(inputNetwork, messageNetwork, updateNetwork, outputNetwork, numMessagePassingIterations).to(device)

    optimizer = optim.Adam(itertools.chain(inputNetwork.parameters(), messageNetwork.parameters(), updateNetwork.parameters(), outputNetwork.parameters())
                           , lr, weight_decay=1e-5)
    
    trainingIdxs = [index]
    
    for epoch in range(10):

        for morphIdx in trainingIdxs:
            permutation = np.random.permutation(X_train[morphIdx].shape[0])
            X_train[morphIdx] = X_train[morphIdx][permutation]
            Y_train[morphIdx] = Y_train[morphIdx][permutation]

        stepLoss = None
        graphs = []
        numAggregatedBatches = 0

        for batch in range(0, numTrainingBatches, numBatchesPerTrainingStep):

            inputNetwork.train()
            messageNetwork.train()
            updateNetwork.train()
            outputNetwork.train()

            t0 = time.time()

            for morphIdx in trainingIdxs:
                numNodes = (X_train[morphIdx].shape[1] - 6) // 2
                trainLosses[morphIdx].append(torch.zeros(numNodes))

            for batchOffset in range(numBatchesPerTrainingStep):

                if batch + batchOffset >= numTrainingBatches:
                    break

                for morphIdx in trainingIdxs:
                    graphs.append(env[morphIdx].get_graph()._get_dgl_graph())
                    x = X_train[morphIdx][(batch+batchOffset) * batch_size:(batch+batchOffset+1)*batch_size].to(device)
                    y = Y_train[morphIdx][(batch+batchOffset) * batch_size:(batch+batchOffset+1)*batch_size].to(device)

                    y_hat = gnn.forward(graphs[-1], x)

                    loss_tmp = criterion(y, y_hat).mean(dim=0)

                    trainLosses[morphIdx][-1] += loss_tmp.cpu().detach() / numBatchesPerTrainingStep

                    if stepLoss is None:
                        stepLoss = loss_tmp.mean()

                    else:
                        stepLoss += loss_tmp.mean()

            optimizer.zero_grad()
            stepLoss.backward()


            s = 0
            for parameter in inputNetwork.parameters():
                s += torch.abs(parameter.grad).mean()
            inputNetworkGradients.append(s.item())

            s = 0
            for parameter in messageNetwork.parameters():
                s += torch.abs(parameter.grad).mean()
            messageNetworkGradients.append(s.item())

            s = 0        
            for parameter in updateNetwork.parameters():
                s += torch.abs(parameter.grad).mean()
            updateNetworkGradients.append(s.item())

            s = 0        
            for parameter in outputNetwork.parameters():
                s += torch.abs(parameter.grad).mean()
            outputNetworkGradients.append(s.item())

            optimizer.step()

            stepLoss = None
            graphs = []

            inputNetwork.eval()
            messageNetwork.eval()
            updateNetwork.eval()
            outputNetwork.eval()

            numBatchesForExectution = 50
            for morphIdx in trainingIdxs:
                numNodes = (X_train[morphIdx].shape[1] - 6) // 2
                testLosses[morphIdx].append(torch.zeros(numNodes))
                for batch_ in np.random.choice(np.arange(numTestingBatches-1), numBatchesForExectution):
                    g = env[morphIdx].get_graph()._get_dgl_graph()
                    x = X_test[morphIdx][batch_ * batch_size:(batch_+1)*batch_size].to(device)
                    y = Y_test[morphIdx][batch_ * batch_size:(batch_+1)*batch_size].to(device)
                    y_hat = gnn.forward(g, x)
                    loss = criterion(y, y_hat).mean(dim=0)
                    testLosses[morphIdx][-1] += loss.cpu().detach()
                testLosses[morphIdx][-1] /= numBatchesForExectution

            for morphIdx in validationIdxs:
                numNodes = (X_train[morphIdx].shape[1] // - 6) // 2
                validLosses[morphIdx].append(torch.zeros(numNodes))
                for batch_ in np.random.choice(np.arange(numTestingBatches-1), numBatchesForExectution):

                    g = env[morphIdx].get_graph()._get_dgl_graph()
                    x = X_test[morphIdx][batch_ * batch_size:(batch_+1)*batch_size].to(device)
                    y = Y_test[morphIdx][batch_ * batch_size:(batch_+1)*batch_size].to(device)

                    y_hat = gnn.forward(g, x)
                    loss = criterion(y, y_hat).mean(dim=0)

                    validLosses[morphIdx][-1] += loss.cpu().detach()
                validLosses[morphIdx][-1] /= numBatchesForExectution

            print('\n************** Batch {} in {} **************\n'.format(batch, time.time() - t0))
            for morphIdx in trainingIdxs:
                print('Training Idx {} \nTrain Loss {} \nTest Loss {}\n'.format(morphIdx, trainLosses[morphIdx][-1], testLosses[morphIdx][-1]))
            for morphIdx in validationIdxs:
                print('Valid Idx {} | Loss {}'.format(morphIdx, validLosses[morphIdx][-1]))
                
    lossArr = torch.stack(testLosses[index]).T
    fig, ax = plt.subplots(1, sharex=True)
    for i in range(lossArr.shape[0]):
        ax.plot(range(lossArr.shape[1]), torch.log10(lossArr[i]))
    plt.legend(range(lossArr.shape[0]))
    plt.xlabel('Training Step')
    plt.grid()
    plt.ylabel('Smooth L1 Loss')
    plt.title('Per Node Loss Morphology {}'.format(index))
    plt.savefig('xv-per-node-loss-{}.jpg'.format(morphIdx))
    plt.show()

#             if batch % 20 ==0:
#                 print('Gradients: Input {} | Message {} | Update {} | Output {}'.format(
#                     inputNetworkGradients[-1], messageNetworkGradients[-1], updateNetworkGradients[-1], outputNetworkGradients[-1]))

In [None]:
prefix = '../models/mix/'

testLossesArr = np.array(testLosses)
trainLossesArr = np.array(trainLosses)
validLossesArr = np.array(validLosses)
learningRatesArr = np.array(learningRates)

np.save(prefix + 'testLosses', testLossesArr)
np.save(prefix + 'trainLosses', trainLossesArr)
np.save(prefix + 'validLosses', validLossesArr)
np.save(prefix + 'learningRates', learningRatesArr)

# torch.save(inputNetwork, prefix + 'inputNetwork.pt')
# torch.save(outputNetwork, prefix + 'outputNetwork.pt')
# torch.save(messageNetwork, prefix + 'messageNetwork.pt')
# torch.save(updateNetwork, prefix + 'updateNetwork.pt')

fig, ax = plt.subplots(1, sharex=True)
ax.plot(range(len(testLossesArr)), testLossesArr)
ax.plot(range(len(trainLossesArr)), trainLossesArr, 'b')
ax.plot(range(len(validLossesArr)), validLossesArr)
ax.set(ylabel='Smooth L1 Loss')
# ax[1].plot(range(len(learningRates)), learningRatesArr)
# ax[1].set(ylabel='Learning Rate')
ax.legend(["Testing", "Training", "Valid"])
plt.xlabel('Batch')
plt.savefig(prefix + 'valid-0.jpg')
plt.show()

In [None]:
prefix = '../models/mix/'

inputNetwork = torch.load(prefix + 'inputNetwork.pt')
messageNetwork = torch.load(prefix + 'messageNetwork.pt')
updateNetwork = torch.load(prefix + 'updateNetwork.pt')
outputNetwork = torch.load(prefix + 'outputNetwork.pt')

gnn = GraphNeuralNetwork(inputNetwork, messageNetwork, updateNetwork, outputNetwork, numMessagePassingIterations=3).to(device)

testLoss = 0
for morphIdx in [0]:
    
    for batch in range(numTestingBatches):

        g = env[morphIdx].get_graph()._get_dgl_graph()
        x = X_test[morphIdx][batch * batch_size:(batch+1)*batch_size].to(device)
        y = Y_test[morphIdx][batch * batch_size:(batch+1)*batch_size].to(device)

        y_hat = gnn.forward(g, x)
        loss = criterion(y, y_hat)
        testLoss += loss.item()
        
print(testLoss / X_test[morphIdx].shape[0])