# Model Training
Here we will build and train our model using the graphs we previously built.

In [137]:
import numpy as np
import pandas as pd
import torch as th
import dgl
import scipy
import networkx as nx
from tqdm import tqdm

from dgl.data.utils import save_graphs, load_graphs, split_dataset

import dgl.nn as dglnn
import torch.nn as nn
import torch.nn.functional as F

## Load Data

In [139]:
#load our graphs from before
glist, label_dict = load_graphs("./data_final.bin")
unsup_graph, sup_graph = glist[0], glist[1]

## Prepare Data for Training

In [145]:
#fix problem where we have a label of -1 to represent the missing class
tempLabels = sup_graph.ndata['label']

#something weird - no labels from original data had class 8??? --> map our -1 class to 8 so it works with cross-entropy loss
tempLabels[tempLabels==-1] = 8

sup_graph.ndata['label'] = tempLabels
unsup_graph.ndata['label'] = tempLabels

In [146]:
#add self loops to both graphs
sup_graph = sup_graph.add_self_loop()
unsup_graph = unsup_graph.add_self_loop()

In [147]:
#split data into train,test,val
sup_split = split_dataset(sup_graph, shuffle=True, random_state=10)
unsup_split = split_dataset(sup_graph, shuffle=True, random_state=10)

#extract the splits
sup_train, sup_val, sup_test = sup_split[0], sup_split[1], sup_split[2]
unsup_train, unsup_val, unsup_test = unsup_split[0], unsup_split[1], unsup_split[2]

#convert the index based representation into boolean masks for the graphs
n = sup_graph.number_of_nodes() #total num nodes in each graph
train_mask, val_mask, test_mask = np.zeros(n, dtype=bool), np.zeros(n, dtype=bool), np.zeros(n, dtype=bool) #create empty arrays for train/val/test

#populate our boolean masks
train_mask[sup_train.indices] = True 
val_mask[sup_val.indices] = True
test_mask[sup_test.indices] = True

#embed these masks into our graph
sup_graph.ndata['train_mask'], sup_graph.ndata['val_mask'], sup_graph.ndata['test_mask'] = th.tensor(train_mask, dtype=bool), th.tensor(val_mask, dtype=bool), th.tensor(test_mask, dtype=bool)
unsup_graph.ndata['train_mask'], unsup_graph.ndata['val_mask'], unsup_graph.ndata['test_mask'] = th.tensor(train_mask), th.tensor(val_mask), th.tensor(test_mask)



In [148]:
sum(sup_split[0].indices == unsup_split[0].indices) == len(sup_split[0]) #confirm we have the same nodes in both splits

True

In [150]:
#extract and store information from each graph

#sup
sup_node_features = sup_graph.ndata['features']
sup_node_labels = sup_graph.ndata['label']

#unsup
unsup_node_features = unsup_graph.ndata['features']
unsup_node_labels = unsup_graph.ndata['label']


#general graph characteristics - doesn't matter which graph
train_mask = sup_graph.ndata['train_mask']
valid_mask = sup_graph.ndata['val_mask']
test_mask = sup_graph.ndata['test_mask']
n_features = sup_node_features.shape[1]
n_labels = int(sup_node_labels.max().item() + 1)

## Build Model

In [151]:
#simple base GraphSage model
class SAGE(nn.Module):
    def __init__(self, in_feats, hid_feats1, hid_feats2, out_feats):
        super().__init__()
        self.layers = nn.ModuleList()
        self.layers.append(dglnn.SAGEConv(
            in_feats=in_feats, out_feats = hid_feats1, aggregator_type='pool', activation= F.relu))
        self.layers.append(dglnn.SAGEConv(
            in_feats=hid_feats1, out_feats = hid_feats2, aggregator_type='mean', activation= F.relu))
        self.layers.append(nn.Linear(in_features=hid_feats2, out_features = out_feats))
        self.layers.append(nn.Softmax(dim=-1))

    def forward(self, graph, inputs):
        h = inputs
        for i, layer in enumerate(self.layers):
            # inputs are features of nodes
            #different cases for graph layers and fully connected layers
            if i<=1:
                h = layer(graph, h)
            else:
                h = layer(h)
        
        return h

In [152]:
#function to evaluate model
def evaluate(model, graph, features, labels, mask):
    model.eval()
    with th.no_grad():
        logits = model(graph, features)
        logits = logits[mask]
        labels = labels[mask]
        _, indices = th.max(logits, dim=1)
        correct = th.sum(indices == labels)
        return correct.item() * 1.0 / len(labels)

In [153]:
#training loop
model = SAGE(in_feats=n_features, hid_feats1=256, hid_feats2=128, out_feats=n_labels)
opt = th.optim.Adam(model.parameters())

for epoch in tqdm(range(10)):
    model.train()
    
    # forward propagation by using all nodes - extract separate logits
    sup_logits = model(sup_graph, sup_node_features)
    unsup_logits = model(unsup_graph, unsup_node_features)
    
    # compute losses
    sup_loss = F.cross_entropy(sup_logits[train_mask], sup_node_labels[train_mask])
    unsup_loss = F.cross_entropy(unsup_logits[train_mask], unsup_node_labels[train_mask])
    
    # compute validation accuracy
    sup_acc = evaluate(model, sup_graph, sup_node_features, sup_node_labels, valid_mask)
    unsup_acc = evaluate(model, unsup_graph, unsup_node_features, unsup_node_labels, valid_mask)
    
    #add additional term which is average euclidean distance between logits
    additional_term = th.dist(sup_logits, unsup_logits,2)/n
    
    #get total loss
    loss = sup_loss + unsup_loss + additional_term
    
    # backward propagation
    opt.zero_grad()
    loss.backward()
    opt.step()
    print(loss.item())

    # Save model
    th.save(model.state_dict(), f'./models/model{epoch}')















  0%|          | 0/10 [00:00<?, ?it/s][A[A[A[A[A[A[A[A[A[A[A[A[A[A













 10%|█         | 1/10 [13:21<2:00:14, 801.62s/it][A[A[A[A[A[A[A[A[A[A[A[A[A[A

9.020709037780762
















 20%|██        | 2/10 [25:31<1:44:00, 780.01s/it][A[A[A[A[A[A[A[A[A[A[A[A[A[A

7.84567403793335
















 30%|███       | 3/10 [37:47<1:29:28, 766.97s/it][A[A[A[A[A[A[A[A[A[A[A[A[A[A

7.841334342956543
















 40%|████      | 4/10 [50:14<1:16:06, 761.04s/it][A[A[A[A[A[A[A[A[A[A[A[A[A[A

7.841331958770752
















 50%|█████     | 5/10 [1:02:59<1:03:30, 762.02s/it][A[A[A[A[A[A[A[A[A[A[A[A[A[A

7.841331958770752
















 60%|██████    | 6/10 [1:17:12<52:37, 789.33s/it]  [A[A[A[A[A[A[A[A[A[A[A[A[A[A

7.841331958770752
















 70%|███████   | 7/10 [1:31:25<40:25, 808.42s/it][A[A[A[A[A[A[A[A[A[A[A[A[A[A

7.841331958770752
















 80%|████████  | 8/10 [1:45:24<27:15, 817.72s/it][A[A[A[A[A[A[A[A[A[A[A[A[A[A

7.841331958770752
















 90%|█████████ | 9/10 [1:59:23<13:43, 823.95s/it][A[A[A[A[A[A[A[A[A[A[A[A[A[A

7.841331958770752
















100%|██████████| 10/10 [2:15:55<00:00, 874.38s/it][A[A[A[A[A[A[A[A[A[A[A[A[A[A

7.841331958770752


In [156]:
unsup_acc

0.596098944615328

In [None]:
#Accuracy: 59.6%
#Loss: 7.84 - stable after 3 epochs