# Temporal Co-Training

#### Team Details

| Name | SRN | Section
|:------|:-------: | :----: |
| Harshith Mohan Kumar | PES1UG19CS276|    E
| Vishruth Veerendranath | PES1UG19CS577 |    I
| Divya Shekar | PES1UG19CS148 |    C
| Vibha Masti | PES1UG19CS565 |    I

#### Instructions

##### Software Used
- pytorch version: 1.11.0+cu102
- numpy version: 1.21.6
- networkx version: 2.6.3
- pytorch geometric version: 2.0.4 
- pytorch geometric temporal version: 0.51.0 

#### Prior to running notebook


#### Note
- This notebook was created on [Deepnote](https://deepnote.com/workspace)
- A [Dockerfile](https://drive.google.com/file/d/121BRb5qTTCkgojAswX2BHWy6pAg3jmA4/view?usp=sharing) containing some of the dependencies
- A [requirements.txt](https://drive.google.com/file/d/1UwaQxbWx3rnLVoVj-5whY6QF4DObIUf8/view?usp=sharing) with additional dependencies
- The dataset used can be found [here](https://drive.google.com/drive/folders/10FOTa6HXPqX8Pf5WRoRwcFnW9BrNZEIX) - corresponding GitHub repo [here](https://github.com/liyaguang/DCRNN)

## References

1. [Traffic prediction using PyGT](https://colab.research.google.com/drive/132hNQ0voOtTVk3I4scbD3lgmPTQub0KR?usp=sharing#scrollTo=iOc-jbFckFHn)
2. [Corresponding video](https://www.youtube.com/watch?v=Rws9mf1aWUs&t=1s)
3. [Co-Training for Image Recognition](https://github.com/AlanChou/Deep-Co-Training-for-Semi-Supervised-Image-Recognition/blob/master/main.py)
4. [Co-Training for Sentiment Analysis](https://github.com/Deep-Co-Training/Deep-Co-Training)
5. [“GC-LSTM: Graph Convolution Embedded LSTM for Dynamic Link Prediction.”](https://arxiv.org/abs/1812.04206)
6. [“A3T-GCN: Attention Temporal Graph Convolutional Network for Traffic Forecasting.”](https://arxiv.org/abs/2006.11583)

## Description of the Project


Problem statement:

Scarcity of labelled data places a significant limitation on the power of any supervised machine learning model. This challenge is further amplified by the domain-specific expertise and hours of manual labour involved in labelling. In the case when even the available data is not enough, a strategy of **pesudo-labeling** using **semi-supervised learning** can be of help.

In this project we explore the following problems:

- Semi-supervised learning on static graphs with temporal 
- To perform traffic forecasting (node regression) using traffic readings contained in a static graph with temporal signals 
- To perform co-training using the following two views of the dataset:
    - Attention based GNN
    - LSTM GNN
- To compare the results from co-training with an equivalent supervised model



## Outline

1. Data Preperation
2. Supervised Learning
3. Co-Training (Semi-Supervised Learning)
4. Evaluation
5. Analysis

#### Imports

In [1]:
# !pip3 install -U torch-scatter -f https://data.pyg.org/whl/torch-1.12.1+cu113.html
# !pip3 install -U torch-sparse -f https://data.pyg.org/whl/torch-1.12.1+cu113.html
# !pip3 install torch-geometric

# !pip3 install -U torch-cluster -f https://data.pyg.org/whl/torch-1.12.1+cu113.html
# !pip3 install -U torch-spline-conv -f https://data.pyg.org/whl/torch-1.12.1+cu113.html

!pip3 install torch_geometric_temporal



In [2]:
!pip list

Package                  Version
------------------------ ---------
argon2-cffi              21.3.0
argon2-cffi-bindings     21.2.0
attrs                    21.4.0
backcall                 0.2.0
beautifulsoup4           4.10.0
bleach                   5.0.0
brotlipy                 0.7.0
certifi                  2021.10.8
cffi                     1.15.0
chardet                  4.0.0
charset-normalizer       2.0.4
conda                    4.11.0
conda-build              3.21.7
conda-content-trust      0+unknown
conda-package-handling   1.7.3
cryptography             36.0.0
cycler                   0.11.0
Cython                   0.29.32
debugpy                  1.6.0
decorator                4.4.2
defusedxml               0.7.1
dnspython                2.2.1
entrypoints              0.4
fastjsonschema           2.15.3
filelock                 3.4.2
fonttools                4.33.3
glob2                    0.7
idna                     3.3
importlib-resources

In [3]:
# Pytorch
import torch
import torch_geometric
from torch_geometric_temporal.dataset import METRLADatasetLoader
from torch_geometric_temporal.signal import temporal_signal_split
from torch_geometric_temporal.signal import StaticGraphTemporalSignal

# Standard
import networkx as nx
import numpy as np
import matplotlib.pyplot as plt

# Other
import random
from tqdm import tqdm
from tqdm import trange
import copy
import itertools
import os


In [4]:
print(torch.cuda.is_available())
print(torch.version.cuda)
print(torch.cuda.device_count())
# print(torch.cuda.current_device())
# print(torch.cuda.device(0))
# print(torch.cuda.get_device_name(0))

True
11.3
1


In [5]:
# setting device on GPU if available, else CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print()

#Additional Info when using cuda
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')

Using device: cuda

NVIDIA GeForce RTX 3080
Memory Usage:
Allocated: 0.0 GB
Cached:    0.0 GB


## Data Preperation
Approach:
1. Load dataset from [PyTorch Geometric Temporal](https://pytorch-geometric-temporal.readthedocs.io/en/latest/index.html)
2. Split dataset into Train/Test (80% test data, 20% test data - we used a subset of the entire dataset due to RAM constraints)
3. Mask nodes within each timestep to form unlabeled data (70% of nodes unlabeled, 30% of nodes labeled)

### Data Ingestion
In this section we extract the [METR-LA](https://drive.google.com/drive/folders/10FOTa6HXPqX8Pf5WRoRwcFnW9BrNZEIX) static temporal signal graph using [torch geometric temporal](https://pytorch-geometric-temporal.readthedocs.io/en/latest/_modules/torch_geometric_temporal/dataset/metr_la.html). We then split the dataset into `train`/`test` (part of `trunc`) and `rest` (discarded values as the dataset is too large to train on).


**Dataset description:**

A traffic forecasting dataset based on Los Angeles Metropolitan traffic conditions. The dataset contains traffic readings collected from 207 loop detectors on highways in Los Angeles County in aggregated 5 minute intervals for 4 months between March 2012 to June 2012.

For further details on the version of the sensor network and discretization see: [“Diffusion Convolutional Recurrent Neural Network: Data-Driven Traffic Forecasting”](https://arxiv.org/abs/1707.01926). The corresponding [GitHub](https://github.com/liyaguang/DCRNN) repository details the dataset and its representation.

It's crutial to understand that the *graph structure does not change over timeperiods in this temporal data, only the node "features/labels" do*.

NOTE: Any further reference to `timestep` would refer to an interval of 5 minutes



In [6]:
# Load the dataset from METRLA Pytorch Library
metrla_loader = METRLADatasetLoader()

# num_timesteps_in = 1 because in spatial view we input only one snapshot, to predict the next (num_timesteps_out = 1)
metrla_dataset = metrla_loader.get_dataset(num_timesteps_in=1, num_timesteps_out=1)

# Create data splits for train/test
# train_dataset, rest_dataset = temporal_signal_split(metrla_dataset, train_ratio=0.8)
train_dataset, test_dataset = temporal_signal_split(metrla_dataset, train_ratio=0.3)

In [7]:
#print("Size of trunc dataset (timesteps):",len(set(trunc_dataset)))
# print("Size of train dataset (timesteps):",len(set(train_dataset)))
#print("Size of test dataset (timesteps):",len(set(test_dataset)))

### Graph Visualization


In [8]:
def plot_graph(G, layout=nx.spring_layout, node_size=800, font_size=8, node_color='lightgreen'):
    """
    Plot a networkx graph
    """

    plt.figure(figsize=(15, 15))
    pos = layout(G)

    nx.draw_networkx(G, pos=pos, node_size=node_size, font_size=font_size, node_color=node_color)
    nx.draw_networkx_edges(G, pos=pos)
    
    # edge_labels = nx.get_edge_attributes(G, 'Weight')
    # nx.draw_networkx_edge_labels(G, pos=pos, edge_labels=edge_labels)
    plt.show()

### Masking a random subset of nodes

To perform **Semi-Supervised Learning** on the METR-LA Dataset which is fully labelled, we **mask the labels of a certain percentage of nodes** (70%). The nodes that are masked are picked randomly. Masking is performed by setting the label values to `NaN` across all timesteps.

This would be anologous to the failure of a certain percentage of sensors on the highway. 

In [9]:
# Number of nodes
nodes = list(train_dataset)[-1].num_nodes

# Unlabelled fraction
u_fraction = 0.7
random.seed(42)

# Range of numbers from 0 to 206 (number of nodes)
all_nodes = np.arange(nodes)

# Masking random sample of nodes with NaN
masked_nodes = np.array(random.sample(range(nodes), int(u_fraction*nodes)))

# Indices of true labelled nodes
non_masked = np.setdiff1d(all_nodes, masked_nodes)

# Print shape of all nodes, unlabelled and labelled nodes
print(all_nodes.shape, masked_nodes.shape, non_masked.shape)

# Create semi-supervised training dataset
data_semi = copy.deepcopy(train_dataset)

# Graph edges for the last timestep
edge_index = list(train_dataset)[-1].edge_index

(207,) (144,) (63,)


In [10]:
masked_nodes

array([163,  28,   6, 189,  70,  62,  57,  35, 188,  26, 173, 203, 139,
        22, 151, 108,   8,   7,  23,  55,  59, 129, 154, 204, 143,  50,
       166, 179, 194, 107,  56, 114, 150,  71,   1,  40, 191,  87, 196,
        39, 187,  86, 197, 198,  97,  24,  91,  88, 184,  67,  11, 117,
       137,  31,  96,  20, 141,  75,  92, 147,  49,  17, 156,  58,  74,
       192, 186,  25, 162, 168, 116,  93,  41,  94,  90,  53,  68,  18,
        43, 201, 174, 140,  48,  34, 118,  81, 159, 158, 205, 169, 134,
       177,  98,  99, 195,  29, 105,   4, 103, 171,  51, 123, 190,  27,
        72, 160, 115, 170,  83,  63, 181,  82, 182, 185,  33, 145, 153,
       119, 130, 148, 142,  54, 165, 106,  46, 122, 101,  65, 138, 144,
       183,  14,  19, 126,  85, 104, 146, 124, 125, 157,  32, 172, 149,
       110])

In [11]:
# Replace masked node labels with np.nan in train_semi
for snapshot in data_semi:
    snapshot.y[masked_nodes] = np.nan

In [12]:
# Labels of nodes in the last timestep
list(data_semi)[-1].y

tensor([[-2.6522],
        [    nan],
        [ 0.6878],
        [-0.7657],
        [    nan],
        [ 0.3909],
        [    nan],
        [    nan],
        [    nan],
        [ 0.7125],
        [ 0.3724],
        [    nan],
        [-1.8976],
        [ 0.6878],
        [    nan],
        [-1.4832],
        [-1.8234],
        [    nan],
        [    nan],
        [    nan],
        [    nan],
        [ 0.5270],
        [    nan],
        [    nan],
        [    nan],
        [    nan],
        [    nan],
        [    nan],
        [    nan],
        [    nan],
        [ 0.0569],
        [    nan],
        [    nan],
        [    nan],
        [    nan],
        [    nan],
        [ 0.7187],
        [-2.6522],
        [-1.8934],
        [    nan],
        [    nan],
        [    nan],
        [-1.2048],
        [    nan],
        [ 0.7435],
        [ 0.6816],
        [    nan],
        [-0.9945],
        [    nan],
        [    nan],
        [    nan],
        [    nan],
        [-1.

In [13]:
# X values for the first timestep
data_semi[0].x[masked_nodes]

tensor([[[ 7.8058e-01],
         [-1.7292e+00]],

        [[ 6.5687e-01],
         [-1.7292e+00]],

        [[ 5.7028e-01],
         [-1.7292e+00]],

        [[ 8.0532e-01],
         [-1.7292e+00]],

        [[ 4.7132e-01],
         [-1.7292e+00]],

        [[-5.4413e-02],
         [-1.7292e+00]],

        [[ 3.2906e-01],
         [-1.7292e+00]],

        [[ 4.1565e-01],
         [-1.7292e+00]],

        [[ 6.5687e-01],
         [-1.7292e+00]],

        [[ 3.7854e-01],
         [-1.7292e+00]],

        [[ 5.0225e-01],
         [-1.7292e+00]],

        [[ 7.6202e-01],
         [-1.7292e+00]],

        [[ 4.9606e-01],
         [-1.7292e+00]],

        [[ 5.0734e-02],
         [-1.7292e+00]],

        [[-2.3487e-02],
         [-1.7292e+00]],

        [[ 4.9606e-01],
         [-1.7292e+00]],

        [[ 2.9814e-01],
         [-1.7292e+00]],

        [[ 6.6924e-01],
         [-1.7292e+00]],

        [[ 2.2392e-01],
         [-1.7292e+00]],

        [[ 3.4143e-01],
         [-1.7292e+00]],



In [14]:
# Last timestep semi-supervised data
list(data_semi)[-1]

Data(x=[207, 2, 1], edge_index=[2, 1722], edge_attr=[1722], y=[207, 1])

In [15]:
# Calcualte the number of timesteps to be used as batches during training
total_timesteps = len(list(data_semi)) + 1

## Supervised Learning


### Model Initialization

For our two models, we have used one Attention based GNN and the other being a LSTM GNN. The reason behind the choice of these two models is the difference in their respective data processing techniques. The LSTM model places higher weightage on the most recent timesteps (linear processing) while the Attention based model distributes this emphasis across different timesteps (non-linear processing). It is our belief that this qualifies the implementation in terms of the conditional independence of two views required by the co-training algorithm. 

In this section we define the architecture for both of these models which is suitable with static temporal signals.

We have used the Adam optimizer for both GNN models defined.

#### GCLSTM
An implementation of the the Integrated Graph Convolutional Long Short Term Memory Cell. For details see this paper: [“GC-LSTM: Graph Convolution Embedded LSTM for Dynamic Link Prediction.”](https://arxiv.org/abs/1812.04206)

In [16]:
from torch_geometric_temporal.nn.recurrent import A3TGCN, GCLSTM

class LSTMGNN(torch.nn.Module):
    def __init__(self, model_class, node_features, out_channels, **kwargs):
        # kwargs for GCLSTM : {K: K} -> Chebyshev filter order
        periods = kwargs.get('periods', 1)
        K = kwargs.get('K', 1)
        super(LSTMGNN, self).__init__()

        # Attention Temporal Graph Convolutional Cell
        self.tgnn = model_class(in_channels=node_features, 
                           out_channels=out_channels, 
                           **kwargs)
    
        # Equals single-shot prediction
        self.linear = torch.nn.Linear(out_channels, 1)

    def forward(self, x, edge_index):
        """
        x = Node features for T time steps
        edge_index = Graph edge indices
        """
        h, c = self.tgnn(x, edge_index)
#         print('h shape', h.size())
        h = F.relu(h)
        h = self.linear(h)
#         print('h after linear', h.size())
        return h

# model1 = LSTMGNN(GCLSTM, node_features=2, out_channels=nodes, K=2)
# optimizer1 = torch.optim.Adam(params=model1.parameters(), lr = 0.01)

#### ASTGCN
An implementation of the Attention Based Spatial-Temporal Graph Convolutional Cell: [“Attention Based Spatial-Temporal Graph Convolutional Networks for Traffic Flow Forecasting.”](https://ojs.aaai.org/index.php/AAAI/article/view/3881)

In [17]:
import torch.nn.functional as F
from torch_geometric_temporal.nn.attention import ASTGCN
from torch.nn import MSELoss

class AttentionGNN(torch.nn.Module):
    def __init__(self, model_class, node_features, edge_index, **kwargs):
        # kwargs for A3TGCN : {periods=periods}
        # kwargs for GCLSTM : {K: K} -> Chebyshev filter order
        nb_block = kwargs.get('nb_block', 1)
        K = kwargs.get('K', 1)
        nb_chev_filters = kwargs.get('nb_chev_filters', 2)
        nb_time_filters = kwargs.get('nb_time_filters', 2)
        time_strides = kwargs.get('time_strides', 2)
        num_for_predict = kwargs.get('num_for_predict', 1)
        len_input = kwargs.get('len_input', 1)
        num_of_vertices = kwargs.get('num_of_vertices', 207) 

        # K = kwargs.get('K', 1)
        super(AttentionGNN, self).__init__()

        # Attention Temporal Graph Convolutional Cell
        self.tgnn = model_class(in_channels=node_features, 
                            nb_block = nb_block,
                            K = K,
                            nb_chev_filter = nb_chev_filters,
                            nb_time_filter = nb_time_filters,
                            time_strides = time_strides,
                            num_for_predict = num_for_predict,
                            len_input = len_input,
                            num_of_vertices = num_of_vertices,
                            )
    
        # Equals single-shot prediction
        self.linear = torch.nn.Linear(nodes, 1)

    def forward(self, x, edge_index):
        """
        x = Node features for T time steps
        edge_index = Graph edge indices
        """
#         print('Dimensions of x', x.shape)
#         print('Dimensions of permute x.permute(0, 3, 1, 2)', x.permute(0, 3, 1, 2).shape)
        h = self.tgnn(x, edge_index).squeeze()
#         print('Dimensions of h', h.shape)
        # h = F.relu(h)
        # h = self.linear(h)
#         print('h after linear', h.size())
        return h

# model2 = AttentionGNN(A3TGCN, node_features=2, periods=1, out_channels=nodes)
# optimizer2 = torch.optim.Adam(params=model2.parameters(), lr = 0.01)

### Training

In [18]:
def train_super_helper(epoch, model, optimizer, data_semi, non_masked, reshape=False):
    '''
    Helper method to train the supervised model.
    '''
    model.train()
    loss = MSELoss()
    
    # Tensor of size : (number of timesteps in data, )
    loss_total = torch.zeros(len(list(data_semi))).to(device)

    for timestep, snapshot in enumerate(data_semi):
        if timestep % 10 == 0:
            print(timestep)
        
        # Load data to gpu
        snapshot.x = snapshot.x.to(device)
        snapshot.y = snapshot.y.to(device)
        snapshot.edge_index = snapshot.edge_index.to(device)
        
        # Forward passes for a timestep
#         print('Shape of snapshot', snapshot.x.shape)
#         print('Shape of snapshot.x.reshape(-1, 2)', snapshot.x.reshape(-1, 2).shape)
#         print('Shape of snapshot.x.unsqueeze(0)', snapshot.x.unsqueeze(0).shape)
        if reshape:
            yhat = model.forward(snapshot.x.reshape(-1, 2), snapshot.edge_index).reshape(-1)
#             print('LSTM yhat shape', yhat.size())
        else:
            yhat = model.forward(snapshot.x.unsqueeze(0), snapshot.edge_index).reshape(-1)
#             print('Attention yhat shape', yhat.size())
        
        # Individual model losses for that timestep (only for labelled data)
#         print('Shape of yhat[non_masked]', yhat[non_masked].shape)
#         print('Shape of snapshot.y.reshape(-1)[non_masked]', snapshot.y.reshape(-1)[non_masked].shape)
#         print('Dtype of yhat[non_masked]', yhat[non_masked].dtype)
#         print('DType of snapshot.y.reshape(-1)[non_masked]', snapshot.y.reshape(-1)[non_masked].dtype)
        loss_model = loss(yhat[non_masked], snapshot.y.reshape(-1)[non_masked])
        
        loss_total[timestep] = loss_model.item()

        optimizer.zero_grad()

        loss_model.backward()

        optimizer.step()
    
    
    print(f'Train MSE for {epoch}:', torch.sum(loss_total))
    return torch.sum(loss_total)

@torch.no_grad()
def test_super(model, data_labelled, reshape):
    '''
    Method to evaluate the supervised model.
    '''
    model.eval()
    loss = MSELoss()
    
    loss_total = torch.zeros(len(list(data_labelled))).to(device)

    for timestep, snapshot in enumerate(data_labelled):
        
        # Load data to gpu
        snapshot.x = snapshot.x.to(device)
        snapshot.y = snapshot.y.to(device)
        snapshot.edge_index = snapshot.edge_index.to(device)
        
        # Forward passes for a timestep
        if reshape:
            yhat = model.forward(snapshot.x.reshape(-1, 2), snapshot.edge_index).reshape(-1)
        else:
            yhat = model.forward(snapshot.x.unsqueeze(0), snapshot.edge_index).reshape(-1)

        
        loss_model = loss(yhat, snapshot.y.reshape(-1))
        loss_total[timestep] = loss_model.item()
    

    print('Test MSE:', torch.sum(loss_total))
    return torch.sum(loss_total)

In [19]:
def train_super(epochs, model, optimizer, data_semi, data_labelled, non_masked, reshape=False):
    '''
    Method for training the supervised model.
        Args:
            epochs: Number of epochs to train
            model: Temporal Signal model
            optimizer: Optimizer for corresponding model
            data_semi: Masked training dataset
            data_labelled: Non-masked training dataset
            non_masked: Non-masked indices
            reshape: bool for ATTN model
    '''
    # Losses for all epochs
    train_loss = np.empty((0))

    # Train the model
    for epoch in trange(epochs):
        # Single loss value (float) for every epoch
        loss_epoch = train_super_helper(epoch, model, optimizer, data_semi, non_masked, reshape)
        
        # Append the loss
        train_loss = np.append(train_loss, loss_epoch.cpu().detach().numpy())

    # Test the model on the labelled dataset
    test_loss = test_super(model, data_labelled, reshape)
    return train_loss, test_loss

In [37]:
# Train LSTM model using supervised learning
epochs = 2
model_lstm_super = LSTMGNN(GCLSTM, node_features=2, out_channels=nodes, K=2)
optimizer_lstm_super = torch.optim.Adam(params=model_lstm_super.parameters(), lr = 0.01)
data_labelled = train_dataset

In [None]:
# Load Model to GPU
model_lstm_super.to(device)

In [None]:
train_loss_super_lstm, test_loss_super_lstm = train_super(epochs, model_lstm_super, optimizer_lstm_super, data_semi, data_labelled, non_masked, reshape=True)

print(f'Train loss for LSTM: {train_loss_super_lstm}')
print(f'Test loss for LSTM: {test_loss_super_lstm}')


In [None]:
# Train Attention model using supervised learning
# model_attn_super = AttentionGNN(ASTGCN, node_features=2, edge_index=data_semi.edge_index, len_input=2)
model_attn_super = AttentionGNN(ASTGCN, 
                            node_features=2, 
                            edge_index=data_semi.edge_index,
                            nb_block=1, 
                            nb_chev_filters=2, 
                            nb_time_filters=2,
                            time_strides=1,
                            num_for_predict=1,
                            len_input=1,
                            num_of_vertices=nodes
                            )
optimizer_attn_super = torch.optim.Adam(params=model_attn_super.parameters(), lr=0.01)
# train_data = train_dataset

In [None]:
model_attn_super.to(device)

In [None]:
epochs = 50
train_loss_super_attn, test_loss_super_attn = train_super(epochs, model_attn_super, optimizer_attn_super, data_semi, data_labelled, non_masked, reshape=False)

print(f'Train loss for Attn: {train_loss_super_attn}')
print(f'Test loss for Attn: {test_loss_super_attn}')


In [None]:
train_loss_super_attn

### Results

Here are the results from supervised learning on the non-masked dataset:

| Model |   MSE |
| ----- | -------: |
| LSTM Based | 0.324 |
| Attention Based | 0.370 |

**Note: These values are subject to change upon rerunning the notebook

## Co-Training (Semi-Supervised)

[Co-training](https://www.cs.cmu.edu/~avrim/Papers/cotrain.pdf) is a **semi-supervised learning technique** which trains two classifiers based on two different views of data. It was originally proposed by Tom Mitchell and Avrim Blum, two iconic figures from CMU. In this project we have adapted the algorithm to run on static temporal graph networks. 

Our custom algorithm:
1. Create dataset which consists of pools of labeled and unlabeled data (proportion of unlabeled data is much higher)
2. Feed the data into the two GNN models which function as two views of the data (such that the two models can learn different features off of each other)
3. Calculate the absolute difference between the predictions (traffic speed) obtained by the two models and pick the most similar predictions as the top-k most confident pseudo labels (intuition being that the predictions which are consistent across both models are more likely to be accurate)
4. Append the pseudo labels to the pool of labeled data. This way, the amount of labeled data increases across each iteration
5. Continue until all unlabeled data has been converted to labeled data 
6. Final predictions are a simple average of the predictions of the two models

In [20]:
masked_nodes

array([163,  28,   6, 189,  70,  62,  57,  35, 188,  26, 173, 203, 139,
        22, 151, 108,   8,   7,  23,  55,  59, 129, 154, 204, 143,  50,
       166, 179, 194, 107,  56, 114, 150,  71,   1,  40, 191,  87, 196,
        39, 187,  86, 197, 198,  97,  24,  91,  88, 184,  67,  11, 117,
       137,  31,  96,  20, 141,  75,  92, 147,  49,  17, 156,  58,  74,
       192, 186,  25, 162, 168, 116,  93,  41,  94,  90,  53,  68,  18,
        43, 201, 174, 140,  48,  34, 118,  81, 159, 158, 205, 169, 134,
       177,  98,  99, 195,  29, 105,   4, 103, 171,  51, 123, 190,  27,
        72, 160, 115, 170,  83,  63, 181,  82, 182, 185,  33, 145, 153,
       119, 130, 148, 142,  54, 165, 106,  46, 122, 101,  65, 138, 144,
       183,  14,  19, 126,  85, 104, 146, 124, 125, 157,  32, 172, 149,
       110])

In [21]:
non_masked

array([  0,   2,   3,   5,   9,  10,  12,  13,  15,  16,  21,  30,  36,
        37,  38,  42,  44,  45,  47,  52,  60,  61,  64,  66,  69,  73,
        76,  77,  78,  79,  80,  84,  89,  95, 100, 102, 109, 111, 112,
       113, 120, 121, 127, 128, 131, 132, 133, 135, 136, 152, 155, 161,
       164, 167, 175, 176, 178, 180, 193, 199, 200, 202, 206])

### Helper Methods

In [22]:
# Function to append pseudo labels to the pool of labeled data the model is trained on at the end of every iteration
def append_top_k(data, pseudo1, pseudo2, non_masked, masked_nodes, top_k_indices):
    '''
    Args:
        data: Pytorch Temporal Signal Object
        top_k: Dictonary where keys = node_index and value = [pseudo_label for all timesteps]
    Returns:
        data_appended: Pytorch Temporal Signal Object with the top_k values
    '''
    # global non_masked
    # global masked_nodes
    # Replace masked node labels with predicted values of epoch

    print('Check', pseudo1.shape, pseudo2.shape)

    avg = torch.div(torch.add(pseudo1, pseudo2), 2)
    avg_selected = avg[:, top_k_indices]

    print('Avg', avg_selected.shape)

    for i, snapshot in enumerate(data):
        snapshot.y[top_k_indices] = avg_selected[i].reshape(-1,1)
    
    # non_masked = np.array(torch.cat((torch.tensor(non_masked), top_k_indices)))
    # Update non_masked and masked after adding topk 
    print('Before', non_masked.shape, masked_nodes.shape)
    non_masked = np.concatenate((non_masked, np.array(top_k_indices)))
    masked_nodes = np.setdiff1d(all_nodes, non_masked)
    print('After', non_masked.shape, masked_nodes.shape, '\n')

    return data, non_masked, masked_nodes

In [23]:
# Import library for progress bar - using tensors - too much ram - OLD works for model1
# Helper Train method
def train_single_model(epoch, model, optimizer, data_semi, k, non_masked, reshape=False):
    # Train the two models
    # global non_masked
    model.train()

    loss = MSELoss()

#    yhat_total = torch.zeros((total_timesteps, nodes))
    print('Len of data semi', len(list(data_semi)))
    loss_total = torch.zeros(len(list(data_semi))).to(device)

    # Print log and statistics for epoch:
    print()
    print("="*20)
    print('epoch:',epoch+1)

    for timestep, snapshot in enumerate(data_semi):
        
        # Load data to gpu
        snapshot.x = snapshot.x.to(device)
        snapshot.y=snapshot.y.to(device)
        snapshot.edge_index=snapshot.edge_index.to(device)
        
        if timestep % 10 == 0:
            print(timestep)

        if reshape:
            yhat = model.forward(snapshot.x.reshape(-1, 2), snapshot.edge_index).reshape(-1)
        else:
            yhat = model.forward(snapshot.x.unsqueeze(0), snapshot.edge_index).reshape(-1)
        
        # Individual model losses for that timestep (only for labelled data)
        loss_model = loss(yhat[non_masked], snapshot.y.reshape(-1)[non_masked])
        

        optimizer.zero_grad()
        loss_model.backward()
        optimizer.step()

        # yhat_total[timestep+1,:] = yhat
        loss_total[timestep] = loss_model.item()
    
    print(f'Train MSE for epoch {epoch}:', torch.sum(loss_total))

    return torch.sum(loss_total)


@torch.no_grad()
# def test(epoch, model, data, non_masked, reshape=False):
def test_semi_single(epoch, model, data_labelled, reshape=False):
    # Put the model into Eval mode
    model.eval()
    loss = MSELoss()
    loss_total = torch.zeros(len(list(data_labelled))).to(device)

    # Generate Predictions
    for timestep, snapshot in enumerate(data_labelled):
        
        # Load data to gpu
        snapshot.x = snapshot.x.to(device)
        snapshot.y=snapshot.y.to(device)
        snapshot.edge_index=snapshot.edge_index.to(device)
        
        # Forward passes for a timestep
        if reshape:
            yhat = model.forward(snapshot.x.reshape(-1, 2), snapshot.edge_index).reshape(-1)
        else:
            yhat = model.forward(snapshot.x.unsqueeze(0), snapshot.edge_index).reshape(-1)

    
        # loss_model = loss(yhat[non_masked], snapshot.y.reshape(-1)[non_masked])
        loss_model = loss(yhat, snapshot.y.reshape(-1))
        
        loss_total[timestep]=loss_model.item()
        
        
    # print('Testing epoch:', epoch)
    print(f'Test MSE for {epoch}:', torch.sum(loss_total))
    return torch.sum(loss_total)
    #TODO: Generate MSE and print

In [24]:
@torch.no_grad()
def pseudo_labels(model, data_semi, reshape=False):
    '''
    Method to calculate the pseudo labels after one epoch of training.
    Args:
        model: Trained temporal model
        data_semi: Unlabelled dataset
        reshape:
    Return:
        yhat_total: Tensor with all the y outputs for the Temporal Signal Series
    '''

    yhat_total = torch.zeros((len(list(data_semi)), nodes))
    
    # Put the model into Eval mode
    model.eval()

    # Generate Predictions
    for timestep, snapshot in enumerate(data_semi):
        
        # Load data to gpu
        snapshot.x = snapshot.x.to(device)
        snapshot.y=snapshot.y.to(device)
        snapshot.edge_index=snapshot.edge_index.to(device)
        
        # Forward passes for a timestep
        if reshape:
            yhat = model.forward(snapshot.x.reshape(-1, 2), snapshot.edge_index).reshape(-1)
        else:
            yhat = model.forward(snapshot.x.unsqueeze(0), snapshot.edge_index).reshape(-1)
            
        #yhat_total[timestep+1,:] = yhat
        yhat_total[timestep,:] = yhat
        
    return yhat_total

In [25]:
#picking the top-k most confident predictions as pseudo labels 
@torch.no_grad()
def top_k(pred1, pred2, non_masked, masked_nodes, k):
    
    '''
    y_pred1, y_pred2 : Dictionary with node index as key and predicted value as value
    k : Number of predictions to pick as pseudolabels

    Returns dictionary of top-k node labels node labels and corresponding mean values of the predictions
    '''
    # global non_masked
    diff = torch.abs(torch.sub(pred1, pred2)).sum(dim = 0)
    print('Diff', diff.shape)

    # Set the diff of the non masked (labeled) nodes to be 
    # infinity, so that they are not picked ni the top k
    diff[non_masked] = float('inf')
    print('Diff Masked', diff.shape)

    # TODO: make sure that the topk do not contain 'inf' in 
    # later epochs (Check if this works)
    k = min(masked_nodes.shape[0], k)

    topk_res = torch.topk(diff, k, largest=False)
    # print(topk_res)

    return topk_res.indices

In [26]:
masked_nodes.shape[0]

144

In [27]:
non_masked.shape

(63,)

In [28]:
# Before training
torch.isnan(data_semi[-1].y).sum()

tensor(144)

### Training

**Warning**: This method takes atleast an hour to train. So fire it up, grab some popcorn, go watch a tv show and come back :)

In [29]:
t = np.empty((10,2,2))
t

array([[[4.67494734e-310, 0.00000000e+000],
        [4.67494756e-310, 0.00000000e+000]],

       [[4.67494756e-310, 4.67494759e-310],
        [4.67494758e-310, 4.67494730e-310]],

       [[4.67494730e-310, 4.67494758e-310],
        [1.25986740e-321, 4.67494758e-310]],

       [[4.67494730e-310, 4.67494730e-310],
        [4.67494757e-310, 1.25986740e-321]],

       [[4.67494758e-310, 5.43472210e-323],
        [0.00000000e+000, 0.00000000e+000]],

       [[5.26354551e-315, 5.43472210e-323],
        [0.00000000e+000, 0.00000000e+000]],

       [[1.39178292e-320, 4.67494757e-310],
        [7.41098469e-323, 3.05176043e-005]],

       [[0.00000000e+000, 4.94065646e-324],
        [4.67494758e-310, 4.67494758e-310]],

       [[4.67494758e-310, 4.67494758e-310],
        [1.13635099e-322, 0.00000000e+000]],

       [[0.00000000e+000, 4.67480936e-310],
        [1.13635099e-322, 0.00000000e+000]]])

In [30]:
t = np.append(t, [[32],[23]])
t

array([4.67494734e-310, 0.00000000e+000, 4.67494756e-310, 0.00000000e+000,
       4.67494756e-310, 4.67494759e-310, 4.67494758e-310, 4.67494730e-310,
       4.67494730e-310, 4.67494758e-310, 1.25986740e-321, 4.67494758e-310,
       4.67494730e-310, 4.67494730e-310, 4.67494757e-310, 1.25986740e-321,
       4.67494758e-310, 5.43472210e-323, 0.00000000e+000, 0.00000000e+000,
       5.26354551e-315, 5.43472210e-323, 0.00000000e+000, 0.00000000e+000,
       1.39178292e-320, 4.67494757e-310, 7.41098469e-323, 3.05176043e-005,
       0.00000000e+000, 4.94065646e-324, 4.67494758e-310, 4.67494758e-310,
       4.67494758e-310, 4.67494758e-310, 1.13635099e-322, 0.00000000e+000,
       0.00000000e+000, 4.67480936e-310, 1.13635099e-322, 0.00000000e+000,
       3.20000000e+001, 2.30000000e+001])

In [31]:
np.empty((0,1))

array([], shape=(0, 1), dtype=float64)

In [32]:
# Custom train method built for co-training with two models

def co_train(epochs, data_semi, data_labelled, model_lstm, model_attn, optimizer_lstm, optimizer_attn, non_masked, masked_nodes, k):
    '''
    Args:
        train_semi: StaticGraphTemporalSignal
        model_1: Temporal model
        model_2: Temporal attention based model
        optimizer_1: Optimizer for Temporal
        optimizer_2: Optimizer for Attention 
        k: Number of masked nodes pseudolabels to append in each epoch
    '''
    train_loss_total_lstm = np.empty((0,1))
    train_loss_total_attn = np.empty((0,1))
    
    #test_pseudo_lstm_total = np.empty((0,1))
    #test_pseudo_attn_total = np.empty((0,1))
    test_lstm_total = np.empty((0,1))
    test_attn_total = np.empty((0,1))
    

    for epoch in trange(epochs):
        print('\nnon masked:', masked_nodes.size)
        #k = min(k, masked_nodes.size)
        
        print('k:', k)
        
        
        
        # Train and obtain top_k pseudo labels
        print('LSTM Model:')
        train_loss_lstm = train_single_model(epoch, model_lstm, optimizer_lstm, data_semi, k, non_masked, reshape=True)
        
        # CHECK DATA_SEMI SIZE
        #print(f'DATA SEMI SIZE FOR VV: {data_semi}')
        
        print('Attention Model:')
        train_loss_attn = train_single_model(epoch, model_attn, optimizer_attn, data_semi, k, non_masked, reshape=False)
        
        train_loss_total_lstm = np.append(train_loss_total_lstm, train_loss_lstm.cpu().detach().numpy())
        train_loss_total_attn = np.append(train_loss_total_attn, train_loss_attn.cpu().detach().numpy())
    
        
        # Generate Pesudo-labels
        pseudo_labels_lstm = pseudo_labels(model_lstm, data_semi, reshape=True)
        pseudo_labels_attn = pseudo_labels(model_attn, data_semi)

        print('Pseudo1', pseudo_labels_lstm.shape)
        print('Pseudo2', pseudo_labels_attn.shape)

        # TODO: Find top k using sum of squared diff between 2 models
        top_k_node_labels = top_k(pseudo_labels_lstm, pseudo_labels_attn, non_masked, masked_nodes, k)

        # print(top_k_node_labels, sorted(masked_nodes))

        assert not any(np.isin(top_k_node_labels, non_masked))

        # print('Topk shape', top_k_node_labels.shape)

        # Append top_k pesudo labels to the train_semi
        data_semi, non_masked, masked_nodes = append_top_k(data_semi, pseudo_labels_lstm, pseudo_labels_attn, non_masked, masked_nodes, top_k_node_labels)
        
        del pseudo_labels_lstm
        del pseudo_labels_attn


        # Test Model - True Validation
        test_lstm = test_semi_single(epoch, model_lstm, data_labelled, reshape=True)
        test_attn = test_semi_single(epoch, model_attn, data_labelled)
        

        #test_pesudo_lstm_total = np.append(test_pesudo_lstm_total, test_pesudo_lstm.cpu().detach().numpy())
        #test_pesudo_atten_total = np.append(test_pesudo_atten_total, test_pesudo_lstm.cpu().detach().numpy())
        test_lstm_total = np.append(test_lstm_total, test_lstm.cpu().detach().numpy())
        test_attn_total = np.append(test_attn_total, test_attn.cpu().detach().numpy())
    
    
    return {'train_loss_total_lstm':train_loss_total_lstm, 
            'train_loss_total_attn':train_loss_total_attn,
            'test_lstm_total':test_lstm_total,
            'test_attn_total':test_attn_total}

In [33]:
# Redefine Model to avoid continuous training
model_lstm_co = LSTMGNN(GCLSTM, node_features=2, out_channels=nodes, K=2)
optimizer_lstm_co = torch.optim.Adam(params=model_lstm_co.parameters(), lr = 0.01)

model_lstm_co.to(device)

LSTMGNN(
  (tgnn): GCLSTM(
    (conv_i): ChebConv(207, 207, K=2, normalization=sym)
    (conv_f): ChebConv(207, 207, K=2, normalization=sym)
    (conv_c): ChebConv(207, 207, K=2, normalization=sym)
    (conv_o): ChebConv(207, 207, K=2, normalization=sym)
  )
  (linear): Linear(in_features=207, out_features=1, bias=True)
)

In [34]:
model_attn_co = AttentionGNN(ASTGCN, 
                            node_features=2, 
                            edge_index=data_semi.edge_index,
                            nb_block=1, 
                            nb_chev_filters=2, 
                            nb_time_filters=2,
                            time_strides=1,
                            num_for_predict=1,
                            len_input=1,
                            num_of_vertices=nodes
                            )
optimizer_attn_co = torch.optim.Adam(params=model_attn_co.parameters(), lr = 0.01)

model_attn_co.to(device)

AttentionGNN(
  (tgnn): ASTGCN(
    (_blocklist): ModuleList(
      (0): ASTGCNBlock(
        (_temporal_attention): TemporalAttention()
        (_spatial_attention): SpatialAttention()
        (_chebconv_attention): ChebConvAttention(2, 2, K=1, normalization=None)
        (_time_convolution): Conv2d(2, 2, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1))
        (_residual_convolution): Conv2d(2, 2, kernel_size=(1, 1), stride=(1, 1))
        (_layer_norm): LayerNorm((2,), eps=1e-05, elementwise_affine=True)
      )
    )
    (_final_conv): Conv2d(1, 1, kernel_size=(1, 2), stride=(1, 1))
  )
  (linear): Linear(in_features=207, out_features=1, bias=True)
)

In [38]:
# Call co-train
k = 4
epochs=30
# co_train(epochs, data_semi, data_labelled, model_lstm, model_attn, optimizer_lstm, optimizer_attn, non_masked, masked_nodes, k):
total_dict = co_train(epochs, data_semi, data_labelled, model_lstm_co, model_attn_co, optimizer_lstm_co, optimizer_attn_co, non_masked, masked_nodes, k)



  0%|                                                                                                  | 0/30 [00:00<?, ?it/s]


non masked: 144
k: 4
LSTM Model:
Len of data semi 10281

epoch: 1
0
10
20
30
40
50
60
70
80
90
100
110
120
130
140
150
160
170
180
190
200
210
220
230
240
250
260
270
280
290
300
310
320
330
340
350
360
370
380
390
400
410
420
430
440
450
460
470
480
490
500
510
520
530
540
550
560
570
580
590
600
610
620
630
640
650
660
670
680
690
700
710
720
730
740
750
760
770
780
790
800
810
820
830
840
850
860
870
880
890
900
910
920
930
940
950
960
970
980
990
1000
1010
1020
1030
1040
1050
1060
1070
1080
1090
1100
1110
1120
1130
1140
1150
1160
1170
1180
1190
1200
1210
1220
1230
1240
1250
1260
1270
1280
1290
1300
1310
1320
1330
1340
1350
1360
1370
1380
1390
1400
1410
1420
1430
1440
1450
1460
1470
1480
1490
1500
1510
1520
1530
1540
1550
1560
1570
1580
1590
1600
1610
1620
1630
1640
1650
1660
1670
1680
1690
1700
1710
1720
1730
1740
1750
1760
1770
1780
1790
1800
1810
1820
1830
1840
1850
1860
1870
1880
1890
1900
1910
1920
1930
1940
1950
1960
1970
1980
1990
2000
2010
2020
2030
2040
2050
2060
2070
2080

6060
6070
6080
6090
6100
6110
6120
6130
6140
6150
6160
6170
6180
6190
6200
6210
6220
6230
6240
6250
6260
6270
6280
6290
6300
6310
6320
6330
6340
6350
6360
6370
6380
6390
6400
6410
6420
6430
6440
6450
6460
6470
6480
6490
6500
6510
6520
6530
6540
6550
6560
6570
6580
6590
6600
6610
6620
6630
6640
6650
6660
6670
6680
6690
6700
6710
6720
6730
6740
6750
6760
6770
6780
6790
6800
6810
6820
6830
6840
6850
6860
6870
6880
6890
6900
6910
6920
6930
6940
6950
6960
6970
6980
6990
7000
7010
7020
7030
7040
7050
7060
7070
7080
7090
7100
7110
7120
7130
7140
7150
7160
7170
7180
7190
7200
7210
7220
7230
7240
7250
7260
7270
7280
7290
7300
7310
7320
7330
7340
7350
7360
7370
7380
7390
7400
7410
7420
7430
7440
7450
7460
7470
7480
7490
7500
7510
7520
7530
7540
7550
7560
7570
7580
7590
7600
7610
7620
7630
7640
7650
7660
7670
7680
7690
7700
7710
7720
7730
7740
7750
7760
7770
7780
7790
7800
7810
7820
7830
7840
7850
7860
7870
7880
7890
7900
7910
7920
7930
7940
7950
7960
7970
7980
7990
8000
8010
8020
8030
8040
8050



  3%|██▉                                                                                    | 1/30 [03:55<1:53:43, 235.30s/it]

Test MSE for 0: tensor(12727.8945, device='cuda:0')

non masked: 140
k: 4
LSTM Model:
Len of data semi 10281

epoch: 2
0
10
20
30
40
50
60
70
80
90
100
110
120
130
140
150
160
170
180
190
200
210
220
230
240
250
260
270
280
290
300
310
320
330
340
350
360
370
380
390
400
410
420
430
440
450
460
470
480
490
500
510
520
530
540
550
560
570
580
590
600
610
620
630
640
650
660
670
680
690
700
710
720
730
740
750
760
770
780
790
800
810
820
830
840
850
860
870
880
890
900
910
920
930
940
950
960
970
980
990
1000
1010
1020
1030
1040
1050
1060
1070
1080
1090
1100
1110
1120
1130
1140
1150
1160
1170
1180
1190
1200
1210
1220
1230
1240
1250
1260
1270
1280
1290
1300
1310
1320
1330
1340
1350
1360
1370
1380
1390
1400
1410
1420
1430
1440
1450
1460
1470
1480
1490
1500
1510
1520
1530
1540
1550
1560
1570
1580
1590
1600
1610
1620
1630
1640
1650
1660
1670
1680
1690
1700
1710
1720
1730
1740
1750
1760
1770
1780
1790
1800
1810
1820
1830
1840
1850
1860
1870
1880
1890
1900
1910
1920
1930
1940
1950
1960
1970
19

5970
5980
5990
6000
6010
6020
6030
6040
6050
6060
6070
6080
6090
6100
6110
6120
6130
6140
6150
6160
6170
6180
6190
6200
6210
6220
6230
6240
6250
6260
6270
6280
6290
6300
6310
6320
6330
6340
6350
6360
6370
6380
6390
6400
6410
6420
6430
6440
6450
6460
6470
6480
6490
6500
6510
6520
6530
6540
6550
6560
6570
6580
6590
6600
6610
6620
6630
6640
6650
6660
6670
6680
6690
6700
6710
6720
6730
6740
6750
6760
6770
6780
6790
6800
6810
6820
6830
6840
6850
6860
6870
6880
6890
6900
6910
6920
6930
6940
6950
6960
6970
6980
6990
7000
7010
7020
7030
7040
7050
7060
7070
7080
7090
7100
7110
7120
7130
7140
7150
7160
7170
7180
7190
7200
7210
7220
7230
7240
7250
7260
7270
7280
7290
7300
7310
7320
7330
7340
7350
7360
7370
7380
7390
7400
7410
7420
7430
7440
7450
7460
7470
7480
7490
7500
7510
7520
7530
7540
7550
7560
7570
7580
7590
7600
7610
7620
7630
7640
7650
7660
7670
7680
7690
7700
7710
7720
7730
7740
7750
7760
7770
7780
7790
7800
7810
7820
7830
7840
7850
7860
7870
7880
7890
7900
7910
7920
7930
7940
7950
7960



  7%|█████▊                                                                                 | 2/30 [07:50<1:49:49, 235.33s/it]

Test MSE for 1: tensor(10954.1621, device='cuda:0')

non masked: 136
k: 4
LSTM Model:
Len of data semi 10281

epoch: 3
0
10
20
30
40
50
60
70
80
90
100
110
120
130
140
150
160
170
180
190
200
210
220
230
240
250
260
270
280
290
300
310
320
330
340
350
360
370
380
390
400
410
420
430
440
450
460
470
480
490
500
510
520
530
540
550
560
570
580
590
600
610
620
630
640
650
660
670
680
690
700
710
720
730
740
750
760
770
780
790
800
810
820
830
840
850
860
870
880
890
900
910
920
930
940
950
960
970
980
990
1000
1010
1020
1030
1040
1050
1060
1070
1080
1090
1100
1110
1120
1130
1140
1150
1160
1170
1180
1190
1200
1210
1220
1230
1240
1250
1260
1270
1280
1290
1300
1310
1320
1330
1340
1350
1360
1370
1380
1390
1400
1410
1420
1430
1440
1450
1460
1470
1480
1490
1500
1510
1520
1530
1540
1550
1560
1570
1580
1590
1600
1610
1620
1630
1640
1650
1660
1670
1680
1690
1700
1710
1720
1730
1740
1750
1760
1770
1780
1790
1800
1810
1820
1830
1840
1850
1860
1870
1880
1890
1900
1910
1920
1930
1940
1950
1960
1970
19

5960
5970
5980
5990
6000
6010
6020
6030
6040
6050
6060
6070
6080
6090
6100
6110
6120
6130
6140
6150
6160
6170
6180
6190
6200
6210
6220
6230
6240
6250
6260
6270
6280
6290
6300
6310
6320
6330
6340
6350
6360
6370
6380
6390
6400
6410
6420
6430
6440
6450
6460
6470
6480
6490
6500
6510
6520
6530
6540
6550
6560
6570
6580
6590
6600
6610
6620
6630
6640
6650
6660
6670
6680
6690
6700
6710
6720
6730
6740
6750
6760
6770
6780
6790
6800
6810
6820
6830
6840
6850
6860
6870
6880
6890
6900
6910
6920
6930
6940
6950
6960
6970
6980
6990
7000
7010
7020
7030
7040
7050
7060
7070
7080
7090
7100
7110
7120
7130
7140
7150
7160
7170
7180
7190
7200
7210
7220
7230
7240
7250
7260
7270
7280
7290
7300
7310
7320
7330
7340
7350
7360
7370
7380
7390
7400
7410
7420
7430
7440
7450
7460
7470
7480
7490
7500
7510
7520
7530
7540
7550
7560
7570
7580
7590
7600
7610
7620
7630
7640
7650
7660
7670
7680
7690
7700
7710
7720
7730
7740
7750
7760
7770
7780
7790
7800
7810
7820
7830
7840
7850
7860
7870
7880
7890
7900
7910
7920
7930
7940
7950



 10%|████████▋                                                                              | 3/30 [11:38<1:44:18, 231.78s/it]

Test MSE for 2: tensor(3519.3962, device='cuda:0')

non masked: 132
k: 4
LSTM Model:
Len of data semi 10281

epoch: 4
0
10
20
30
40
50
60
70
80
90
100
110
120
130
140
150
160
170
180
190
200
210
220
230
240
250
260
270
280
290
300
310
320
330
340
350
360
370
380
390
400
410
420
430
440
450
460
470
480
490
500
510
520
530
540
550
560
570
580
590
600
610
620
630
640
650
660
670
680
690
700
710
720
730
740
750
760
770
780
790
800
810
820
830
840
850
860
870
880
890
900
910
920
930
940
950
960
970
980
990
1000
1010
1020
1030
1040
1050
1060
1070
1080
1090
1100
1110
1120
1130
1140
1150
1160
1170
1180
1190
1200
1210
1220
1230
1240
1250
1260
1270
1280
1290
1300
1310
1320
1330
1340
1350
1360
1370
1380
1390
1400
1410
1420
1430
1440
1450
1460
1470
1480
1490
1500
1510
1520
1530
1540
1550
1560
1570
1580
1590
1600
1610
1620
1630
1640
1650
1660
1670
1680
1690
1700
1710
1720
1730
1740
1750
1760
1770
1780
1790
1800
1810
1820
1830
1840
1850
1860
1870
1880
1890
1900
1910
1920
1930
1940
1950
1960
1970
198

5960
5970
5980
5990
6000
6010
6020
6030
6040
6050
6060
6070
6080
6090
6100
6110
6120
6130
6140
6150
6160
6170
6180
6190
6200
6210
6220
6230
6240
6250
6260
6270
6280
6290
6300
6310
6320
6330
6340
6350
6360
6370
6380
6390
6400
6410
6420
6430
6440
6450
6460
6470
6480
6490
6500
6510
6520
6530
6540
6550
6560
6570
6580
6590
6600
6610
6620
6630
6640
6650
6660
6670
6680
6690
6700
6710
6720
6730
6740
6750
6760
6770
6780
6790
6800
6810
6820
6830
6840
6850
6860
6870
6880
6890
6900
6910
6920
6930
6940
6950
6960
6970
6980
6990
7000
7010
7020
7030
7040
7050
7060
7070
7080
7090
7100
7110
7120
7130
7140
7150
7160
7170
7180
7190
7200
7210
7220
7230
7240
7250
7260
7270
7280
7290
7300
7310
7320
7330
7340
7350
7360
7370
7380
7390
7400
7410
7420
7430
7440
7450
7460
7470
7480
7490
7500
7510
7520
7530
7540
7550
7560
7570
7580
7590
7600
7610
7620
7630
7640
7650
7660
7670
7680
7690
7700
7710
7720
7730
7740
7750
7760
7770
7780
7790
7800
7810
7820
7830
7840
7850
7860
7870
7880
7890
7900
7910
7920
7930
7940
7950



 13%|███████████▌                                                                           | 4/30 [15:25<1:39:38, 229.96s/it]

Test MSE for 3: tensor(3288.4727, device='cuda:0')

non masked: 128
k: 4
LSTM Model:
Len of data semi 10281

epoch: 5
0
10
20
30
40
50
60
70
80
90
100
110
120
130
140
150
160
170
180
190
200
210
220
230
240
250
260
270
280
290
300
310
320
330
340
350
360
370
380
390
400
410
420
430
440
450
460
470
480
490
500
510
520
530
540
550
560
570
580
590
600
610
620
630
640
650
660
670
680
690
700
710
720
730
740
750
760
770
780
790
800
810
820
830
840
850
860
870
880
890
900
910
920
930
940
950
960
970
980
990
1000
1010
1020
1030
1040
1050
1060
1070
1080
1090
1100
1110
1120
1130
1140
1150
1160
1170
1180
1190
1200
1210
1220
1230
1240
1250
1260
1270
1280
1290
1300
1310
1320
1330
1340
1350
1360
1370
1380
1390
1400
1410
1420
1430
1440
1450
1460
1470
1480
1490
1500
1510
1520
1530
1540
1550
1560
1570
1580
1590
1600
1610
1620
1630
1640
1650
1660
1670
1680
1690
1700
1710
1720
1730
1740
1750
1760
1770
1780
1790
1800
1810
1820
1830
1840
1850
1860
1870
1880
1890
1900
1910
1920
1930
1940
1950
1960
1970
198

5950
5960
5970
5980
5990
6000
6010
6020
6030
6040
6050
6060
6070
6080
6090
6100
6110
6120
6130
6140
6150
6160
6170
6180
6190
6200
6210
6220
6230
6240
6250
6260
6270
6280
6290
6300
6310
6320
6330
6340
6350
6360
6370
6380
6390
6400
6410
6420
6430
6440
6450
6460
6470
6480
6490
6500
6510
6520
6530
6540
6550
6560
6570
6580
6590
6600
6610
6620
6630
6640
6650
6660
6670
6680
6690
6700
6710
6720
6730
6740
6750
6760
6770
6780
6790
6800
6810
6820
6830
6840
6850
6860
6870
6880
6890
6900
6910
6920
6930
6940
6950
6960
6970
6980
6990
7000
7010
7020
7030
7040
7050
7060
7070
7080
7090
7100
7110
7120
7130
7140
7150
7160
7170
7180
7190
7200
7210
7220
7230
7240
7250
7260
7270
7280
7290
7300
7310
7320
7330
7340
7350
7360
7370
7380
7390
7400
7410
7420
7430
7440
7450
7460
7470
7480
7490
7500
7510
7520
7530
7540
7550
7560
7570
7580
7590
7600
7610
7620
7630
7640
7650
7660
7670
7680
7690
7700
7710
7720
7730
7740
7750
7760
7770
7780
7790
7800
7810
7820
7830
7840
7850
7860
7870
7880
7890
7900
7910
7920
7930
7940



 17%|██████████████▌                                                                        | 5/30 [19:14<1:35:45, 229.83s/it]

Test MSE for 4: tensor(3199.4810, device='cuda:0')

non masked: 124
k: 4
LSTM Model:
Len of data semi 10281

epoch: 6
0
10
20
30
40
50
60
70
80
90
100
110
120
130
140
150
160
170
180
190
200
210
220
230
240
250
260
270
280
290
300
310
320
330
340
350
360
370
380
390
400
410
420
430
440
450
460
470
480
490
500
510
520
530
540
550
560
570
580
590
600
610
620
630
640
650
660
670
680
690
700
710
720
730
740
750
760
770
780
790
800
810
820
830
840
850
860
870
880
890
900
910
920
930
940
950
960
970
980
990
1000
1010
1020
1030
1040
1050
1060
1070
1080
1090
1100
1110
1120
1130
1140
1150
1160
1170
1180
1190
1200
1210
1220
1230
1240
1250
1260
1270
1280
1290
1300
1310
1320
1330
1340
1350
1360
1370
1380
1390
1400
1410
1420
1430
1440
1450
1460
1470
1480
1490
1500
1510
1520
1530
1540
1550
1560
1570
1580
1590
1600
1610
1620
1630
1640
1650
1660
1670
1680
1690
1700
1710
1720
1730
1740
1750
1760
1770
1780
1790
1800
1810
1820
1830
1840
1850
1860
1870
1880
1890
1900
1910
1920
1930
1940
1950
1960
1970
198

5960
5970
5980
5990
6000
6010
6020
6030
6040
6050
6060
6070
6080
6090
6100
6110
6120
6130
6140
6150
6160
6170
6180
6190
6200
6210
6220
6230
6240
6250
6260
6270
6280
6290
6300
6310
6320
6330
6340
6350
6360
6370
6380
6390
6400
6410
6420
6430
6440
6450
6460
6470
6480
6490
6500
6510
6520
6530
6540
6550
6560
6570
6580
6590
6600
6610
6620
6630
6640
6650
6660
6670
6680
6690
6700
6710
6720
6730
6740
6750
6760
6770
6780
6790
6800
6810
6820
6830
6840
6850
6860
6870
6880
6890
6900
6910
6920
6930
6940
6950
6960
6970
6980
6990
7000
7010
7020
7030
7040
7050
7060
7070
7080
7090
7100
7110
7120
7130
7140
7150
7160
7170
7180
7190
7200
7210
7220
7230
7240
7250
7260
7270
7280
7290
7300
7310
7320
7330
7340
7350
7360
7370
7380
7390
7400
7410
7420
7430
7440
7450
7460
7470
7480
7490
7500
7510
7520
7530
7540
7550
7560
7570
7580
7590
7600
7610
7620
7630
7640
7650
7660
7670
7680
7690
7700
7710
7720
7730
7740
7750
7760
7770
7780
7790
7800
7810
7820
7830
7840
7850
7860
7870
7880
7890
7900
7910
7920
7930
7940
7950



 20%|█████████████████▍                                                                     | 6/30 [23:04<1:31:54, 229.79s/it]

Test MSE for 5: tensor(3270.2769, device='cuda:0')

non masked: 120
k: 4
LSTM Model:
Len of data semi 10281

epoch: 7
0
10
20
30
40
50
60
70
80
90
100
110
120
130
140
150
160
170
180
190
200
210
220
230
240
250
260
270
280
290
300
310
320
330
340
350
360
370
380
390
400
410
420
430
440
450
460
470
480
490
500
510
520
530
540
550
560
570
580
590
600
610
620
630
640
650
660
670
680
690
700
710
720
730
740
750
760
770
780
790
800
810
820
830
840
850
860
870
880
890
900
910
920
930
940
950
960
970
980
990
1000
1010
1020
1030
1040
1050
1060
1070
1080
1090
1100
1110
1120
1130
1140
1150
1160
1170
1180
1190
1200
1210
1220
1230
1240
1250
1260
1270
1280
1290
1300
1310
1320
1330
1340
1350
1360
1370
1380
1390
1400
1410
1420
1430
1440
1450
1460
1470
1480
1490
1500
1510
1520
1530
1540
1550
1560
1570
1580
1590
1600
1610
1620
1630
1640
1650
1660
1670
1680
1690
1700
1710
1720
1730
1740
1750
1760
1770
1780
1790
1800
1810
1820
1830
1840
1850
1860
1870
1880
1890
1900
1910
1920
1930
1940
1950
1960
1970
198

5970
5980
5990
6000
6010
6020
6030
6040
6050
6060
6070
6080
6090
6100
6110
6120
6130
6140
6150
6160
6170
6180
6190
6200
6210
6220
6230
6240
6250
6260
6270
6280
6290
6300
6310
6320
6330
6340
6350
6360
6370
6380
6390
6400
6410
6420
6430
6440
6450
6460
6470
6480
6490
6500
6510
6520
6530
6540
6550
6560
6570
6580
6590
6600
6610
6620
6630
6640
6650
6660
6670
6680
6690
6700
6710
6720
6730
6740
6750
6760
6770
6780
6790
6800
6810
6820
6830
6840
6850
6860
6870
6880
6890
6900
6910
6920
6930
6940
6950
6960
6970
6980
6990
7000
7010
7020
7030
7040
7050
7060
7070
7080
7090
7100
7110
7120
7130
7140
7150
7160
7170
7180
7190
7200
7210
7220
7230
7240
7250
7260
7270
7280
7290
7300
7310
7320
7330
7340
7350
7360
7370
7380
7390
7400
7410
7420
7430
7440
7450
7460
7470
7480
7490
7500
7510
7520
7530
7540
7550
7560
7570
7580
7590
7600
7610
7620
7630
7640



 20%|█████████████████▍                                                                     | 6/30 [24:48<1:39:13, 248.07s/it]

7650
7660
7670
7680





KeyboardInterrupt: 

In [None]:
total_dict

In [None]:
# After training
torch.isnan(data_semi[-1].y).sum()

In [None]:
from datetime import datetime

now = datetime.now().strftime('%H_%M-%d_%m_%y')
now

In [None]:
# DO NOT ARCHIVE - FOR SAVING MODELS
import os

#try:
#    number = max(int(x) for x in os.listdir('saved_models')) + 1
#except Exception as e:
#    print(e)
#    number = 1

if not os.path.exists(f'saved_models/{now}'):
    os.makedirs(f'saved_models/{now}')
    
torch.save(model_lstm_super.state_dict(), f'saved_models/{now}/model_lstm_super.pt')
torch.save(model_attn_super.state_dict(), f'saved_models/{now}/model_attn_super.pt')

torch.save(model_lstm_co.state_dict(), f'saved_models/{now}/model_lstm_co.pt')
torch.save(model_attn_co.state_dict(), f'saved_models/{now}/model_attn_co.pt')

In [None]:
non_masked

### Results

Here are the results from co-training on the masked (partially labeled) data:

| Model |   MSE |
| ----- | :-----: |
| LSTM Based | 0.0316 |
| Attention Based | 0.0574 |

## Evaluation

In this project we compare supervised learning and semi-supervised learning using mean squared error as our primary metric.

In this section we plot the MSE vs epoch for the two supervised and semi-supervised models. We demonstrate that even with an advantage of being able to "see" or utilize a fully labled dataset with no missing node labels, the supervised learning method fails to compete with the semi-supervised co-training algorithm. 



#### MSE vs Epoch

In [None]:
# plotting method
def plot_mse_epoch(y,n):
    epoch_list = [i for i in range(n)]
    plt.plot(epoch_list, y)

##### Supervised

#### LSTM

In [None]:
30*273

In [None]:
train_loss_super_lstm.shape

In [None]:
plot_mse_epoch(train_loss_super_lstm,epochs)

In [None]:
test_loss_super_lstm.cpu().detach().numpy()

#### ATTENTION 

In [None]:
plot_mse_epoch(train_loss_super_attn, epochs)

In [None]:
test_loss_super_attn.cpu().detach().numpy()

### Co-Training

In [None]:
total_dict

In [None]:
if not os.path.exists(f'saved_losses/{now}'):
    os.makedirs(f'saved_losses/{now}')

In [None]:
np.save(f'./saved_losses/{now}/super_train_loss_lstm',train_loss_super_lstm)
np.save(f'./saved_losses/{now}/super_train_loss_super_attn',train_loss_super_attn)

np.save(f'./saved_losses/{now}/co_train_loss_total_lstm',total_dict['train_loss_total_lstm'])
np.save(f'./saved_losses/{now}/co_train_loss_total_attn',total_dict['train_loss_total_attn'])
np.save(f'./saved_losses/{now}/co_test_lstm_total',total_dict['test_lstm_total'])
np.save(f'./saved_losses/{now}/co_test_attn_total',total_dict['test_attn_total'])

In [None]:
plot_mse_epoch(total_dict['train_loss_total_lstm'], epochs)

In [None]:
plot_mse_epoch(total_dict['train_loss_total_attn'], epochs)

In [None]:
plot_mse_epoch(total_dict['test_lstm_total'], epochs)

In [None]:
plot_mse_epoch(total_dict['test_attn_total'], epochs)

<a style='text-decoration:none;line-height:16px;display:flex;color:#5B5B62;padding:10px;justify-content:end;' href='https://deepnote.com?utm_source=created-in-deepnote-cell&projectId=e3c63109-179d-4951-97a2-5a13ae332211' target="_blank">
 </img>
Created in <span style='font-weight:600;margin-left:4px;'>Deepnote</span></a>

## Load Saved Models

### Co-Training Models 

In [None]:
model_lstm_co = LSTMGNN(GCLSTM, node_features=2, out_channels=nodes, K=2)
model_attn_co = AttentionGNN(A3TGCN, node_features=2, periods=1, out_channels=nodes)

model_lstm_co.load_state_dict(torch.load('/NAM/saved_models/23_33-19_06_22/model_lstm_co.pt'))
model_attn_co.load_state_dict(torch.load('/NAM/saved_models/23_33-19_06_22/model_attn_co.pt'))

In [None]:
model_lstm_co

In [None]:
model_attn_co

###  Load Supervised Models

In [None]:
model_lstm_super = LSTMGNN(GCLSTM, node_features=2, out_channels=nodes, K=2)
model_attn_super = AttentionGNN(A3TGCN, node_features=2, periods=1, out_channels=nodes)

model_lstm_super.load_state_dict(torch.load('/NAM/saved_models/23_33-19_06_22/model_lstm_super.pt'))
model_attn_super.load_state_dict(torch.load('/NAM/saved_models/23_33-19_06_22/model_attn_super.pt'))

In [None]:
print(model_lstm_super)

In [None]:
print(model_attn_super)

#### Load Model to GPU

In [None]:
model_attn_co.to(device)
model_lstm_co.to(device)
model_lstm_super.to(device)
model_attn_super.to(device)

## Evaluate on Train Set 

In [None]:
# Co-Training Validation Test
@torch.no_grad()
def validation_test_cotrain(model_lstm, model_a3t, input_data):
    # Put the models into Eval mode
    model_lstm.eval()
    model_a3t.eval()
    # Define and initialize Loss
    loss = MSELoss()
    loss_total = torch.zeros(len(list(input_data))).to(device)
    
    node = 69
    yhat_for_node = []
    # Generate Predictions
    for timestep, snapshot in enumerate(input_data):
        
        # Load data to gpu
        snapshot.x = snapshot.x.to(device)
        snapshot.y = snapshot.y.to(device)
        snapshot.edge_index = snapshot.edge_index.to(device)
        
        # Forward passes for a timestep
        model_lstm_yhat = model_lstm.forward(snapshot.x.reshape(-1, 2), snapshot.edge_index).reshape(-1)
        model_a3t_yhat = model_a3t.forward(snapshot.x, snapshot.edge_index).reshape(-1)
        
        yhat_cotrain = torch.div(torch.add(model_lstm_yhat, model_a3t_yhat), 2)
        
        yhat_for_node.append(yhat_cotrain[node].cpu().detach().numpy().item())
        
        loss_model = loss(yhat_cotrain[masked_nodes], snapshot.y[masked_nodes].reshape(-1))
        loss_total[timestep]=loss_model.item()
        
    return yhat_cotrain, yhat_for_node, loss_total

In [None]:
yhat_cotrain, yhat_for_node_cotrain, loss_total_cotrain = validation_test_cotrain(model_lstm_co, model_attn_co, rest_dataset)


In [None]:
yhat_cotrain.cpu().detach().numpy()

In [None]:
yhat_for_node

In [None]:
torch.sum(loss_total_cotrain)

In [None]:
# Supervised Validation Test
@torch.no_grad()
def supervised_test_cotrain(model, input_data,reshape=False):
    # Put the models into Eval mode
    model.eval()
    # Define and initialize Loss
    loss = MSELoss()
    loss_total = torch.zeros(len(list(input_data))).to(device)
    
    node = 69
    yhat_for_node = []
    # Generate Predictions
    for timestep, snapshot in enumerate(input_data):

        # Load data to gpu
        snapshot.x = snapshot.x.to(device)
        snapshot.y = snapshot.y.to(device)
        snapshot.edge_index = snapshot.edge_index.to(device)
        
        # Forward passes for a timestep
        if reshape:
            yhat = model.forward(snapshot.x.reshape(-1, 2), snapshot.edge_index).reshape(-1)
        else:
            yhat = model.forward(snapshot.x, snapshot.edge_index).reshape(-1)

        yhat_for_node.append(yhat[node].cpu().detach().numpy().item())

        loss_model = loss(yhat[masked_nodes], snapshot.y[masked_nodes].reshape(-1))
        loss_total[timestep]=loss_model.item()

    return yhat, yhat_for_node, loss_total

In [None]:
yhat_lstm, yhat_for_node_lstm, loss_total_lstm = supervised_test_cotrain(model_lstm_super,rest_dataset,reshape=True)


In [None]:
yhat_for_node_lstm

In [None]:
torch.sum(loss_total_lstm)

In [None]:
yhat_attn, yhat_for_node_attn, loss_total_attn = supervised_test_cotrain(model_attn_super,rest_dataset)


In [None]:
torch.sum(loss_total_attn)

## Load npy saved files

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt

In [None]:
epochs=30

In [None]:
!pwd

In [None]:
saved_losses_dir = '/NAM/saved_losses/23_33-19_06_22'
print("Directory exists:",os.path.isdir(saved_losses_dir))
os.path.abspath(saved_losses_dir)

In [None]:
train_loss_total_lstm=np.load(saved_losses_dir+"/co_test_lstm_total.npy")

In [None]:
train_loss_total_attn=np.load(saved_losses_dir+"/co_train_loss_total_attn.npy")

In [None]:
test_lstm_total=np.load(saved_losses_dir+"/co_test_lstm_total.npy")

In [None]:
test_attn_total=np.load(saved_losses_dir+"/co_test_attn_total.npy")

In [None]:
train_loss_super_lstm=np.load(saved_losses_dir+"/super_train_loss_lstm.npy")
train_loss_super_attn=np.load(saved_losses_dir+"/super_train_loss_super_attn.npy")

In [None]:
total_dict={'train_loss_total_lstm':train_loss_total_lstm, 
            'train_loss_total_attn':train_loss_total_attn,
            'test_lstm_total':test_lstm_total,
            'test_attn_total':test_attn_total}

In [None]:
train_loss_total_lstm[-1]+train_loss_total_lstm[-1]//2

In [None]:
train_loss_super_lstm

In [None]:
train_loss_super_attn

### Plot Loss Figures

In [None]:
def multi_plot(y,n,name):
    epoch_list = [i for i in range(n)]
    for i,l in y:
        plt.plot(epoch_list, i, label=l)
    plt.xlabel("Timestep")
    plt.ylabel("Mph")
    leg = plt.legend(loc='lower right')
    plt.title(name)
    plt.savefig('saved_figs/'+name+'.eps', format='eps')

In [None]:
# EPS Plot saving method
def plot_save(y,n,name):
    epoch_list = [i for i in range(n)]
    plt.plot(epoch_list, y)
    plt.xlabel("Epochs")
    plt.ylabel("MSE")
    plt.title(name)
#     plt.savefig('saved_figs/'+name+'.eps', format='eps')

### Plot Figures from Validation 

In [None]:
true_values=[]
for timestep, snapshot in enumerate(rest_dataset):
    true_values.append(snapshot.y[69].detach().numpy().item())
len(true_values)

In [None]:
multi_plot([
            (true_values[:288],'True values'),
            (yhat_for_node_lstm[:288],'Supervised - LSTM'),
                (yhat_for_node_attn[:288],'Supervised - A3T'),
(yhat_for_node_cotrain[:288],'Co-Training')],
           288,'Speed vs Timesteps')


### Plot Figures from Train

In [None]:
multi_plot([(train_loss_total_lstm, 'Train'),(test_lstm_total,'Test')],epochs,'Co-Training - GC-LSTM')

In [None]:
multi_plot([(train_loss_total_attn,'Train'),(test_attn_total,'Test')],epochs,'A3T-GCN')

In [None]:
multi_plot([(train_loss_total_lstm,'Co-Training'),(train_loss_super_lstm,'Supervised')],epochs,"Co-Training vs Supervised - GC-LSTM")

In [None]:
multi_plot([(train_loss_total_attn,'Co-Training'),(train_loss_super_attn,'Supervised')],epochs,"Co-Training vs Supervised - A3T-GCN")

In [None]:
plot_save(train_loss_total_lstm,epochs,'Co-Training - GC-LSTM')