In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch import Tensor
from torch_geometric.data import Data
from torch_geometric.nn.norm import GraphNorm

from typing import Union, Tuple
from torch.nn import Linear
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.typing import OptPairTensor, Adj, Size

In [None]:
from torch_geometric.data import Data

In [None]:
WINDOW_SIZE = 7
TRAINING_RUN = "2022-03-21-01_05_53"
MODEL_IDX = 25

# Define Model

In [None]:
class WeightedSAGEConv(MessagePassing):
    """The GraphSAGE operator from the `"Inductive Representation Learning on
    Large Graphs" <https://arxiv.org/abs/1706.02216>`_ paper

    Copied from torch_geometric.nn.SageConv and then modified by Sesti et. al and Juan
    Jose Garau to take edge weights into account in message-passing step.

    math:
        \mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i + \mathbf{W_2} \cdot
        \mathrm{mean}_{j \in \mathcal{N(i)}} \mathbf{x}_j

    Args:
        in_channels (int or tuple): Size of each input sample. A tuple
            corresponds to the sizes of source and target dimensionalities.
        out_channels (int): Size of each output sample.
        normalize (bool, optional): If set to True, output features
            will be l_2-normalized (default: False).
        bias (bool, optional): If set to False, the layer will not learn
            an additive bias. (default: True)
        **kwargs (optional): Additional arguments of
            torch_geometric.nn.conv.MessagePassing.
    """

    def __init__(self, 
                in_channels: Union[int, Tuple[int, int]],
                out_channels: int, 
                normalize: bool = False,
                training: bool = True,
                root_weight = True,
                bias: bool = True, 
                **kwargs):
        super(WeightedSAGEConv, self).__init__(aggr='mean', **kwargs)
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.normalize = normalize
        self.root_weight = root_weight
        self.training = training

        if isinstance(in_channels, int):
            in_channels = (in_channels, in_channels)

        self.lin_l = Linear(in_channels[0], out_channels, bias=bias)
        if self.root_weight:
            self.lin_r = Linear(in_channels[1], out_channels, bias=False)
        
        self.reset_parameters()

    def reset_parameters(self):
        self.lin_l.reset_parameters()
        if self.root_weight:
            self.lin_r.reset_parameters()

    def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, edge_weight: Tensor = None,
                size: Size = None) -> Tensor:

        if isinstance(x, Tensor):
            x: OptPairTensor = (x, x)

        # propagate_type: (x: OptPairTensor)
        out = self.propagate(edge_index, x=x, size=size, edge_weight=edge_weight)
        out = self.lin_l(out)

        x_r = x[1]
        if self.root_weight and x_r is not None:
            out += self.lin_r(x_r)

        if self.normalize:
            out = F.normalize(out, p=2., dim=-1)

        return out

    def message(self, x_i: Tensor, x_j: Tensor, edge_weight) -> Tensor:
        """
        Constructs messages from node j to node i in analogy to ϕΘ for each edge in 
        edge_index. This function can take any argument as input which was initially 
        passed to propagate(). Furthermore, tensors passed to propagate() can be 
        mapped to the respective nodes i and j by appending _i or _j to the variable 
        name, .e.g. x_i and x_j.

        x_i.shape and x_j.shape is [num_edges, embedding dim (num_features or graph emb dim)]
        edge_weight.shape is [num_edges, 1]
        """

        return x_j * edge_weight  # [num_edges, dim] * [num_edges, 1] = [num_edges, dim]
        # return x_j

    # def message_and_aggregate(self, adj_t: SparseTensor, x: OptPairTensor) -> Tensor:
    #     # Not using Sparse Tensors, so this is not called
    #     adj_t = adj_t.set_value(None, layout=None)
    #     return matmul(adj_t, x[0], reduce=self.aggr)

    def __repr__(self):
        return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, self.out_channels)


class DynamicAdjSAGE(torch.nn.Module):
    """
    Architecture
    1 Weighted GraphSAGE layer -> 1 Weigted GraphSAGE layer -> concat outputs of both GraphSAGE layers 
        -> 1 LSTM cell -> concat LSTM cell hidden state output for N+1 day and original input features
        X for N days -> MLP with ReLU activation.

    Architecture expects to receive 1 day graphs, so that adjacency matrix can change every day. N 
    1-day graphs passed through, then take model output for N+1 day. 

    Constructor Arguments:
        node_features: number of features each node in 1-day graph contains
        emd_dim: embedding dimension that WeightedGraphSAGE layers will output
        window_size: number of 1-day graphs that will be passed through network before predicting next day
        output: number of features to predict for each node on N+1 day
        training: whether the model is being loaded for training or test. Affects things like dropout if present.
        lstm_type: what type of LSTM to use (["vanila"])
    """

    def __init__(self,
                node_features: int = 3, 
                emb_dim: int = 16,
                window_size: int = 14,
                output: int = 1, 
                training: bool = True,
                lstm_type: str = 'vanilla',
                name: str = "DASAGE"):
        super(DynamicAdjSAGE, self).__init__()
        assert lstm_type in ["vanilla"]

        self.emb_dim = emb_dim
        self.window_size = window_size
        self.training = training
        self.lstm_type = lstm_type
        self.name = name

        normalize_graphsage_layers = False

        self.sage1 = WeightedSAGEConv(in_channels=node_features, out_channels=self.emb_dim, normalize=normalize_graphsage_layers, training=self.training)
        self.sage2 = WeightedSAGEConv(in_channels=self.emb_dim, out_channels=self.emb_dim, normalize=normalize_graphsage_layers, training=self.training)

        self.graph_norm_1 = GraphNorm(self.emb_dim)
        self.graph_norm_2 = GraphNorm(self.emb_dim)

        self.lstm1 = nn.LSTMCell(input_size=2 * self.emb_dim, hidden_size=self.emb_dim)
        self.lstm2 = nn.LSTMCell(input_size=self.emb_dim, hidden_size=self.emb_dim)
        
        self.act1 = torch.nn.ReLU()
        self.lin1 = torch.nn.Linear(self.window_size + (2 * self.emb_dim), 13)
        self.act2 = torch.nn.ReLU()
        self.lin2 = torch.nn.Linear(13, output)

        # self.init_weights()  # Initialize weights with orthogonal matrices
        
        # For concatenating features across each day of time window
        self.concat_feat_list = []
    
    def init_weights(self):
        nn.init.orthogonal_(self.sage1.lin_l.weight)
        nn.init.orthogonal_(self.sage1.lin_r.weight)
        nn.init.orthogonal_(self.sage2.lin_l.weight)
        nn.init.orthogonal_(self.sage2.lin_r.weight)
        # nn.init.orthogonal_(self.graph_norm_1.weight)  # Only tensors with 2+ dimensions are supported
        # nn.init.orthogonal_(self.graph_norm_2.weight)

        # Pytorch LSTMCell only has 2 weight matrices, each one is 4*hidden_size * output_size,
        # meaning these 2 matrices contain the 8 LSTM kernels we are trying to initialize
        nn.init.orthogonal_(self.lstm1.weight_ih)
        nn.init.orthogonal_(self.lstm1.weight_hh)
        nn.init.orthogonal_(self.lstm2.weight_ih)
        nn.init.orthogonal_(self.lstm2.weight_hh)

        nn.init.orthogonal_(self.lin1.weight)
        nn.init.orthogonal_(self.lin2.weight)
        print("Ran init_weights().")

    def forward(self, data: Data, h_1: Tensor=None, c_1: Tensor=None, 
                h_2: Tensor=None, c_2: Tensor=None, day_idx: int=0):
        # Get data from snapshot
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        """
        x is [10, 3]
        edge_index is [2, num_edges]
        edge_attr is [num_edges, 1]
        last_day_flag is False unless day graph being passed is last day graph in the time window
        """
        
        graphsage_outputs = []
        self.concat_feat_list.append(x[:,1:2])

        # First GNN Layer
        x = self.sage1(x, edge_index, edge_attr)
        x = self.graph_norm_1(x)
        x = F.relu(x)  
        graphsage_outputs.append(x)
        
        x = self.sage2(x, edge_index, edge_attr)
        x = self.graph_norm_2(x)
        x = F.relu(x)
        graphsage_outputs.append(x)

        x = torch.cat(graphsage_outputs, dim=1)

        # Initialize hidden and cell states if None
        if h_1 is None:
            h_1 = torch.zeros(x.shape[0], self.emb_dim)
        if c_1 is None:
            c_1 = torch.zeros(x.shape[0], self.emb_dim)
        if h_2 is None:
            h_2 = torch.zeros(x.shape[0], self.emb_dim)
        if c_2 is None:
            c_2 = torch.zeros(x.shape[0], self.emb_dim)

        # RNN Layer
        h_1, c_1 = self.lstm1(x, (h_1, c_1))  # h_1 and c_1 both become [10, self.emb_dim]
        h_2, c_2 = self.lstm2(h_1, (h_2, c_2))  # h_2 and c_2 both become [10, self.emb_dim]
        
        if day_idx == self.window_size - 1:
            # Skip connection
            concat_feat = torch.cat(self.concat_feat_list, dim=1)  # becomes [10, 16]
            x = torch.cat((concat_feat, h_1, h_2), dim=1)  # x becomes [10, 16 + 2 * self.emb_dim]
            self.concat_feat_list.clear()

            # Readout and activation layers
            x = self.act1(x)
            x = self.lin1(x)
            x = self.act2(x)
            x = self.lin2(x)

        return x, h_1, c_1, h_2, c_2

In [None]:
model = DynamicAdjSAGE(node_features=2, 
        emb_dim=10, 
        window_size=WINDOW_SIZE, 
        output=1, 
        training=False, 
        lstm_type="vanilla", 
        name="DASAGE")

# If not running on Syed's laptop, then need to change this path to directory where 100 DCSAGE 14-day models are stored
checkpoint = torch.load("/Users/syedrizvi/Desktop/Projects/GNN_Project/DCSAGE/Training-Code/training-runs-multiple-models/" + TRAINING_RUN + "/model_" + str(MODEL_IDX) + ".pth")
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# Extract Weights

In [None]:
print(model.sage1.lin_l.weight.shape)  # lin_l is weights for aggregated neighbor features
print(model.sage1.lin_r.weight.shape)  # lin_r is weights for node's own features
print(model.sage2.lin_l.weight.shape)  # lin_l is weights for aggregated neighbor features
print(model.sage2.lin_r.weight.shape)

In [None]:
sage1_lin_l = model.sage1.lin_l.weight.detach().numpy()
sage1_lin_r = model.sage1.lin_r.weight.detach().numpy()
sage2_lin_l = model.sage2.lin_l.weight.detach().numpy()
sage2_lin_r = model.sage2.lin_r.weight.detach().numpy()

print(sage1_lin_l.shape)
print(sage1_lin_r.shape)
print(sage2_lin_l.shape)
print(sage2_lin_r.shape)

In [None]:
print(model.lstm1.weight_hh.shape)
print(model.lstm1.weight_ih.shape)
print(model.lstm2.weight_hh.shape)
print(model.lstm2.weight_ih.shape)

In [None]:
lstm1_hh = model.lstm1.weight_hh.detach().numpy()
lstm1_ih = model.lstm1.weight_ih.detach().numpy()
lstm2_hh = model.lstm2.weight_hh.detach().numpy()
lstm2_ih = model.lstm2.weight_ih.detach().numpy()

# Visualize Extracted Weights as Heatmaps, Distribution Plots

In [None]:
def visualize_weights_heatmap(model_weights, title, save_name):
    mean = model_weights.mean()
    std = model_weights.std()
    median = np.median(model_weights)

    sns.heatmap(model_weights)
    plt.title("Model {} {} \n(Mean: {:.4f}, Median: {:.4f}, STD: {:.4f})".format(MODEL_IDX, title, mean, median, std))
    plt.savefig("model{}_{}".format(MODEL_IDX, save_name), bbox_inches="tight", facecolor="white")
    # plt.show()
    plt.clf()

In [None]:
visualize_weights_heatmap(sage1_lin_l, title="GraphSAGE Layer 1 Neighbor Aggr Weights", save_name="sage_1_neighbor_weights_heatmap")
visualize_weights_heatmap(sage1_lin_r, title="GraphSAGE Layer 1 Self Aggr Weights", save_name="sage_1_self_weights_heatmap")
visualize_weights_heatmap(sage2_lin_l, title="GraphSAGE Layer 2 Neighbor Aggr Weights", save_name="sage_2_neighbor_weights_heatmap")
visualize_weights_heatmap(sage2_lin_r, title="GraphSAGE Layer 2 Self Aggr Weights", save_name="sage_2_self_weights_heatmap")

In [None]:
visualize_weights_heatmap(lstm1_hh, title="LSTM Cell 1 hh", save_name="lstm_1_hh_heatmap")
visualize_weights_heatmap(lstm1_ih, title="LSTM Cell 1 ih", save_name="lstm_1_ih_heatmap")
visualize_weights_heatmap(lstm2_hh, title="LSTM Cell 2 hh", save_name="lstm_2_hh_heatmap")
visualize_weights_heatmap(lstm2_ih, title="LSTM Cell 2 ih", save_name="lstm_2_ih_heatmap")

In [None]:
def visualize_weights_distribution(model_weights, title, save_name):
    mean = model_weights.mean()
    std = model_weights.std()
    median = np.median(model_weights)

    sns.histplot(model_weights.flatten(), kde=True)
    plt.title("Model {} {} \n(Mean: {:.4f}, Median: {:.4f},, STD: {:.4f})".format(MODEL_IDX, title, mean, median, std))
    plt.xlim(-1, 1)
    plt.savefig("model{}_{}".format(MODEL_IDX, save_name), bbox_inches="tight", facecolor="white")
    # plt.show()
    plt.clf()

In [None]:
visualize_weights_distribution(sage1_lin_l, title="GraphSAGE Layer 1 Neighbor Aggr Weights", save_name="sage_1_neighbor_weights_distrib")
visualize_weights_distribution(sage1_lin_r, title="GraphSAGE Layer 1 Self Aggr Weights", save_name="sage_1_self_weights_distrib")
visualize_weights_distribution(sage2_lin_l, title="GraphSAGE Layer 2 Neighbor Aggr Weights", save_name="sage_2_neighbor_weights_distrib")
visualize_weights_distribution(sage2_lin_r, title="GraphSAGE Layer 2 Self Aggr Weights", save_name="sage_2_self_weights_distrib")

In [None]:
visualize_weights_distribution(lstm1_hh, title="LSTM Cell 1 hh", save_name="lstm_1_hh_distrib")
visualize_weights_distribution(lstm1_ih, title="LSTM Cell 1 ih", save_name="lstm_1_ih_distrib")
visualize_weights_distribution(lstm2_hh, title="LSTM Cell 2 hh", save_name="lstm_2_hh_distrib")
visualize_weights_distribution(lstm2_ih, title="LSTM Cell 2 ih", save_name="lstm_2_ih_distrib")