In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.manifold import TSNE

In [2]:
dataset_v18 = np.load("./datasets/10_continents_dataset_v18_node_pert.npz")
dataset_v18.files

['feature_matrix_smooth',
 'feature_matrix_unsmooth',
 'flight_matrix_unscaled',
 'flight_matrix_log10_scaled']

In [3]:
dataset_v18_train = np.load("/Users/syedrizvi/Desktop/Projects/GNN_Project/DCSAGE/Training-Code/datasets/10_continents_dataset_v18_training.npz")
dataset_v18_train.files

['train_features_log10',
 'train_log10_scaled_flight_matrix',
 'train_unscaled_flight_matrix',
 'val_features_log10',
 'val_log10_scaled_flight_matrix',
 'val_unscaled_flight_matrix',
 'test_features_log10_unsmooth',
 'test_features_log10_smooth',
 'test_log10_scaled_flight_matrix',
 'test_unscaled_flight_matrix']

In [4]:
print(dataset_v18_train["train_unscaled_flight_matrix"].max())
print(dataset_v18_train["val_unscaled_flight_matrix"].max())
print(dataset_v18_train["test_unscaled_flight_matrix"].max())

897.0
537.0
815.0


## Define WeightedSAGEConv and DCSAGE Architectres

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch import Tensor
from torch_geometric.data import Data
from torch_geometric.nn.norm import GraphNorm


from typing import Union, Tuple
from torch.nn import Linear
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.typing import OptPairTensor, Adj, Size


class WeightedSAGEConv(MessagePassing):
    """The GraphSAGE operator from the `"Inductive Representation Learning on
    Large Graphs" <https://arxiv.org/abs/1706.02216>`_ paper

    Copied from torch_geometric.nn.SageConv and then modified by Sesti et. al and Juan
    Jose Garau to take edge weights into account in message-passing step.

    math:
        \mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i + \mathbf{W_2} \cdot
        \mathrm{mean}_{j \in \mathcal{N(i)}} \mathbf{x}_j

    Args:
        in_channels (int or tuple): Size of each input sample. A tuple
            corresponds to the sizes of source and target dimensionalities.
        out_channels (int): Size of each output sample.
        normalize (bool, optional): If set to True, output features
            will be l_2-normalized (default: False).
        bias (bool, optional): If set to False, the layer will not learn
            an additive bias. (default: True)
        **kwargs (optional): Additional arguments of
            torch_geometric.nn.conv.MessagePassing.
    """

    def __init__(self, 
                in_channels: Union[int, Tuple[int, int]],
                out_channels: int, 
                normalize: bool = False,
                training: bool = True,
                root_weight = True,
                bias: bool = True, 
                **kwargs):
        super(WeightedSAGEConv, self).__init__(aggr='mean', **kwargs)
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.normalize = normalize
        self.root_weight = root_weight
        self.training = training

        if isinstance(in_channels, int):
            in_channels = (in_channels, in_channels)

        self.lin_l = Linear(in_channels[0], out_channels, bias=bias)
        if self.root_weight:
            self.lin_r = Linear(in_channels[1], out_channels, bias=False)
        
        self.reset_parameters()

    def reset_parameters(self):
        self.lin_l.reset_parameters()
        if self.root_weight:
            self.lin_r.reset_parameters()

    def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, edge_weight: Tensor = None,
                size: Size = None) -> Tensor:

        if isinstance(x, Tensor):
            x: OptPairTensor = (x, x)

        # propagate_type: (x: OptPairTensor)
        out = self.propagate(edge_index, x=x, size=size, edge_weight=edge_weight)
        out = self.lin_l(out)

        x_r = x[1]
        if self.root_weight and x_r is not None:
            out += self.lin_r(x_r)

        if self.normalize:
            out = F.normalize(out, p=2., dim=-1)

        return out

    def message(self, x_i: Tensor, x_j: Tensor, edge_weight) -> Tensor:
        """
        Constructs messages from node j to node i in analogy to ϕΘ for each edge in 
        edge_index. This function can take any argument as input which was initially 
        passed to propagate(). Furthermore, tensors passed to propagate() can be 
        mapped to the respective nodes i and j by appending _i or _j to the variable 
        name, .e.g. x_i and x_j.

        x_i.shape and x_j.shape is [num_edges, embedding dim (num_features or graph emb dim)]
        edge_weight.shape is [num_edges, 1]
        """

        return x_j * edge_weight  # [num_edges, dim] * [num_edges, 1] = [num_edges, dim]
        # return x_j

    # def message_and_aggregate(self, adj_t: SparseTensor, x: OptPairTensor) -> Tensor:
    #     # Not using Sparse Tensors, so this is not called
    #     adj_t = adj_t.set_value(None, layout=None)
    #     return matmul(adj_t, x[0], reduce=self.aggr)

    def __repr__(self):
        return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, self.out_channels)


class DynamicAdjSAGE(torch.nn.Module):
    def __init__(self,
                node_features: int = 3, 
                emb_dim: int = 16,
                window_size: int = 14,
                output: int = 1, 
                training: bool = True,
                lstm_type: str = 'vanilla',
                name: str = "DASAGE"):
        super(DynamicAdjSAGE, self).__init__()
        assert lstm_type in ["vanilla"]

        self.emb_dim = emb_dim
        self.window_size = window_size
        self.training = training
        self.lstm_type = lstm_type
        self.name = name

        normalize_graphsage_layers = False

        self.sage1 = WeightedSAGEConv(in_channels=node_features, out_channels=self.emb_dim, normalize=normalize_graphsage_layers, training=self.training)
        self.sage2 = WeightedSAGEConv(in_channels=self.emb_dim, out_channels=self.emb_dim, normalize=normalize_graphsage_layers, training=self.training)

        self.graph_norm_1 = GraphNorm(self.emb_dim)
        self.graph_norm_2 = GraphNorm(self.emb_dim)

        self.lstm1 = nn.LSTMCell(input_size=2 * self.emb_dim, hidden_size=self.emb_dim)
        self.lstm2 = nn.LSTMCell(input_size=self.emb_dim, hidden_size=self.emb_dim)
        
        self.act1 = torch.nn.ReLU()
        self.lin1 = torch.nn.Linear(self.window_size + (2 * self.emb_dim), 13)
        self.act2 = torch.nn.ReLU()
        self.lin2 = torch.nn.Linear(13, output)

        # self.init_weights()  # Initialize weights with orthogonal matrices
        
        # For concatenating features across each day of time window
        self.concat_feat_list = []

    def init_weights(self):
        nn.init.orthogonal_(self.sage1.lin_l.weight)
        nn.init.orthogonal_(self.sage1.lin_r.weight)
        nn.init.orthogonal_(self.sage2.lin_l.weight)
        nn.init.orthogonal_(self.sage2.lin_r.weight)
        # nn.init.orthogonal_(self.graph_norm_1.weight)  # Only tensors with 2+ dimensions are supported
        # nn.init.orthogonal_(self.graph_norm_2.weight)

        # Pytorch LSTMCell only has 2 weight matrices, each one is 4*hidden_size * output_size,
        # meaning these 2 matrices contain the 8 LSTM kernels we are trying to initialize
        nn.init.orthogonal_(self.lstm1.weight_ih)
        nn.init.orthogonal_(self.lstm1.weight_hh)
        nn.init.orthogonal_(self.lstm2.weight_ih)
        nn.init.orthogonal_(self.lstm2.weight_hh)

        nn.init.orthogonal_(self.lin1.weight)
        nn.init.orthogonal_(self.lin2.weight)
        print("Ran init_weights().")

    def forward(self, data: Data, h_1: Tensor=None, c_1: Tensor=None, 
                h_2: Tensor=None, c_2: Tensor=None, day_idx: int=0):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        
        graphsage_outputs = []
        self.concat_feat_list.append(x[:,1:2])

        x = self.sage1(x, edge_index, edge_attr)  # x becomes [10, self.emb_dim]
        self.graphsage1_out = x
        x = self.graph_norm_1(x)
        x = F.relu(x)  
        graphsage_outputs.append(x)
        
        x = self.sage2(x, edge_index, edge_attr)  # x becomes [10, self.emb_dim]
        self.graphsage2_out = x
        x = self.graph_norm_2(x)
        x = F.relu(x)
        graphsage_outputs.append(x)

        x = torch.cat(graphsage_outputs, dim=1)  # x becomes [10, 2 * self.emb_dim]
        self.graphsage_concat = x

        if h_1 is None:
            h_1 = torch.zeros(x.shape[0], self.emb_dim)
        if c_1 is None:
            c_1 = torch.zeros(x.shape[0], self.emb_dim)
        if h_2 is None:
            h_2 = torch.zeros(x.shape[0], self.emb_dim)
        if c_2 is None:
            c_2 = torch.zeros(x.shape[0], self.emb_dim)

        h_1, c_1 = self.lstm1(x, (h_1, c_1))  # h_1 and c_1 both become [10, self.emb_dim]
        h_2, c_2 = self.lstm2(h_1, (h_2, c_2))  # h_2 and c_2 both become [10, self.emb_dim]
        
        if day_idx == self.window_size - 1:
            concat_feat = torch.cat(self.concat_feat_list, dim=1)
            x = torch.cat((concat_feat, h_1, h_2), dim=1)
            self.concat_feat_list.clear()

            x = self.act1(x)
            x = self.lin1(x)
            x = self.act2(x)
            x = self.lin2(x)

        return x, h_1, c_1, h_2, c_2

## Define Dataloader

In [6]:
from covid_10country_perturb_dataset import Covid10CountriesPerturbedDataset, Covid10CountriesUnperturbedDataset
from torch.utils.data import DataLoader
from torch_geometric.data import Data

In [7]:
WINDOW_SIZE = 7
TRAINING_RUN = "2022-03-21-01_05_53"
MODEL_IDX = 7

# Sum aggregator filtered models: [0 1 3 29 37 41 43 63 77 78 82 96]
# Sum agg: model 10 is sensitive negatively 

In [8]:
# Define 10 perturbed dataloaders

perturbed_dataloaders = []
for idx in range(10):
    dataset_unsmooth = Covid10CountriesPerturbedDataset(
        dataset_npz_path="./datasets/10_continents_dataset_v18_node_pert.npz", 
        window_size=WINDOW_SIZE, 
        data_split="entire-dataset-smooth", 
        perturb_country_idx=idx, 
        avg_graph_structure=False)
    dataloader = DataLoader(dataset_unsmooth, batch_size=800, shuffle=False)
    perturbed_dataloaders.append(dataloader)

# Define one regular unperturbed dataloaders
unperturbed_dataset = Covid10CountriesUnperturbedDataset(
    dataset_npz_path="./datasets/10_continents_dataset_v18_node_pert.npz",
    window_size=WINDOW_SIZE, 
    data_split="entire-dataset-smooth", 
    avg_graph_structure=False)
unperturbed_dataloader = DataLoader(unperturbed_dataset, batch_size=800, shuffle=False)


## Define Model

In [9]:
model = DynamicAdjSAGE(
        node_features=2, 
        emb_dim=10, 
        window_size=WINDOW_SIZE, 
        output=1, 
        training=True, 
        lstm_type="vanilla", 
        name="DCSAGE")

# If not running on Syed's laptop, then need to change this path to directory where 100 DCSAGE 14-day models are stored
checkpoint = torch.load("/Users/syedrizvi/Desktop/Projects/GNN_Project/DCSAGE/Training-Code/training-runs-multiple-models/" + TRAINING_RUN + "/model_" + str(MODEL_IDX) + ".pth")
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

DynamicAdjSAGE(
  (sage1): WeightedSAGEConv(2, 10)
  (sage2): WeightedSAGEConv(10, 10)
  (graph_norm_1): GraphNorm(10)
  (graph_norm_2): GraphNorm(10)
  (lstm1): LSTMCell(20, 10)
  (lstm2): LSTMCell(10, 10)
  (act1): ReLU()
  (lin1): Linear(in_features=27, out_features=13, bias=True)
  (act2): ReLU()
  (lin2): Linear(in_features=13, out_features=1, bias=True)
)

## LSTM Activation Plot Showing Each Day in 7-Day Window

Notes:
* 1 model, 512 windows in dataset, 7 days in each window
* 1 day graph: [10, 10]  -  10 countries, embedding dim 10

In [10]:
print(len(unperturbed_dataloader.dataset))

560


### Get unperturbed LSTM and graphsage activations

In [11]:
lstm_h1_activations_list = [[] for _ in range(WINDOW_SIZE)]
lstm_h2_activations_list = [[] for _ in range(WINDOW_SIZE)]

gs_1_activations_list = [[] for _ in range(WINDOW_SIZE)]
gs_2_activations_list = [[] for _ in range(WINDOW_SIZE)]
gs_concat_activations_list = [[] for _ in range(WINDOW_SIZE)]

with torch.no_grad():
    for batch_window_node_feat, batch_window_edge_idx, batch_window_edge_attr, batch_window_labels in unperturbed_dataloader:
        for count, window_idx in enumerate(range(len(batch_window_node_feat))):
            window_node_feat = batch_window_node_feat[window_idx]
            window_edge_idx = batch_window_edge_idx[window_idx]
            window_edge_attr = batch_window_edge_attr[window_idx]
            window_labels = batch_window_labels[window_idx]
            
            h_1, c_1, h_2, c_2 = None, None, None, None
            for day_idx in range(len(window_node_feat)):
                day_node_feat = window_node_feat[day_idx]
                day_edge_idx = window_edge_idx[day_idx]
                day_edge_attr = window_edge_attr[day_idx]

                cutoff_idx = day_edge_idx[0].tolist().index(-1)
                day_edge_idx = day_edge_idx[:, :cutoff_idx]
                day_edge_attr = day_edge_attr[:cutoff_idx, :]
                
                day_graph = Data(x=day_node_feat, edge_index=day_edge_idx, edge_attr=day_edge_attr)
                y_hat, h_1, c_1, h_2, c_2 = model(day_graph, h_1, c_1, h_2, c_2, day_idx)
            
                # Accumulate LSTM and graphsage outputs for each window
                lstm_h1_activations_list[day_idx].append(h_1.unsqueeze(0))  # h_1 is [10, 10]
                lstm_h2_activations_list[day_idx].append(h_2.unsqueeze(0))
                
                gs_1_activations_list[day_idx].append(model.graphsage1_out.unsqueeze(0))
                gs_2_activations_list[day_idx].append(model.graphsage2_out.unsqueeze(0))
                gs_concat_activations_list[day_idx].append(model.graphsage_concat.unsqueeze(0))

In [12]:
lstm_h1_activations = [torch.cat(activations, dim=0).numpy() for activations in lstm_h1_activations_list]
lstm_h2_activations = [torch.cat(activations, dim=0).numpy() for activations in lstm_h2_activations_list]
gs_1_activations = [torch.cat(activations, dim=0).numpy() for activations in gs_1_activations_list]
gs_2_activations = [torch.cat(activations, dim=0).numpy() for activations in gs_2_activations_list]
gs_concat_activations = [torch.cat(activations, dim=0).numpy() for activations in gs_concat_activations_list]

unpert_lstm_h1_activations = np.array(lstm_h1_activations)
unpert_lstm_h2_activations = np.array(lstm_h2_activations)
unpert_gs_1_activations = np.array(gs_1_activations)
unpert_gs_2_activations = np.array(gs_2_activations)
unpert_gs_concat_activations = np.array(gs_concat_activations)

print(unpert_lstm_h1_activations.shape)
print(unpert_lstm_h2_activations.shape)
print(unpert_gs_1_activations.shape)
print(unpert_gs_2_activations.shape)
print(unpert_gs_concat_activations.shape)

(7, 560, 10, 10)
(7, 560, 10, 10)
(7, 560, 10, 10)
(7, 560, 10, 10)
(7, 560, 10, 20)


### Get perturbed LSTM activations

In [13]:
pert_lstm_h1_act = []
pert_lstm_h2_act = []
pert_gs_1_act = []
pert_gs_2_act = []
pert_gs_concat_act = []

with torch.no_grad():
    for idx, dataloader in enumerate(perturbed_dataloaders):
        lstm_h1_act = [[] for _ in range(WINDOW_SIZE)]
        lstm_h2_act = [[] for _ in range(WINDOW_SIZE)]

        gs_1_act = [[] for _ in range(WINDOW_SIZE)]
        gs_2_act = [[] for _ in range(WINDOW_SIZE)]
        gs_concat_act = [[] for _ in range(WINDOW_SIZE)]

        for batch_window_node_feat, batch_window_edge_idx, batch_window_edge_attr, batch_window_labels in dataloader:
            for count, window_idx in enumerate(range(len(batch_window_node_feat))):
                window_node_feat = batch_window_node_feat[window_idx]
                window_edge_idx = batch_window_edge_idx[window_idx]
                window_edge_attr = batch_window_edge_attr[window_idx]
                window_labels = batch_window_labels[window_idx]
                
                h_1, c_1, h_2, c_2 = None, None, None, None
                for day_idx in range(len(window_node_feat)):
                    day_node_feat = window_node_feat[day_idx]
                    day_edge_idx = window_edge_idx[day_idx]
                    day_edge_attr = window_edge_attr[day_idx]

                    cutoff_idx = day_edge_idx[0].tolist().index(-1)
                    day_edge_idx = day_edge_idx[:, :cutoff_idx]
                    day_edge_attr = day_edge_attr[:cutoff_idx, :]
                    
                    day_graph = Data(x=day_node_feat, edge_index=day_edge_idx, edge_attr=day_edge_attr)
                    y_hat, h_1, c_1, h_2, c_2 = model(day_graph, h_1, c_1, h_2, c_2, day_idx)
                
                    # Accumulate LSTM outputs for each window
                    lstm_h1_act[day_idx].append(h_1.unsqueeze(0))  # h_1 is [10, 10]
                    lstm_h2_act[day_idx].append(h_2.unsqueeze(0))
                    
                    gs_1_act[day_idx].append(model.graphsage1_out.unsqueeze(0))
                    gs_2_act[day_idx].append(model.graphsage2_out.unsqueeze(0))
                    gs_concat_act[day_idx].append(model.graphsage_concat.unsqueeze(0))
        
        lstm_h1_act = [torch.cat(activations, dim=0).numpy() for activations in lstm_h1_act]
        lstm_h2_act = [torch.cat(activations, dim=0).numpy() for activations in lstm_h2_act]
        gs_1_act = [torch.cat(activations, dim=0).numpy() for activations in gs_1_act]
        gs_2_act = [torch.cat(activations, dim=0).numpy() for activations in gs_2_act]
        gs_concat_act = [torch.cat(activations, dim=0).numpy() for activations in gs_concat_act]
        
        lstm_h1_act = np.array(lstm_h1_act)
        lstm_h2_act = np.array(lstm_h2_act)
        gs_1_act = np.array(gs_1_act)
        gs_2_act = np.array(gs_2_act)
        gs_concat_act = np.array(gs_concat_act)

        pert_lstm_h1_act.append(lstm_h1_act)
        pert_lstm_h2_act.append(lstm_h2_act)
        pert_gs_1_act.append(gs_1_act)
        pert_gs_2_act.append(gs_2_act)
        pert_gs_concat_act.append(gs_concat_act)

In [14]:
pert_lstm_h1_act = np.array(pert_lstm_h1_act)
pert_lstm_h2_act = np.array(pert_lstm_h2_act)
pert_gs_1_act = np.array(pert_gs_1_act)
pert_gs_2_act = np.array(pert_gs_2_act)
pert_gs_concat_act = np.array(pert_gs_concat_act)

print(pert_lstm_h1_act.shape)
print(pert_lstm_h2_act.shape)
print(pert_gs_1_act.shape)
print(pert_gs_2_act.shape)
print(pert_gs_concat_act.shape)

(10, 7, 560, 10, 10)
(10, 7, 560, 10, 10)
(10, 7, 560, 10, 10)
(10, 7, 560, 10, 10)
(10, 7, 560, 10, 20)


In [15]:
def plot_lstmcell_activation_per_window_day(unpert_lstm_activations, pert_lstm_activations, lstm_cell_num):
    """
    This function will create a 11x7 figure, where each subplot will be a distribution of the model's LSTM 
    cell activations on one day in the window
    Args:
        - unpert_lstm_activations: Numpy array of shape [7, 512, 10, 10] - 7 is window size, 512 windows in 
            datasets, 10 countries, 10 embedding dimension
        - pert_lstm_activations: Numpy array of shape [10, 7, 512, 10, 10] - 10 perturbed countries, ...
    """
    continents = ["Africa", "North America", "South America", "Oceania", "Eastern Europe", "Western Europe", "Middle East", "South Asia", "Southeast-East Asia", "Central Asia"]
    
    fig, ax = plt.subplots(nrows=11, ncols=WINDOW_SIZE, figsize=(50,40))
    fig.suptitle("DCSAGE Model {} LSTM Cell {} Activations Per Day of Window".format(MODEL_IDX, lstm_cell_num), fontsize= 30)

    for row_idx, row in enumerate(ax):
        for col_idx, col in enumerate(row):
            # col_idx is index of day in 7-day window
            if row_idx == 0:
                activation_values = list(unpert_lstm_activations[col_idx].flatten())
            else:
                activation_values = list(pert_lstm_activations[row_idx - 1][col_idx].flatten())
            
            visual_df = pd.DataFrame({
                "Flattened Activation Values": activation_values,
            })

            sns.histplot(ax=col, x='Flattened Activation Values', data=visual_df, kde=True)
            if row_idx == 0:
                mean = unpert_lstm_activations[col_idx].mean()
                median = np.median(unpert_lstm_activations[col_idx])
                stddev = unpert_lstm_activations[col_idx].std()
                subplot_title = "(Unpert) Day {} (Mean: {:.2f}, Median: {:.2f}, Std: {:.2f})".format(col_idx, mean, median, stddev)
            else:
                mean = pert_lstm_activations[row_idx - 1][col_idx].mean()
                median = np.median(pert_lstm_activations[row_idx - 1][col_idx])
                stddev = pert_lstm_activations[row_idx - 1][col_idx].std()
                subplot_title = "({} Pert) Day {} \n(Mean: {:.2f}, Median: {:.2f}, Std: {:.2f})".format(continents[row_idx - 1], col_idx, mean, median, stddev)

            col.set_title(subplot_title)
            col.set_xlim([-1, 1])

    filename = "model_{}_lstmcell_{}_activ_per_window_day".format(MODEL_IDX, lstm_cell_num)
    fig.tight_layout()
    plt.savefig("./" + filename + '.png', bbox_inches='tight', facecolor="white")
    plt.clf()
    plt.close()

In [16]:
plot_lstmcell_activation_per_window_day(unpert_lstm_h1_activations, pert_lstm_h1_act, lstm_cell_num=1)
plot_lstmcell_activation_per_window_day(unpert_lstm_h2_activations, pert_lstm_h2_act, lstm_cell_num=2)

### Plot GraphSAGE activation distributions unperturbed vs perturbed

In [17]:
def plot_graphsage_activation_distrib_unpert_vs_pert(unpert_lstm_activations, pert_lstm_activations, gs_layer_id, x_bound:int, y_bound:int):
    """
    This function will create a 7x1 figure, where each subplot will be activation vs timestep (day_idx) for
    one of the embedding dimension positions in the LSTM cell.

    *Important node: This function will just plot 1st day of each window. This will give 512 separate days, rather
    than plotting windows that overlap days.

    Args:
        - unpert_lstm_activations: Numpy array of shape [7, 512, 10, 10] - 7 is window size, 512 windows in 
            datasets, 10 countries, 10 embedding dimension
        - pert_lstm_activations: Numpy array of shape [10, 7, 512, 10, 10] - 10 perturbed countries, ...
    """
    continents = ["Africa", "North America", "South America", "Oceania", "Eastern Europe", "Western Europe", "Middle East", "South Asia", "Southeast-East Asia", "Central Asia"]
    
    fig, ax = plt.subplots(nrows=11, ncols=10, figsize=(50,60))
    fig.suptitle("DCSAGE Model {} GraphSAGE Layer {} Activations".format(MODEL_IDX, gs_layer_id), fontsize= 30)

    for row_idx, row in enumerate(ax):
        for col_idx, col in enumerate(row):
            # col_idx is index of the country being affected by perturbation
            if col_idx == row_idx - 1:
                continue  # Don't plot a country when it is the one being perturbed

            if row_idx == 0:
                activation_values = list(unpert_lstm_activations[0, :, col_idx, :].flatten())
            else:
                activation_values = list(pert_lstm_activations[row_idx - 1, 0, :, col_idx, :].flatten())
            
            visual_df = pd.DataFrame({
                "Flattened Activation Values": activation_values,
            })

            sns.histplot(ax=col, x='Flattened Activation Values', data=visual_df, kde=True)
            if row_idx == 0:
                mean = unpert_lstm_activations[0, :, col_idx, :].mean()
                median = np.median(unpert_lstm_activations[0, :, col_idx, :])
                stddev = unpert_lstm_activations[0, :, col_idx, :].std()
                subplot_title = "(Unpert) {} Distrib \n(Mean: {:.2f}, Median: {:.2f}, Std: {:.2f})".format(continents[col_idx], mean, median, stddev)
            else:
                mean = pert_lstm_activations[row_idx - 1, 0, :, col_idx, :].mean()
                median = np.median(pert_lstm_activations[row_idx - 1, 0, :, col_idx, :])
                stddev = pert_lstm_activations[row_idx - 1, 0, :, col_idx, :].std()
                subplot_title = "({} Pert) {} Distrib \n(Mean: {:.2f}, Median: {:.2f}, Std: {:.2f})".format(continents[row_idx - 1], continents[col_idx], mean, median, stddev)

            col.set_title(subplot_title)
            col.set_xlim([-1 * x_bound, x_bound])
            col.set_ylim([0, y_bound])

    filename = "model_{}_gslayer_{}_activ".format(MODEL_IDX, gs_layer_id)
    fig.tight_layout()
    plt.savefig("./" + filename + '.png', bbox_inches='tight', facecolor="white")
    plt.clf()
    plt.close()

In [18]:
plot_graphsage_activation_distrib_unpert_vs_pert(unpert_gs_1_activations, pert_gs_1_act, "1", 10, 700)
plot_graphsage_activation_distrib_unpert_vs_pert(unpert_gs_2_activations, pert_gs_2_act, "2", 2, 700)
plot_graphsage_activation_distrib_unpert_vs_pert(unpert_gs_concat_activations, pert_gs_concat_act, "concat", 3, 1000)

# Plot GraphSAGE TSNE embeddings

In [None]:
countries = ["Brazil", "Germany", "Spain", "France", "Britain", "India", "Italy", "Russia", "Turkey", "USA"]

def plot_graphsage_embedding(one_day_emb, title, savename):
    assert one_day_emb.shape == torch.Size([10, 10]), "Unrecognized shape"
    reduced_graphsage_emb = TSNE(n_components=2, learning_rate="auto", init="random").fit_transform(one_day_emb)
    # print(reduced_graphsage_emb)
    
    fig = sns.scatterplot(x=reduced_graphsage_emb[:,0], y=reduced_graphsage_emb[:,1])
    for i in range(10):
        fig.text(x=reduced_graphsage_emb[i,0] + 0.2, y=reduced_graphsage_emb[i,1] + 0.5, s=countries[i])
        # print(reduced_graphsage_emb[i,0] + 0.2, reduced_graphsage_emb[i,1] + 0.5)

    plt.title(title)
    plt.xlabel("Recuced Dimension X")
    plt.ylabel("Recuced Dimension Y")
    plt.savefig(savename, bbox_inches="tight", facecolor="white")
    plt.show()

In [None]:
with torch.no_grad():
    for batch_window_node_feat, batch_window_edge_idx, batch_window_edge_attr, batch_window_labels in unperturbed_dataloader:
        for window_idx in range(500, 510):
            window_node_feat = batch_window_node_feat[window_idx]
            window_edge_idx = batch_window_edge_idx[window_idx]
            window_edge_attr = batch_window_edge_attr[window_idx]
            window_labels = batch_window_labels[window_idx]

            h_1, c_1, h_2, c_2 = None, None, None, None
            day_idx = 0
            day_node_feat = window_node_feat[day_idx]
            day_edge_idx = window_edge_idx[day_idx]
            day_edge_attr = window_edge_attr[day_idx]

            cutoff_idx = day_edge_idx[0].tolist().index(-1)
            day_edge_idx = day_edge_idx[:, :cutoff_idx]
            day_edge_attr = day_edge_attr[:cutoff_idx, :]
            
            day_graph = Data(x=day_node_feat, edge_index=day_edge_idx, edge_attr=day_edge_attr)
            y_hat, h_1, c_1, h_2, c_2 = model(day_graph, h_1, c_1, h_2, c_2, day_idx)

            plot_graphsage_embedding(model.graphsage1_out, 
                title="GraphSAGE 1 Embedding Window {} Model {}".format(window_idx, MODEL_IDX), 
                savename="win{}_model{}_graphsage_1_emb.png".format(window_idx, MODEL_IDX))
            plot_graphsage_embedding(model.graphsage2_out, 
                title="GraphSAGE 2 Embedding Window {} Model {}".format(window_idx, MODEL_IDX), 
                savename="win{}_model{}_graphsage_2_emb.png".format(window_idx, MODEL_IDX))

In [None]:
def plot_graphsage_embedding_10days(one_day_emb_list, title, savename):
    reduced_graphsage_emb_list = []
    for i in range(len(one_day_emb_list)):
        reduced_graphsage_emb = TSNE(n_components=2, learning_rate="auto", init="random").fit_transform(one_day_emb_list[i])
        reduced_graphsage_emb_list.append(reduced_graphsage_emb)
    
    country_x = []
    country_y = []
    country_name = []
    
    for i in range(10):
        country_x += [reduced_emb[i,0] for reduced_emb in reduced_graphsage_emb_list]
        country_y += [reduced_emb[i,1] for reduced_emb in reduced_graphsage_emb_list]
        country_name += [countries[i]] * len(reduced_graphsage_emb_list)

    visual_df = pd.DataFrame({
        "Reduced Dimension X": country_x,
        "Reduced Dimension Y": country_y,
        "Country": country_name
    })

    sns.set_theme()
    sns.scatterplot(data=visual_df, x="Reduced Dimension X", y="Reduced Dimension Y", hue="Country")
    plt.title(title)
    plt.xlabel("Recuced Dimension X")
    plt.ylabel("Recuced Dimension Y")
    plt.savefig(savename, bbox_inches="tight", facecolor="white")
    plt.show()

In [None]:
with torch.no_grad():
    for batch_window_node_feat, batch_window_edge_idx, batch_window_edge_attr, batch_window_labels in unperturbed_dataloader:
        graphsage1_list = []
        graphsage2_list = []
        for window_idx in range(410, 510):
            window_node_feat = batch_window_node_feat[window_idx]
            window_edge_idx = batch_window_edge_idx[window_idx]
            window_edge_attr = batch_window_edge_attr[window_idx]
            window_labels = batch_window_labels[window_idx]

            h_1, c_1, h_2, c_2 = None, None, None, None
            day_idx = 0
            day_node_feat = window_node_feat[day_idx]
            day_edge_idx = window_edge_idx[day_idx]
            day_edge_attr = window_edge_attr[day_idx]

            cutoff_idx = day_edge_idx[0].tolist().index(-1)
            day_edge_idx = day_edge_idx[:, :cutoff_idx]
            day_edge_attr = day_edge_attr[:cutoff_idx, :]
            
            day_graph = Data(x=day_node_feat, edge_index=day_edge_idx, edge_attr=day_edge_attr)
            y_hat, h_1, c_1, h_2, c_2 = model(day_graph, h_1, c_1, h_2, c_2, day_idx)

            graphsage1_list.append(model.graphsage1_out)
            graphsage2_list.append(model.graphsage2_out)

        plot_graphsage_embedding_10days(graphsage1_list, 
            title="GraphSAGE 1 Embedding Windows 410-510 Model {}".format(MODEL_IDX), 
            savename="100days410-510_model{}_graphsage_1_emb.png".format(MODEL_IDX))
        plot_graphsage_embedding_10days(graphsage2_list, 
            title="GraphSAGE 2 Embedding Window 410-510 Model {}".format(MODEL_IDX), 
            savename="100days410-510_model{}_graphsage_2_emb.png".format(MODEL_IDX))