# CGNN - Triad Prediction

In [None]:
import torch
import torch.nn as nn
import numpy as np
from torch.distributions.bernoulli import Bernoulli
from torch.distributions.normal import Normal
from torch.nn import init
from random import shuffle, randint
import operator
import networkx as nx
import matplotlib.pyplot as plt
import torch.nn.functional as F
from sklearn.cluster import SpectralClustering
from sklearn.manifold import TSNE
from torch_geometric.datasets import Reddit, PPI, Planetoid
import os
from itertools import combinations, combinations_with_replacement
from sklearn.metrics import f1_score, accuracy_score
import sys
import pickle

## Define the dataset, the type of prediction and the number of samples

In [None]:
DATASET = 'cora'
PREDICTION = 'triad'
RUN_COUNT = 1
NUM_SAMPLES = 1
PATH_TO_DATASETS_DIRECTORY = './'

In [None]:
datasets = {
    'reddit': Reddit(root=PATH_TO_DATASETS_DIRECTORY + '/datasets/Reddit'),
    'cora' : Planetoid(root=PATH_TO_DATASETS_DIRECTORY + '/datasets/Cora/', name='Cora'),
    'citeseer' : Planetoid(root=PATH_TO_DATASETS_DIRECTORY + '/datasets/CiteSeer/', name='CiteSeer'),
    'pubmed' : Planetoid(root=PATH_TO_DATASETS_DIRECTORY + '/datasets/PubMed/', name='PubMed'),
}
dataset = datasets[DATASET]
data = dataset[0]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
predictions = {
    'node' : dataset.num_classes,
    'link' : 2,
    'triad' : 4,
}

In [None]:
dataset_types = ['train', 'validation', 'test']
triads_stores = dict()
triad_loc = '/scratch-ml00/bsriniv/pos2struct/triad_store/'

for set_nature in dataset_types :
    zero_filename = triad_loc + set_nature + '/' + DATASET + '_triad_zero.pickle'
    one_filename = triad_loc + set_nature + '/' + DATASET + '_triad_one.pickle'
    two_filename = triad_loc + set_nature + '/' + DATASET + '_triad_two.pickle'
    three_filename = triad_loc + set_nature + '/' + DATASET + '_triad_three.pickle'
    with open(zero_filename, 'rb') as f:
        zeros = pickle.load(f)
    with open(one_filename, 'rb') as f:
        ones = pickle.load(f)
    with open(two_filename, 'rb') as f:
        twos = pickle.load(f)
    with open(three_filename, 'rb') as f:
        threes = pickle.load(f)
    triads_stores[set_nature] = dict()
    triads_stores[set_nature]['zeros'] = zeros
    triads_stores[set_nature]['ones'] = ones
    triads_stores[set_nature]['twos'] = twos
    triads_stores[set_nature]['threes'] = threes

## Parallel Splash Generator

In [None]:
ds_rho_hidden = 256
ds_rho_visible = dataset.num_features
num_neurons = 256


def parallel_splash_variant(adjacency_matrix, num_splashes=64):
    """
        Parallel Splash Gibbs Sampling
    """
    max_value = adjacency_matrix.shape[0]
    max_value_hidden = adjacency_matrix.shape[1]
    batches = dict()
    neighborhood_tracker = dict()
    curr_step_max_len = 0
    num_batches = 0
    all_nodes = list(range(max_value))
    root_allowed_nodes = set(all_nodes)
    #Check if all nodes have been reached
    #Strong Assumption that connected graph or all connected components will be seen atleast once, if not batch size will be increased
    for i in range(num_splashes):
        temp_list = list(root_allowed_nodes)
        shuffle(temp_list)
        root_node = temp_list.pop()
        root_allowed_nodes = set(temp_list)
        batches[root_node] = [root_node]
        root_node_neighbors = set(sum(torch.LongTensor((adjacency_matrix[root_node]).nonzero()).tolist(),[]))
        neighborhood_tracker[root_node] = dict((k,1) for k in root_node_neighbors)
        root_allowed_nodes = root_allowed_nodes - root_node_neighbors
        if len(root_allowed_nodes) == 0 :
            # print("Unable to create %d splashes due to graph connectivity or number of nodes"%num_splashes)
            break
    allowed_nodes = set(all_nodes) - set(batches.keys())
    prev_size  = len(allowed_nodes)
    while len(allowed_nodes) > 0:
        for root in batches:
            #Have to grow splash one at a time for each splash
            #Select node with max number of shared neighbors
            try :
                node_selected = max(neighborhood_tracker[root].items(), key=operator.itemgetter(1))[0]
                node_neighbors = set(sum(torch.LongTensor(adjacency_matrix[node_selected].nonzero()).tolist(),[]))
                batches[root].insert(0, node_selected)
                #Clear up, so that it doesn't show up in other splashes
                allowed_nodes.remove(node_selected)
                #Decrement / Increment neighbor count for other/ current root
                for r in batches:
                    if node_selected in neighborhood_tracker[r]:
                        del neighborhood_tracker[r][node_selected]
                        if root!=r:
                            for neigh in node_neighbors:
                                if neigh in neighborhood_tracker[r]:
                                    neighborhood_tracker[r][neigh]-=1
                        else:
                             for neigh in node_neighbors :
                                if neigh in neighborhood_tracker[r]:
                                    neighborhood_tracker[r][neigh]+=1
                                elif neigh not in neighborhood_tracker[r] and neigh in allowed_nodes :
                                    neighborhood_tracker[r][neigh]=1
            except :
                pass
        curr_size = len(allowed_nodes)
        if curr_size == prev_size and curr_size!=0:
            #Add a random node and increase the batch size by 1
            new_root_node = allowed_nodes.pop()
            batches[new_root_node] = [new_root_node]
            new_root_node_neighbors = set(sum(torch.LongTensor((adjacency_matrix[new_root_node]==0).nonzero()).tolist(),[]))
            neighborhood_tracker[new_root_node] = dict((k,1) for k in new_root_node_neighbors)
            curr_size-=1
        prev_size = curr_size

    for root in batches:
        num_batches = max(num_batches, len(batches[root]))
    batches_ = dict((k,[]) for k in range(num_batches))
    batch_info = dict((k,[]) for k in range(num_batches))
    for i in range(num_batches):
        curr_step_max_len = 0
        for root in batches:
            try:
                batches_[i].append(batches[root][i])
                curr_node_neigh = sum(torch.LongTensor(adjacency_matrix[batches[root][i]].nonzero()).tolist(),[])
                cn_neighbor = set(curr_node_neigh)
                b_neighbor = set(batches[root])
                curr_node_neigh = list(cn_neighbor.intersection(b_neighbor)) #list(cn_neighbor)
                batch_info[i].append(curr_node_neigh)
                curr_step_max_len = max(curr_step_max_len, len(curr_node_neigh))
            except :
                pass
        for node_neigh in batch_info[i]:
            while len(node_neigh) < curr_step_max_len :
                node_neigh.append(max_value_hidden)

    return batches_, batch_info


## Colliders Learning Network

In [None]:
class ColliderNetworks(nn.Module):
    def __init__(self, weight_norm=True):
        super(ColliderNetworks, self).__init__()
        #Deepsets MLP's
        self.rho_mlp_visible_1 = nn.Linear(ds_rho_visible, num_neurons)
        self.rho_mlp_visible_2 = nn.Linear(num_neurons, num_neurons)
        self.rho_mlp_visible_1_dropout = nn.Dropout(p=0.5)
        self.rho_mlp_hidden_1 = nn.Linear(ds_rho_hidden, num_neurons)
        self.rho_mlp_hidden_2 = nn.Linear(num_neurons, num_neurons)
        self.rho_mlp_hidden_1_dropout = nn.Dropout(p=0.5)

        #Gibbs Sampling MLP's - Mean
        self.colliders_mlp_1 = nn.Linear(2*num_neurons, num_neurons)
        self.colliders_mlp_2 = nn.Linear(num_neurons, num_neurons)
        self.colliders_mlp_1_dropout_mean = nn.Dropout(p=0.5)
        self.colliders_mlp_1_dropout_variance = nn.Dropout(p=0.5)

        self.relu_activation = nn.Tanh()
        self.sigmoid_activation = nn.Sigmoid()
        self.loss_func = nn.BCELoss()

        self.rho_mlp_visible_1_norm = nn.BatchNorm1d(num_neurons)
        self.rho_mlp_hidden_1_norm = nn.BatchNorm1d(num_neurons)
        self.colliders_mlp_1_norm = nn.BatchNorm1d(num_neurons)

        #Xavier Uniform Inits, Biases set to zero
        init.xavier_uniform_(self.rho_mlp_visible_1.weight)
        init.xavier_uniform_(self.rho_mlp_visible_2.weight)
        init.xavier_uniform_(self.rho_mlp_hidden_1.weight)
        init.xavier_uniform_(self.rho_mlp_hidden_2.weight)
        init.xavier_uniform_(self.colliders_mlp_1.weight)
        init.xavier_uniform_(self.colliders_mlp_2.weight)
        self.rho_mlp_visible_1.bias.data.fill_(0)
        self.rho_mlp_visible_2.bias.data.fill_(0)
        self.rho_mlp_hidden_1.bias.data.fill_(0)
        self.rho_mlp_hidden_2.bias.data.fill_(0)
        self.colliders_mlp_1.bias.data.fill_(0)
        self.colliders_mlp_2.bias.data.fill_(0)

    def deepsets(self, input_tensor, hidden=False):
        """
            Returns the set representation of the input
            rho uses an mlp, Activation is relu
        """
        deepsets_sum = torch.sum(input_tensor, dim = 1)
        if hidden :
            out = self.rho_mlp_hidden_1(deepsets_sum)
            try:
                out = self.rho_mlp_hidden_1_norm(out)
            except:
                pass
            out = self.relu_activation(out)
            out = self.rho_mlp_hidden_1_dropout(out)
            out = self.rho_mlp_hidden_2(out)
        else:
            out = self.rho_mlp_visible_1(deepsets_sum)
            try:
                out = self.rho_mlp_visible_1_norm(out)
            except:
                pass
            out = self.relu_activation(out)
            out = self.rho_mlp_visible_1_dropout(out)
            out = self.rho_mlp_visible_2(out)
        return out

    def forward(self, adjacency_matrix, hidden_embeddings, visible_feats, batches, batch_info, num_times=5):
        #Torch Embeddings, one for visible and another for hidden
        #Call Gibbs Sampling multiple times
        for iteration in range(num_times):
            self.gibbs_sampling(hidden_embeddings, visible_feats, batches, batch_info)
        return hidden_embeddings

    def reconstruction_loss(self, adjacency_matrix, hidden_embeddings):
        """
            Reconstrunction Loss: Compute loss over all n choose 2 possible edges, non edges
        """
        #Multiply the hidden embedding tensor with its transpose and compute sigmoid
        max_value = adjacency_matrix.shape[0]
        if max_value > 5000 :
            selection = torch.randint(low=0, high=2,size=(max_value,1)).type(torch.uint8).view(-1)
            adjacency_matrix = adjacency_matrix[selection].t()[selection].t()
            hidden_embeddings = hidden_embeddings[:max_value][selection]
        if max_value < 5000 :
            edge_probability = torch.matmul(hidden_embeddings[:max_value], torch.t(hidden_embeddings[:max_value]))
        else :
            edge_probability = torch.matmul(hidden_embeddings, torch.t(hidden_embeddings))
        edge_probability = torch.sigmoid(edge_probability)
        preds = edge_probability.view(-1)
        target = torch.reshape(adjacency_matrix, (-1,1)).view(-1).to(device).type(torch.float)
        #target = adjacency_matrix.view(-1).to(device).type(torch.float)
        target_inverse = 1.0 - target
        total_edges_non_edges = max_value**2
        num_non_edges = (target == 0).sum().item()
        num_edges = total_edges_non_edges - num_non_edges
        class_weights = target*(num_non_edges/total_edges_non_edges) + target_inverse*(num_edges/ total_edges_non_edges)
        loss = F.binary_cross_entropy(preds, target, weight=class_weights)
        return loss

    def gibbs_sampling(self, hidden_embeddings, visible_feats, batches, batch_info):
        """
        Runs one complete iteration of Gibbs Sampling
        """
        #For each of the constructed batches
        for num_batch in range(len(batches)):
            #Pass the hidden and visible separately through deepsets and concatenate them
            hidden_dependencies = hidden_embeddings[torch.LongTensor(batch_info[num_batch]).to(device)]
            visible_dependencies = visible_feats[torch.LongTensor(batch_info[num_batch]).to(device)]
            hidden_dependencies = hidden_dependencies.detach()
            visible_dependencies = visible_dependencies.detach()

            set_rep_hidden  = self.deepsets(hidden_dependencies, hidden=True)
            set_rep_visible = self.deepsets(visible_dependencies, hidden=False)
            posterior = torch.cat((set_rep_hidden, set_rep_visible), 1)
            posterior = posterior.detach()

            #Perform a sampling equivalent using a MLP
            out = self.colliders_mlp_1(posterior)
            try:
                out = self.colliders_mlp_1_norm(out)
            except:
                pass

            out = self.relu_activation(out)

            mean = self.colliders_mlp_1_dropout_mean(out)
            var = self.colliders_mlp_1_dropout_variance(out)
            m = Normal(torch.zeros(var.shape), torch.ones(var.shape))
            noise = torch.Tensor(m.sample()).to(device)
            out = mean + torch.mul(noise, var)

            out = self.colliders_mlp_2(out)
            hidden_embeddings[batches[num_batch]] = out

## Build the non-overlapping induced subgraphs and Corrupt a small fraction of the edges

In [None]:
data.train_mask = 1 - data.val_mask - data.test_mask

adj_mat = torch.zeros((data.num_nodes,data.num_nodes)).type(torch.short)
edges = data.edge_index.t()
adj_mat[edges[:,0], edges[:,1]] = 1

adj_train = adj_mat[data.train_mask].t()[data.train_mask].t()
adj_validation = adj_mat[data.val_mask].t()[data.val_mask].t()
adj_test = adj_mat[data.test_mask].t()[data.test_mask].t()


def corrupt_adj(adj_mat, task, percent=1):
    """ Returns the corrupted version of the adjacency matrix """
    if task == 'link':
        edges = adj_mat.triu().nonzero()
        num_edges = edges.shape[0]
        num_to_corrupt = int(percent/100.0 * num_edges)
        random_corruption = np.random.randint(num_edges, size=num_to_corrupt)
        adj_mat_corrupted = adj_mat.clone()
        false_edges, false_non_edges = [], []
        #Edge Corruption
        for ed in edges[random_corruption]:
            adj_mat_corrupted[ed[0], ed[1]] = 0
            adj_mat_corrupted[ed[1], ed[0]] = 0
            false_non_edges.append(ed.tolist())
        #Non Edge Corruption
        random_non_edge_corruption = list(np.random.randint(adj_mat.shape[0], size = 6*num_to_corrupt))
        non_edge_to_corrupt = []
        for k in range(len(random_non_edge_corruption)-1):
            to_check = [random_non_edge_corruption[k], random_non_edge_corruption[k+1]]
            if to_check not in edges.tolist():
                non_edge_to_corrupt.append(to_check)
            if len(non_edge_to_corrupt) == num_to_corrupt:
                break
        non_edge_to_corrupt = torch.Tensor(non_edge_to_corrupt).type(torch.int16)
        for n_ed in non_edge_to_corrupt:
            adj_mat_corrupted[n_ed[0], n_ed[1]] = 1
            adj_mat_corrupted[n_ed[1], n_ed[0]] = 1
            false_edges.append(n_ed.tolist())
    return adj_mat_corrupted, false_edges, false_non_edges



adj_train_corrupted, train_false_edges, train_false_non_edges = corrupt_adj(adj_train, 'link', percent=1)
adj_val_corrupted, val_false_edges, val_false_non_edges = corrupt_adj(adj_validation, 'link', percent=1)
adj_test_corrupted, test_false_edges, test_false_non_edges  = corrupt_adj(adj_test, 'link', percent=1)


train_batches, train_batch_info = parallel_splash_variant(adj_train_corrupted, 64)
val_batches, val_batch_info = parallel_splash_variant(adj_val_corrupted, 64)
test_batches, test_batch_info = parallel_splash_variant(adj_test_corrupted, 64)

## Training the Collider Network

In [None]:
visible_feats = data.x.to(device)
temp_holder = torch.zeros((1, data.num_features)).to(device)
visible_feats = torch.cat((visible_feats, temp_holder))
validation_loss = 10000.0
torch.cuda.empty_cache()
collider_sample = ColliderNetworks().to(device)
colliders_model = 'best_colliders_model.model'
optimizer = torch.optim.Adam(collider_sample.parameters(), lr=0.001)

for epoch in range(50):
    print("Epoch Num: ", epoch)
    torch.cuda.empty_cache()
    optimizer.zero_grad()
    normal_init = Normal(torch.zeros(adj_train_corrupted.shape[0]+1, ds_rho_hidden), torch.ones(adj_train_corrupted.shape[0]+1, ds_rho_hidden))
    hidden_embeddings = torch.Tensor(normal_init.sample()).to(device)
    hidden_embeddings[-1] = torch.zeros((1, ds_rho_hidden)).to(device)
    hidden_embeddings = hidden_embeddings.detach()
    visible_feats = data.x[data.train_mask].to(device)
    temp_holder = torch.zeros((1, data.num_features)).to(device)
    visible_feats = torch.cat((visible_feats, temp_holder))
    hidden_embeddings = collider_sample.forward(adjacency_matrix=adj_train_corrupted, hidden_embeddings=hidden_embeddings, batches=train_batches, batch_info=train_batch_info, num_times=2, visible_feats=visible_feats)
    loss = collider_sample.reconstruction_loss(adjacency_matrix=adj_train_corrupted, hidden_embeddings=hidden_embeddings)
    loss.backward()
    print("Training Loss: ", loss.item())
    sys.stdout.flush()
    with torch.no_grad():
        #Do Validation and check if validation loss has gone down
        normal_init = Normal(torch.zeros(adj_val_corrupted.shape[0]+1, ds_rho_hidden), torch.ones(adj_val_corrupted.shape[0]+1, ds_rho_hidden))
        hidden_embeddings = torch.Tensor(normal_init.sample()).to(device)
        hidden_embeddings[-1] = torch.zeros((1, ds_rho_hidden)).to(device)
        hidden_embeddings = hidden_embeddings.detach()
        visible_feats = data.x[data.val_mask].to(device)
        temp_holder = torch.zeros((1, data.num_features)).to(device)
        visible_feats = torch.cat((visible_feats, temp_holder))
        hidden_embeddings = collider_sample.forward(adjacency_matrix=adj_val_corrupted, hidden_embeddings=hidden_embeddings, batches=val_batches, batch_info=val_batch_info, num_times=2, visible_feats=visible_feats)
        compute_val_loss = collider_sample.reconstruction_loss(adjacency_matrix=adj_val_corrupted, hidden_embeddings=hidden_embeddings)
        if compute_val_loss < validation_loss:
            validation_loss = compute_val_loss
            print("Validation Loss: ", validation_loss)
            #Save Model
            torch.save(collider_sample.state_dict(), colliders_model)
    optimizer.step()

## Load the best saved colliders model

In [None]:
collider_sample = ColliderNetworks().to(device)
collider_sample.load_state_dict(torch.load(colliders_model))

## Generate Multiple Samples for Train, Validation and Test uing the Colliders Model with different normal inputs

In [None]:
hidden_samples_train = []
hidden_samples_validation = []
hidden_samples_test = []
for sample in range(NUM_SAMPLES):
    print("Sample No:: ", sample)
    with torch.no_grad():
        #Training
        normal_init = Normal(torch.zeros(adj_train_corrupted.shape[0]+1, ds_rho_hidden), torch.ones(adj_train_corrupted.shape[0]+1, ds_rho_hidden))
        hidden_embeddings = torch.Tensor(normal_init.sample()).to(device)
        hidden_embeddings[-1] = torch.zeros((1, ds_rho_hidden)).to(device)
        hidden_embeddings = hidden_embeddings.detach()
        visible_feats = data.x[data.train_mask].to(device)
        temp_holder = torch.zeros((1, data.num_features)).to(device)
        visible_feats = torch.cat((visible_feats, temp_holder))
        hidden_embeddings = collider_sample.forward(adjacency_matrix=adj_train_corrupted, hidden_embeddings=hidden_embeddings, batches=train_batches, batch_info=train_batch_info, num_times=5, visible_feats=visible_feats)
        hidden_samples_train.append(hidden_embeddings)
        #Validation
        normal_init = Normal(torch.zeros(adj_val_corrupted.shape[0]+1, ds_rho_hidden), torch.ones(adj_val_corrupted.shape[0]+1, ds_rho_hidden))
        hidden_embeddings = torch.Tensor(normal_init.sample()).to(device)
        hidden_embeddings[-1] = torch.zeros((1, ds_rho_hidden)).to(device)
        hidden_embeddings = hidden_embeddings.detach()
        visible_feats = data.x[data.val_mask].to(device)
        temp_holder = torch.zeros((1, data.num_features)).to(device)
        visible_feats = torch.cat((visible_feats, temp_holder))
        hidden_embeddings = collider_sample.forward(adjacency_matrix=adj_val_corrupted, hidden_embeddings=hidden_embeddings, batches=val_batches, batch_info=val_batch_info, num_times=5, visible_feats=visible_feats)
        hidden_samples_validation.append(hidden_embeddings)
        #Test
        normal_init = Normal(torch.zeros(adj_test_corrupted.shape[0]+1, ds_rho_hidden), torch.ones(adj_test_corrupted.shape[0]+1, ds_rho_hidden))
        hidden_embeddings = torch.Tensor(normal_init.sample()).to(device)
        hidden_embeddings[-1] = torch.zeros((1, ds_rho_hidden)).to(device)
        hidden_embeddings = hidden_embeddings.detach()
        visible_feats = data.x[data.test_mask].to(device)
        temp_holder = torch.zeros((1, data.num_features)).to(device)
        visible_feats = torch.cat((visible_feats, temp_holder))
        hidden_embeddings = collider_sample.forward(adjacency_matrix=adj_test_corrupted, hidden_embeddings=hidden_embeddings, batches=test_batches, batch_info=test_batch_info, num_times=5, visible_feats=visible_feats)
        hidden_samples_test.append(hidden_embeddings)

In [None]:
for i in range(NUM_SAMPLES):
    hidden_samples_train[i] = hidden_samples_train[i][:-1]
    hidden_samples_train[i] = torch.cat((hidden_samples_train[i], data.x[data.train_mask].to(device)),1)
    hidden_samples_validation[i] = hidden_samples_validation[i][:-1]
    hidden_samples_validation[i] = torch.cat((hidden_samples_validation[i], data.x[data.val_mask].to(device)),1)
    hidden_samples_test[i] = hidden_samples_test[i][:-1]
    hidden_samples_test[i] = torch.cat((hidden_samples_test[i], data.x[data.test_mask].to(device)),1)

## Define the Supervised Learning Network

In [None]:
num_neurons = 256
input_rep = num_neurons + data.num_features

class StructMLP(nn.Module):
    def __init__(self, node_set_size=1):
        super(StructMLP, self).__init__()

        self.node_set_size = node_set_size
        #Deepsets MLP

        self.ds_layer_1 = nn.Linear(input_rep, num_neurons)
        self.ds_layer_2 = nn.Linear(num_neurons, num_neurons)
        self.rho_layer_1 = nn.Linear(num_neurons, num_neurons)
        self.rho_layer_2 = nn.Linear(num_neurons, num_neurons)

        #One Hidden Layer
        self.layer1 = nn.Linear(num_neurons, num_neurons)
        self.layer2 = nn.Linear(num_neurons, predictions[PREDICTION])
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, input_tensor, samples):
        #Deepsets initially on each of the samples
        num_nodes = input_tensor.shape[1]
        sum_tensor = torch.zeros(samples.shape[0], num_neurons).to(device)
        for i in range(input_tensor.shape[0]):
            #Process the input tensor to form n choose k combinations and create a zero tensor
            set_init_rep = input_tensor[i].view(-1, input_rep)
            x = self.ds_layer_1(set_init_rep)
            x = self.relu(x)
            x = self.ds_layer_2(x)
            x = x[samples]
            x = torch.sum(x, dim=1)
            x = self.rho_layer_1(x)
            sum_tensor += x

        x = sum_tensor / input_tensor.shape[0]

        #One Hidden Layer for predictor
        x = self.layer1(x)
        x = self.relu(x)
        x = self.layer2(x)
        return x

    def compute_loss(self, input_tensor, samples, target):
        pred = self.forward(input_tensor, samples)
        return F.cross_entropy(pred, target)

In [None]:
if PREDICTION == 'node':
    node_set_size = 1
elif PREDICTION == 'link':
    node_set_size = 2
else:
    node_set_size = 3

mlp = StructMLP(node_set_size).to(device)
mlp_optimizer = torch.optim.Adam(mlp.parameters(), lr=0.001)
mlp_model = 'best_mlp_model.model'

## Training the Supervised Learning Network

In [None]:
def sample_triads(set_nature, small_samples=100):
    zeros_shape = triads_stores[set_nature]['zeros'].shape[0]
    ones_shape = triads_stores[set_nature]['ones'].shape[0]
    twos_shape = triads_stores[set_nature]['twos'].shape[0]
    threes_shape = triads_stores[set_nature]['threes'].shape[0]

    zeros = triads_stores[set_nature]['zeros'][np.random.randint(zeros_shape, size=min(small_samples, zeros_shape))]
    ones = triads_stores[set_nature]['ones'][np.random.randint(ones_shape, size=min(small_samples, ones_shape))]
    twos = triads_stores[set_nature]['twos'][np.random.randint(twos_shape, size=min(small_samples, twos_shape))]
    threes = triads_stores[set_nature]['threes'][np.random.randint(threes_shape, size=min(small_samples, threes_shape))]

    target_zeros = torch.zeros(zeros.shape[0])
    target_ones = torch.ones(ones.shape[0])
    target_twos = 2.0 * torch.ones(twos.shape[0])
    target_threes = 3.0 * torch.ones(threes.shape[0])

    out = torch.cat((zeros, ones, twos, threes), dim=0).view(-1,3).type(torch.long)
    target = torch.cat((target_zeros, target_ones, target_twos, target_threes), dim=0).type(torch.long)
    return out.to(device), target.to(device)

In [None]:
epochs = 50
validation_loss = 10000.0
small_samples = 200
for num_epoch in range(epochs):
    mlp_optimizer.zero_grad()
    input_ = torch.stack(hidden_samples_train)
    input_ = input_.detach()
    sampled, target = sample_triads('train', small_samples=small_samples)
    loss = mlp.compute_loss(input_, sampled, target=target)
    print("Training Loss: ", loss.item())
    with torch.no_grad():
        #Do Validation and check if validation loss has gone down
        input_val = torch.stack(hidden_samples_validation)
        input_val = input_val.detach()
        sampled, target = sample_triads('validation', small_samples=small_samples)
        compute_val_loss = mlp.compute_loss(input_val, sampled, target=target)
        if compute_val_loss < validation_loss:
            validation_loss = compute_val_loss
            print("Validation Loss: ", validation_loss)
            #Save Model
            torch.save(mlp.state_dict(), mlp_model)
    loss.backward()
    mlp_optimizer.step()

## Load the best model

In [None]:
mlp = StructMLP(node_set_size).to(device)
mlp.load_state_dict(torch.load(mlp_model))

## Forward pass on the test graphs

In [None]:
small_samples = 200
sampled_test, target_test = sample_triads('test', small_samples)


t_test = target_test.to("cpu").numpy()
input_test = torch.stack(hidden_samples_test)
input_test = input_test.detach()

with torch.no_grad():
    test_pred = mlp.forward(input_test, sampled_test)
    pred = F.log_softmax(test_pred, dim=1)
pred = pred.detach().to("cpu").numpy()
pred = np.argmax(pred, axis=1)

## Test Results

In [None]:
print("Test Micro F1 Score: ", f1_score(t_test, pred, average='micro'))
print("Test Weighted F1 Score: ", f1_score(t_test, pred, average='weighted'))
print("Test Accuracy Score: ", accuracy_score(t_test, pred))