# Import libraries

In [1]:
import itertools
import pandas as pd
import numpy as np
import networkx as nx
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
import torch_geometric.nn as nng 
from sklearn.metrics import roc_auc_score
from scipy.sparse import coo_matrix
from torch_geometric_temporal.signal import DynamicGraphTemporalSignal, temporal_signal_split
import matplotlib.pyplot as plt
import seaborn as sns

# Build satellite graph

## Load satellite graph node data

In [2]:
reduced, frac1 = False, 0.25
reduced_sample_alt_e, frac2, min_alt, max_alt, e_thres, sampled1 = False, 1.0, 500, 600, 0.2, False 
reduced_sample_leos, frac3, leo, sampled2,  = True, 0.25, 'leo4', True  # smallest LEO
if reduced:
    nodes_savepath = f"../datasets/space-track-ap2-graph-node-feats-reduced-{int(frac1 * 100)}.csv"
elif reduced_sample_alt_e:
    if sampled1:
        nodes_savepath = f"../datasets/space-track-ap2-graph-node-feats-reduced-{int(frac2 * 100)}-h-{min_alt}-{max_alt}-e-{int(e_thres * 100)}.csv"
    else:
        nodes_savepath = f"../datasets/space-track-ap2-graph-node-feats-reduced-h-{min_alt}-{max_alt}-e-{int(e_thres * 100)}.csv"
elif reduced_sample_leos:
    if sampled2:
        nodes_savepath = f"../datasets/space-track-ap2-graph-node-feats-{leo}-reduced-{int(frac3 * 100)}.csv"
    else:
        nodes_savepath = f"../datasets/space-track-ap2-graph-node-feats-{leo}.csv"
else:
    nodes_savepath = '../datasets/space-track-ap2-graph-node-feats.csv'

nodes_df = pd.read_csv(nodes_savepath, memory_map=True).set_index('NORAD_CAT_ID').drop(['OBJECT_NAME', 'OBJECT_ID', 'DECAY_DATE', 'CENTER_NAME', 'REF_FRAME', 'TIME_SYSTEM', 'MEAN_ELEMENT_THEORY'], axis=1)
nodes_df.head()

Unnamed: 0_level_0,MEAN_MOTION,ECCENTRICITY,INCLINATION,RA_OF_ASC_NODE,ARG_OF_PERICENTER,MEAN_ANOMALY,EPHEMERIS_TYPE,CLASSIFICATION_TYPE,REV_AT_EPOCH,BSTAR,...,OBJECT_TYPE,RCS_SIZE,CONSTELLATION_DISCOS_ID,PX,PY,PZ,VX,VY,VZ,TIMESTAMP
NORAD_CAT_ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
53,12.173182,0.009849,47.2749,245.136139,290.259668,93.520449,0,U,82552,0.001154,...,DEBRIS,MEDIUM,,-981.739653,-7532.454068,2463.44598,5.185705,0.803544,4.730939,2023-12-28 00:00:00
1314,12.917792,0.003186,90.2439,346.489073,121.537524,93.04766,0,U,51454,6.7e-05,...,PAYLOAD,LARGE,,-6109.304395,1487.265525,-4410.214654,4.005457,-0.936573,-5.909974,2023-12-28 00:00:00
1570,12.527797,0.01016,56.0579,38.416745,177.606018,281.664377,0,U,66873,0.000116,...,PAYLOAD,MEDIUM,,-3549.131956,2696.045132,6411.820026,-5.164112,-4.86615,-0.897586,2023-12-28 00:00:00
1573,12.410282,0.006615,56.0518,218.083507,320.617156,261.950267,0,U,64321,0.000149,...,PAYLOAD,MEDIUM,,2808.8364,5941.159037,-4370.62668,-5.568033,-0.634834,-4.361757,2023-12-28 00:00:00
1574,12.370203,0.006564,56.0539,282.997424,126.316083,331.609578,0,U,63479,0.000209,...,PAYLOAD,MEDIUM,,4000.103782,1984.830449,6449.951328,-2.117379,6.781488,-0.799534,2023-12-28 00:00:00


## One-hot encode Categorical columns

In [3]:
nodes_df = pd.get_dummies(nodes_df, columns=['EPHEMERIS_TYPE', 'CLASSIFICATION_TYPE', 'OBJECT_TYPE', 'RCS_SIZE','CONSTELLATION_DISCOS_ID'], drop_first=False, dummy_na=True, dtype=float)
nodes_df.head()

Unnamed: 0_level_0,MEAN_MOTION,ECCENTRICITY,INCLINATION,RA_OF_ASC_NODE,ARG_OF_PERICENTER,MEAN_ANOMALY,REV_AT_EPOCH,BSTAR,MEAN_MOTION_DOT,MEAN_MOTION_DDOT,...,OBJECT_TYPE_nan,RCS_SIZE_LARGE,RCS_SIZE_MEDIUM,RCS_SIZE_SMALL,RCS_SIZE_nan,CONSTELLATION_DISCOS_ID_3.0,CONSTELLATION_DISCOS_ID_4.0,CONSTELLATION_DISCOS_ID_5.0,CONSTELLATION_DISCOS_ID_7.0,CONSTELLATION_DISCOS_ID_nan
NORAD_CAT_ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
53,12.173182,0.009849,47.2749,245.136139,290.259668,93.520449,82552,0.001154,2.917841e-11,0.0,...,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0
1314,12.917792,0.003186,90.2439,346.489073,121.537524,93.04766,51454,6.7e-05,2.298905e-11,0.0,...,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0
1570,12.527797,0.01016,56.0579,38.416745,177.606018,281.664377,66873,0.000116,-6.896714e-11,0.0,...,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0
1573,12.410282,0.006615,56.0518,218.083507,320.617156,261.950267,64321,0.000149,-6.366198e-11,0.0,...,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0
1574,12.370203,0.006564,56.0539,282.997424,126.316083,331.609578,63479,0.000209,-5.747262e-11,0.0,...,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0


## Load satellite graph edges data

In [7]:
edges_df = pd.read_csv(nodes_savepath.replace('node-feats', 'edges'), memory_map=True)
edges_df.head()

Unnamed: 0,source,target,weight,r_dist,it_dist,ct_dist,dist,timestamp,prop
0,1585,13378,1,14413800.0,3413222.0,1658.078733,14812420.0,2023-12-28 00:00:00,True
1,13538,16456,1,326138.5,2032948.0,1392.352587,2058943.0,2023-12-28 00:00:00,True
2,13538,19785,1,1241162.0,7666014.0,1561.350018,7765839.0,2023-12-28 00:00:00,True
3,13766,16456,1,41634.67,11242640.0,636.122198,11242720.0,2023-12-28 00:00:00,True
4,14139,38736,1,554878.8,696.3717,533571.518932,769798.3,2023-12-28 00:00:00,True


In [5]:
timestamps = edges_df['timestamp'].unique()
timestamps

array(['2023-12-28 00:00:00', '2023-12-28 01:00:00',
       '2023-12-28 02:00:00', '2023-12-28 03:00:00',
       '2023-12-28 04:00:00', '2023-12-28 05:00:00',
       '2023-12-28 06:00:00', '2023-12-28 07:00:00',
       '2023-12-28 08:00:00', '2023-12-28 09:00:00',
       '2023-12-28 10:00:00', '2023-12-28 11:00:00',
       '2023-12-28 12:00:00', '2023-12-28 13:00:00',
       '2023-12-28 14:00:00', '2023-12-28 15:00:00',
       '2023-12-28 16:00:00', '2023-12-28 17:00:00',
       '2023-12-28 18:00:00', '2023-12-28 19:00:00',
       '2023-12-28 20:00:00', '2023-12-28 21:00:00',
       '2023-12-28 22:00:00', '2023-12-28 23:00:00'], dtype=object)

# Satellite Conjunction Prediction through Link Prediction

## Separate satellite graph edges into train set and test set

In [10]:
def temporal_signal_split(timestamps, train_ratio=0.8):
    train_snapshots = int(train_ratio * timestamps.shape[0])
    # necessary to reduce len(train_timestamps) timestamps to 1 such that there is only one feature matrix
    train_snapshots = train_snapshots-1 if train_snapshots % 2 == 0 else train_snapshots
    
    train_timestamps = timestamps[0:train_snapshots]
    test_timestamps = timestamps[train_snapshots:]
    
    return train_timestamps, test_timestamps

train_timestamps, test_timestamps = temporal_signal_split(timestamps)

train_edges_df = edges_df[edges_df['timestamp'].isin(train_timestamps)]
test_edges_df = edges_df[edges_df['timestamp'].isin(test_timestamps)]

print(
    f"Number of total edges: {edges_df.shape[0]}\n"
    f"Number of total edges in training set: {train_edges_df.shape[0]}\n"
    f"Number of total edges in test set: {test_edges_df.shape[0]}"
)

Number of total edges: 4746
Number of total edges in training set: 3725
Number of total edges in test set: 1021


## Separate each set edges into positive edges and negative edges

In [6]:
nodes = list(nodes_df.index.unique())

In [7]:
def sample_negative_edges_df_for_dt(nodes, positive_edges_df_in_dt, date_time):
    positive_set = set(positive_edges_df_in_dt[["source", "target"]].itertuples(index=False, name=None))

    def valid_neg_edge(src, tgt):
        return (
            # no self-loops
            src != tgt
            and
            # neither direction of the edge should be a positive one
            (src, tgt) not in positive_set
            and (tgt, src) not in positive_set
        )
    # TODO: Define edge weight and assing random weight here
    possible_neg_edges = [
        [src, tgt, 1, date_time] for src in nodes for tgt in nodes if valid_neg_edge(src, tgt)
    ]
    neg_edges = np.array(random.sample(possible_neg_edges, k=len(positive_set)))
    return {'source':neg_edges[:, 0].tolist(), 'target':neg_edges[:, 1].tolist(), 'weight':neg_edges[:, 2].tolist(), 'timestamp':neg_edges[:, 3].tolist()}

def sample_negative_edges_df(nodes, positive_edges_df, timestamps):
    edges = {'source':[], 'target':[], 'weight':[], 'timestamp':[]}
    for i in range(len(timestamps)):
        date_time = timestamps[i]
        edges_data = sample_negative_edges_df_for_dt(nodes, positive_edges_df[positive_edges_df['timestamp'] == date_time], date_time)
        edges['source'] = edges['source'] + edges_data['source']
        edges['target'] = edges['target'] + edges_data['target']
        edges['weight'] = edges['weight'] + edges_data['weight']
        edges['timestamp'] = edges['timestamp'] + edges_data['timestamp']
    edges_df = pd.DataFrame(edges)
    edges_df['source'] = edges_df['source'].astype(np.int64)
    edges_df['target'] = edges_df['target'].astype(np.int64)
    edges_df['weight'] = edges_df['weight'].astype(np.int64)
    edges_df['timestamp'] = pd.to_datetime(edges_df['timestamp'])
    return edges_df

# train_pos_edges_df = train_edges_df[['source', 'target', 'weight', 'timestamp']]
# train_neg_edges_df = sample_negative_edges_df(nodes, train_pos_edges_df, train_timestamps)
# 
# test_pos_edges_df = test_edges_df[['source', 'target', 'weight', 'timestamp']]
# test_neg_edges_df = sample_negative_edges_df(nodes, test_pos_edges_df, test_timestamps)
# 
# print(
#     f"Number of total positive edges in training set: {train_pos_edges_df.shape[0]}\n"
#     f"Number of total negative edges in training set: {train_neg_edges_df.shape[0]}\n"
#     f"Number of total positive edges in test set: {test_pos_edges_df.shape[0]}\n"
#     f"Number of total negative edges in test set: {test_neg_edges_df.shape[0]}\n"
# )
pos_edges_df = edges_df[['source', 'target', 'weight', 'timestamp']]
neg_edges_df = sample_negative_edges_df(nodes, pos_edges_df, timestamps)

print(
    f"Number of total positive edges in dataset: {pos_edges_df.shape[0]}\n"
    f"Number of total negative edges in dataset: {neg_edges_df.shape[0]}\n"
)

Number of total positive edges in dataset: 4746
Number of total negative edges in dataset: 4746


## Build positive and negative dynamic graph static signal data iterator

In [15]:
def edges_df_to_torch_data(num_nodes, node_index, nodes_df, edges_df):
    x = torch.tensor(nodes_df.values[:, :].astype(float))
    
    # Convert DataFrame to COO format
    row = edges_df['source'].map(node_index.get)
    col = edges_df['target'].map(node_index.get)
    data = [1] * len(edges_df)
    coo = coo_matrix((data, (row, col)), shape=(num_nodes, num_nodes))
    edge_index = torch.tensor(np.array([coo.row, coo.col]), dtype=torch.long)
    
    edge_attr = torch.tensor(edges_df['weight'].values.astype(float))
    
    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr)

def to_torch_data(timestamps, num_nodes, node_index, nodes_df, edges_df):
    torch_datas = []
    for ts in timestamps:
        nodes_df_ts = nodes_df[nodes_df['TIMESTAMP'] == ts]
        edges_df_ts = edges_df[edges_df['timestamp'] == ts]

        torch_data = edges_df_to_torch_data(num_nodes, node_index, nodes_df_ts.drop('TIMESTAMP', axis=1), edges_df_ts)
        torch_datas.append(torch_data)
    return torch_datas

num_nodes = len(nodes)
node_index = {node: i for i, node in enumerate(nodes)}

train_data = to_torch_data(train_timestamps, num_nodes, node_index, nodes_df, train_pos_edges_df)
train_pos_edges_data = train_data
train_neg_edges_data = to_torch_data(train_timestamps, num_nodes, node_index, nodes_df, train_neg_edges_df)

test_pos_edges_data = to_torch_data(test_timestamps, num_nodes, node_index, nodes_df, test_pos_edges_df)
test_neg_edges_data = to_torch_data(test_timestamps, num_nodes, node_index, nodes_df, test_neg_edges_df)

In [8]:
def feat_idx_w(num_nodes, node_index, nodes_df, edges_df):
    x = nodes_df.values[:, :].astype(float)
    
    # Convert DataFrame to COO format
    row = edges_df['source'].map(node_index.get)
    col = edges_df['target'].map(node_index.get)
    data = [1] * len(edges_df)
    coo = coo_matrix((data, (row, col)), shape=(num_nodes, num_nodes))
    edge_index = np.array([coo.row, coo.col], dtype=np.int64)
    
    edge_attr = edges_df['weight'].values.astype(float)
    
    return x, edge_index, edge_attr

def to_feats_idxs_ws(timestamps, num_nodes, node_index, nodes_df, edges_df):
    features = []
    edge_indices = []
    edge_weights = []
    for ts in timestamps:
        nodes_df_ts = nodes_df[nodes_df['TIMESTAMP'] == ts]
        edges_df_ts = edges_df[edges_df['timestamp'] == ts]

        x, edge_index, edge_attr = feat_idx_w(num_nodes, node_index, nodes_df_ts.drop('TIMESTAMP', axis=1), edges_df_ts)
        features.append(x)
        edge_indices.append(edge_index)
        edge_weights.append(edge_attr)
    return features, edge_indices, edge_weights

num_nodes = len(nodes)
node_index = {node: i for i, node in enumerate(nodes)}

pos_features, pos_edge_indices, pos_edge_weights = to_feats_idxs_ws(timestamps, num_nodes, node_index, nodes_df, pos_edges_df)
pos_dataset = DynamicGraphTemporalSignal(pos_edge_indices, pos_edge_weights, pos_features, [None]*len(pos_features))

neg_features, neg_edge_indices, neg_edge_weights = to_feats_idxs_ws(timestamps, num_nodes, node_index, nodes_df, neg_edges_df)
neg_dataset = DynamicGraphTemporalSignal(neg_edge_indices, neg_edge_weights, neg_features, [None]*len(neg_features))

In [9]:
train_pos_dataset, test_pos_dataset = temporal_signal_split(pos_dataset, train_ratio=0.8)
train_neg_dataset, test_neg_dataset = temporal_signal_split(neg_dataset, train_ratio=0.8)

print(
    f"Number of snapshots in train set of positive edges : {train_pos_dataset.snapshot_count}\n"
    f"Number of snapshots in train set of negative edges : {train_neg_dataset.snapshot_count}\n"
    f"Number of snapshots in test set of positive edges : {test_pos_dataset.snapshot_count}\n"
    f"Number of snapshots in test set of negative edges : {test_neg_dataset.snapshot_count}\n"
)

Number of snapshots in train set of positive edges : 19
Number of snapshots in train set of negative edges : 19
Number of snapshots in test set of positive edges : 5
Number of snapshots in test set of negative edges : 5


## Spatial-Temporal Graph Neural Network (STGNN)

### RNN-based approach

### CNN-based approach

In [10]:
class TemporalConv(nn.Module):
    r""" Implementation from PyTorch Geometric Temporal.
    Temporal convolution block applied to nodes in the STGCN Layer
    For details see: `"Spatio-Temporal Graph Convolutional Networks:
    A Deep Learning Framework for Traffic Forecasting."
    <https://arxiv.org/abs/1709.04875>`_ Based off the temporal convolution
     introduced in "Convolutional Sequence to Sequence Learning"  <https://arxiv.org/abs/1709.04875>`_

    Args:
        in_channels (int): Number of input features.
        out_channels (int): Number of output features.
        kernel_size (int): Convolutional kernel size.
    """

    def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3):
        super(TemporalConv, self).__init__()
        self.conv_1 = nn.Conv2d(in_channels, out_channels, (1, kernel_size))
        self.conv_2 = nn.Conv2d(in_channels, out_channels, (1, kernel_size))
        self.conv_3 = nn.Conv2d(in_channels, out_channels, (1, kernel_size))

    def forward(self, X: torch.FloatTensor) -> torch.FloatTensor:
        """Forward pass through temporal convolution block.

        Arg types:
            * **X** (torch.FloatTensor) -  Input data of shape
                (batch_size, input_time_steps, num_nodes, in_channels).

        Return types:
            * **H** (torch.FloatTensor) - Output data of shape
                (batch_size, in_channels, num_nodes, input_time_steps).
        """
        # X = X.permute(0, 3, 2, 1)
        # P = self.conv_1(X)
        # Q = torch.sigmoid(self.conv_2(X))
        # PQ = P * Q
        # H = F.relu(PQ + self.conv_3(X))
        # H = H.permute(0, 3, 2, 1)
        # return H
        X = X.permute(0, 3, 2, 1)
        P = self.conv_1(X)
        Q = torch.sigmoid(self.conv_2(X))
        PQ = P * Q
        H = F.relu(PQ + self.conv_3(X))
        H = H.permute(0, 3, 2, 1)
        return H

In [23]:
class STConv(nn.Module):
    r"""Spatio-temporal convolution block using ChebConv Graph Convolutions.
    For details see: `"Spatio-Temporal Graph Convolutional Networks:
    A Deep Learning Framework for Traffic Forecasting"
    <https://arxiv.org/abs/1709.04875>`_

    NB. The ST-Conv block contains two temporal convolutions (TemporalConv)
    with kernel size k. Hence for an input sequence of length m,
    the output sequence will be length m-2(k-1).

    Args:
        in_channels (int): Number of input features.
        hidden_channels (int): Number of hidden units output by graph convolution block
        out_channels (int): Number of output features.
        kernel_size (int): Size of the kernel considered.
        K (int): Chebyshev filter size :math:`K`.
        normalization (str, optional): The normalization scheme for the graph
            Laplacian (default: :obj:`"sym"`):

            1. :obj:`None`: No normalization
            :math:`\mathbf{L} = \mathbf{D} - \mathbf{A}`

            2. :obj:`"sym"`: Symmetric normalization
            :math:`\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1/2} \mathbf{A}
            \mathbf{D}^{-1/2}`

            3. :obj:`"rw"`: Random-walk normalization
            :math:`\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1} \mathbf{A}`

            You need to pass :obj:`lambda_max` to the :meth:`forward` method of
            this operator in case the normalization is non-symmetric.
            :obj:`\lambda_max` should be a :class:`torch.Tensor` of size
            :obj:`[num_graphs]` in a mini-batch scenario and a
            scalar/zero-dimensional tensor when operating on single graphs.
            You can pre-compute :obj:`lambda_max` via the
            :class:`torch_geometric.transforms.LaplacianLambdaMax` transform.
        bias (bool, optional): If set to :obj:`False`, the layer will not learn
            an additive bias. (default: :obj:`True`)

    """

    def __init__(
        self,
        graph_conv: nn.Module,
        num_nodes: int,
        in_channels: int,
        hidden_channels: int,
        out_channels: int,
        kernel_size: int,
        K: int,
        normalization: str = "sym",
        bias: bool = True,
    ):
        super(STConv, self).__init__()
        self.num_nodes = num_nodes
        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.K = K
        self.normalization = normalization
        self.bias = bias

        self._temporal_conv1 = TemporalConv(
            in_channels=in_channels,
            out_channels=hidden_channels,
            kernel_size=kernel_size,
        )

        #self._graph_conv = graph_conv
        
        self._graph_conv = nng.ChebConv(
            in_channels=hidden_channels,
            out_channels=hidden_channels,
            K=K,
            normalization=normalization,
            bias=bias, 
        )

        self._temporal_conv2 = TemporalConv(
            in_channels=hidden_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
        )

        self._batch_norm = nn.BatchNorm2d(num_nodes)

    def forward(
        self,
        X: torch.FloatTensor,
        edge_index: torch.LongTensor,
        edge_weight: torch.FloatTensor = None,
    ) -> torch.FloatTensor:

        r"""Forward pass. If edge weights are not present the forward pass
        defaults to an unweighted graph.

        Arg types:
            * **X** (PyTorch FloatTensor) - Sequence of node features of shape (Batch size X Input time steps X Num nodes X In channels).
            * **edge_index** (PyTorch LongTensor) - Graph edge indices.
            * **edge_weight** (PyTorch LongTensor, optional)- Edge weight vector.

        Return types:
            * **T** (PyTorch FloatTensor) - Sequence of node features.
        """
        T_0 = self._temporal_conv1(X)
        T = torch.zeros_like(T_0).to(T_0.device)
        for b in range(T_0.size(0)):
            for t in range(T_0.size(1)):
                T[b][t] = self._graph_conv(T_0[b][t], edge_index, edge_weight) #original used in fst template loop
                #T[b][t] = self._graph_conv(T_0[b][t], edge_index[t], edge_weight[t]) #used in snd template loop

        T = F.relu(T)
        T = self._temporal_conv2(T)
        T = T.permute(0, 2, 1, 3)
        T = self._batch_norm(T)
        T = T.permute(0, 2, 1, 3)
        return T

In [24]:
# Our final classifier applies the dot-product between source and destination
# node embeddings to derive edge-level predictions:
class Classifier(torch.nn.Module):
    def forward(self, x:torch.Tensor, edge_index: torch.LongTensor) -> torch.Tensor:
        # Get node embeddings
        u_feat = x[edge_index[0]]
        v_feat = x[edge_index[1]]

        # Apply dot-product to get a prediction per supervision edge:
        return (u_feat * v_feat).sum(dim=-1)

In [25]:
class STGNN(nn.Module):
    def __init__(self):
        super(STGNN, self).__init__()
        pass
    
    def forward(self):
        pass

## Train the STGNN

In [26]:
def compute_loss(pos_score, neg_score):
    scores = torch.cat([pos_score, neg_score])
    labels = torch.cat(
        [torch.ones(pos_score.shape[0]), torch.zeros(neg_score.shape[0])]
    )
    return F.binary_cross_entropy_with_logits(scores, labels)

## Template for training loop when iterating through snapshots

In [27]:
tmp_data = next(iter(train_pos_dataset))
tmp_data

Data(x=[801, 37], edge_index=[2, 165], edge_attr=[165])

In [31]:
tmp_neg_data = next(iter(train_neg_dataset))
tmp_neg_data

Data(x=[801, 37], edge_index=[2, 165], edge_attr=[165])

In [33]:
tmp_model = STConv(None, num_nodes, nodes_df.shape[1]-1, 12, 2, 1, 2)
tmp_cls = Classifier()
torch.no_grad()

tmp_h = tmp_model(tmp_data.x.unsqueeze(0).unsqueeze(0), tmp_data.edge_index, tmp_data.edge_attr)

pos_score = tmp_cls(tmp_h.squeeze(), tmp_data.edge_index)
neg_score = tmp_cls(tmp_h.squeeze(), tmp_neg_data.edge_index)

print(f'Pos score: {pos_score}')
print(f'Neg score: {neg_score}')
print(f'Loss: {compute_loss(pos_score, neg_score)}')
print(f'Pos probabilities: {torch.sigmoid(pos_score)}')

Pos score: tensor([ 2.0000,  2.0000,  2.0000,  2.0000,  2.0000,  0.0000, -2.0000,  2.0000,
         2.0000,  2.0000,  2.0000,  2.0000,  2.0000,  2.0000,  2.0000,  2.0000,
         2.0000,  2.0000,  0.0000,  2.0000,  2.0000,  2.0000,  0.0000,  2.0000,
        -2.0000, -2.0000,  2.0000,  2.0000, -2.0000,  2.0000, -2.0000, -2.0000,
        -2.0000,  2.0000,  2.0000, -2.0000,  2.0000,  0.0000, -2.0000, -2.0000,
         0.0000,  2.0000,  0.0000,  2.0000, -2.0000,  2.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  2.0000,  0.0000, -2.0000, -2.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  2.0000,  2.0000,  0.0000,
         2.0000, -2.0000,  0.0000,  2.0000,  2.0000,  0.0000,  0.0000,  2.0000,
         2.0000,  2.0000,  0.0000,  2.0000,  2.0000,  2.0000,  2.0000,  0.0000,
         2.0000, -2.0000, -2.0000,  2.0000,  2.0000,  0.0000,  2.0000, -2.0000,
         2.0000,  2.0000,  2.0000,  0.0000,  0.0000,  2.0000,  2.0000,  2.0000,
         2.0000, -2.0000,  2.

## Template for training loop when feeding all snapshots in an epoch

In [110]:
num_timestamps = len(train_timestamps) # ((num_timestamps-1)/2)+1 = 1 feature matrix for one timestamp
tmp_model = STConv(None, num_nodes, nodes_df.shape[1]-1, 16, 12, int(((num_timestamps-1)/2)+1), 3).double()
tmp_cls = Classifier()
torch.no_grad()

tmp_data = {'x':torch.stack([train_data[i].x for i in range(num_timestamps)], dim=0).unsqueeze(0),
            'edge_index':[train_data[i].edge_index for i in range(num_timestamps)],
            'edge_attr':[train_data[i].edge_attr for i in range(num_timestamps)]}
tmp_h = tmp_model(tmp_data['x'], tmp_data['edge_index'], tmp_data['edge_attr']).squeeze()

for t in range(num_timestamps):
    pos_score = tmp_cls(tmp_h, train_pos_edges_data[t].edge_index)
    neg_score = tmp_cls(tmp_h, train_neg_edges_data[t].edge_index)
    print('#'*100)
    print(f'Timestamp: {train_timestamps[t]}')
    print(f'Pos score: {pos_score}')
    print(f'Neg score: {neg_score}')
    print(f'Loss: {compute_loss(pos_score, neg_score)}')
    print(f'Pos probabilities: {torch.sigmoid(pos_score)}')
    print(f'Neg probabilities: {torch.sigmoid(pos_score)}')

####################################################################################################
Timestamp: 2023-12-28
Pos score: tensor([11.8570,  9.9837, 10.7632], dtype=torch.float64,
       grad_fn=<SumBackward1>)
Neg score: tensor([10.7834,  6.2852,  7.7267], dtype=torch.float64,
       grad_fn=<SumBackward1>)
Loss: 4.1329498291015625
Pos probabilities: tensor([1.0000, 1.0000, 1.0000], dtype=torch.float64,
       grad_fn=<SigmoidBackward0>)
Neg probabilities: tensor([1.0000, 1.0000, 1.0000], dtype=torch.float64,
       grad_fn=<SigmoidBackward0>)
####################################################################################################
Timestamp: 2023-12-29
Pos score: tensor([11.9501, 11.4129,  3.9887, 11.8570, 10.0561,  9.9837, 11.6577],
       dtype=torch.float64, grad_fn=<SumBackward1>)
Neg score: tensor([5.1827, 8.8211, 1.8530, 7.1052, 6.8421, 7.4928, 5.0643],
       dtype=torch.float64, grad_fn=<SumBackward1>)
Loss: 3.0385549068450928
Pos probabilities: tensor([

In [None]:
model = None
cls = None
model_name = '' # example:stgnn-cnn-gnn

lr = 0.01
optimizer = torch.optim.Adam(itertools.chain(model.parameters(), cls.parameters()), lr=lr)

print(model)
print(cls)

In [None]:
epochs = 777

## Training loop

In [None]:
loss_per_epoch = []
acc_per_epoch = []
for e in range(epochs):
    model.train()
    # forward
    h = model()
    pos_score = cls()
    neg_score = cls()
    loss = compute_loss(pos_score, neg_score)
    acc = (pos_score > 0.5).mean
    
    loss_per_epoch.append()
    acc_per_epoch.append()#(pos_score > 0.5).mean
    
    # backward
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    if e % 5 == 0:
        print("In epoch {}, loss: {}".format(e, loss))

### Plot metrics

In [None]:
def plot_train_metric(values_per_epoch, name):
    plt.figure(figsize=(10, 5))
    plt.plot(values_per_epoch, label='Train ' + name)
    plt.xlabel('Epoch')
    plt.ylabel(name)
    plt.title('Training ' + name)
    plt.legend()
    plt.show()
    
plot_train_metric(loss_per_epoch, 'Loss')
plot_train_metric(loss_per_epoch, 'Accuracy')

In [None]:
# Plotting
plt.figure(figsize=(10, 8))

plt.subplot(2, 1, 1)
plt.plot(loss_per_epoch, label='Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.legend()

plt.subplot(2, 1, 2)
plt.plot(acc_per_epoch, label='Accuracy', color='orange')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Accuracy')
plt.legend()

plt.tight_layout()
plt.show()

## Evaluate the STGNN

In [None]:
def compute_auc(pos_score, neg_score):
    scores = torch.cat([pos_score, neg_score]).numpy()
    labels = torch.cat(
        [torch.ones(pos_score.shape[0]), torch.zeros(neg_score.shape[0])]
    ).numpy()
    return roc_auc_score(labels, scores)

In [None]:
model.eval()
test_size = test_timestamps.shape[0]
loss = 0
acc = 0
auc = 0
for pos_snapshot, neg_snapshot in zip(test_pos_dataset, test_neg_dataset):
    # forward
    h = model()
    pos_score = cls()
    neg_score = cls()
    loss += compute_loss(pos_score, neg_score)
    acc += (pos_score > 0.5).mean
    auc += compute_auc(pos_score, neg_score)

print(f'Loss:{loss/test_size}\nAccuracy:{acc/test_size}\nAuc:{auc/test_size}\n')

# Save learned parameters of the model

In [None]:
if reduced:
    model_savepath = f"../models/{model_name}-reduced-{int(frac1 * 100)}.pth"
elif reduced_sample_alt_e:
    if sampled1:
        model_savepath = f"../models/{model_name}-reduced-{int(frac2 * 100)}-h-{min_alt}-{max_alt}-e-{int(e_thres * 100)}.pth"
    else:
        model_savepath = f"../models/{model_name}-reduced-h-{min_alt}-{max_alt}-e-{int(e_thres * 100)}.pth"
elif reduced_sample_leos:
    if sampled2:
        model_savepath = f"../models/{model_name}-{leo}-reduced-{int(frac3 * 100)}.pth"
    else:
        model_savepath = f"../models/{model_name}-{leo}.pth"
else:
    model_savepath = '../models/{model_name}.pth'

torch.save(model.state_dict(), model_savepath)