In [None]:
import torch
import torch.nn.functional as F
import pandas as pd
from sklearn.preprocessing import OneHotEncoder
import numpy as np
from torch_geometric.data import Data
import xxhash

## Loading edge files with a percentile threshold on the edge weights. Higher percentile extracts stronger relations. This parameter can be adjusted to control the strength of trends that we want to predict for future.

In [None]:
def load_data(year, data_dir, percentile=0.9):
    edges = pd.read_parquet(f'{data_dir}/{year}/{year}_edges.parquet', engine='pyarrow')
    nodes = pd.read_parquet(f'{data_dir}/{year}/{year}_nodes.parquet', engine='pyarrow')
    weight_threshold = edges['weight'].quantile(percentile)
    filtered_edges = edges[edges['weight'] >= weight_threshold]
    return filtered_edges, nodes

In [None]:
data_dir = "gs://datasets-dev-ded86f66/benchmarks/scientific_trend_prediction/new_parquet_data"
years = range(1980, 2024)

all_node_ids = set()
id_to_label = {}
for i in years:
    _, n = load_data(i, data_dir)
    all_node_ids = all_node_ids.union(set(n['node_id'].tolist()))
    keys , vals = n['node_id'].tolist() , n['node_label'].tolist()
    entries = {key: value for key, value in zip(keys, vals)}
    id_to_label.update(entries)

## Extracting node features as out degree by type and converting each year data into a graph object

In [None]:
import numpy as np
import torch
from torch_geometric.data import Data
import networkx as nx

def featurizer(edges, node_ids, id_to_label):
    label_order = ['phenotype', 'gene', 'compound']
    label_to_index = {label: i for i, label in enumerate(label_order)}

    node_features = np.zeros((len(node_ids), 3), dtype=float)
    out_degree_count = {node: {label: 0 for label in label_order} for node in node_ids}

    for src, dest in zip(edges['source_id'], edges['destination_id']):
        dest_label = id_to_label[dest]
        out_degree_count[src][dest_label] += 1

    for i, node in enumerate(node_ids):
        node_feature_vector = [out_degree_count[node][label] for label in label_order]
        node_features[i] = node_feature_vector

    return torch.tensor(node_features, dtype=torch.float)

In [None]:
node_ids = list(all_node_ids)
node_id_to_index = {node_id: idx for idx, node_id in enumerate(node_ids)}

graphs = []

for year in years:
    edges, _ = load_data(year, data_dir)
    node_feature = featurizer(edges, node_ids, id_to_label)
    edge_index = np.array([edges['source_id'].map(node_id_to_index).values,
                           edges['destination_id'].map(node_id_to_index).values])
    edge_index = torch.tensor(edge_index, dtype=torch.long)
    edge_weights = torch.tensor(edges['weight'].values, dtype=torch.float)
    g = Data(x=node_feature, edge_index=edge_index, edge_attr=edge_weights, y=edge_weights)
    graphs.append(g)

# Normalizing edge weights

In [None]:
def normalize_edge_weights_min_max(graph_list):
    all_weights = []
    for graph in graph_list:
        all_weights.extend(graph.edge_attr.view(-1).tolist())

    min_weight = np.array(all_weights).min()
    max_weight = np.array(all_weights).max()

    for graph in graph_list:
        edge_attr_normalized = (graph.edge_attr - min_weight) / (max_weight - min_weight)
        graph.edge_attr = edge_attr_normalized
        graph.y = edge_attr_normalized

    return graph_list

In [None]:
graphs = normalize_edge_weights_min_max(graphs)
print(f"Number of graphs: {len(graphs)}")

## GNN-LSTM Layer Implemetation from torch-geometric temporal library

In [None]:
from torch.nn import Parameter
from torch_geometric.nn import ChebConv
from torch_geometric.nn.inits import glorot, zeros

class GConvLSTM(torch.nn.Module):
    r"""An implementation of the Chebyshev Graph Convolutional Long Short Term Memory
    Cell. For details see this paper: `"Structured Sequence Modeling with Graph
    Convolutional Recurrent Networks." <https://arxiv.org/abs/1612.07659>`_

    Args:
        in_channels (int): Number of input features.
        out_channels (int): Number of output features.
        K (int): Chebyshev filter size :math:`K`.
        normalization (str, optional): The normalization scheme for the graph
            Laplacian (default: :obj:`"sym"`):

            1. :obj:`None`: No normalization
            :math:`\mathbf{L} = \mathbf{D} - \mathbf{A}`

            2. :obj:`"sym"`: Symmetric normalization
            :math:`\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1/2} \mathbf{A}
            \mathbf{D}^{-1/2}`

            3. :obj:`"rw"`: Random-walk normalization
            :math:`\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1} \mathbf{A}`

            You need to pass :obj:`lambda_max` to the :meth:`forward` method of
            this operator in case the normalization is non-symmetric.
            :obj:`\lambda_max` should be a :class:`torch.Tensor` of size
            :obj:`[num_graphs]` in a mini-batch scenario and a
            scalar/zero-dimensional tensor when operating on single graphs.
            You can pre-compute :obj:`lambda_max` via the
            :class:`torch_geometric.transforms.LaplacianLambdaMax` transform.
        bias (bool, optional): If set to :obj:`False`, the layer will not learn
            an additive bias. (default: :obj:`True`)
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        K: int,
        normalization: str = "sym",
        bias: bool = True,
    ):
        super(GConvLSTM, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.K = K
        self.normalization = normalization
        self.bias = bias
        self._create_parameters_and_layers()
        self._set_parameters()

    def _create_input_gate_parameters_and_layers(self):

        self.conv_x_i = ChebConv(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            K=self.K,
            normalization=self.normalization,
            bias=self.bias,
        )

        self.conv_h_i = ChebConv(
            in_channels=self.out_channels,
            out_channels=self.out_channels,
            K=self.K,
            normalization=self.normalization,
            bias=self.bias,
        )

        self.w_c_i = Parameter(torch.Tensor(1, self.out_channels))
        self.b_i = Parameter(torch.Tensor(1, self.out_channels))

    def _create_forget_gate_parameters_and_layers(self):

        self.conv_x_f = ChebConv(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            K=self.K,
            normalization=self.normalization,
            bias=self.bias,
        )

        self.conv_h_f = ChebConv(
            in_channels=self.out_channels,
            out_channels=self.out_channels,
            K=self.K,
            normalization=self.normalization,
            bias=self.bias,
        )

        self.w_c_f = Parameter(torch.Tensor(1, self.out_channels))
        self.b_f = Parameter(torch.Tensor(1, self.out_channels))

    def _create_cell_state_parameters_and_layers(self):

        self.conv_x_c = ChebConv(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            K=self.K,
            normalization=self.normalization,
            bias=self.bias,
        )

        self.conv_h_c = ChebConv(
            in_channels=self.out_channels,
            out_channels=self.out_channels,
            K=self.K,
            normalization=self.normalization,
            bias=self.bias,
        )

        self.b_c = Parameter(torch.Tensor(1, self.out_channels))

    def _create_output_gate_parameters_and_layers(self):

        self.conv_x_o = ChebConv(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            K=self.K,
            normalization=self.normalization,
            bias=self.bias,
        )

        self.conv_h_o = ChebConv(
            in_channels=self.out_channels,
            out_channels=self.out_channels,
            K=self.K,
            normalization=self.normalization,
            bias=self.bias,
        )

        self.w_c_o = Parameter(torch.Tensor(1, self.out_channels))
        self.b_o = Parameter(torch.Tensor(1, self.out_channels))

    def _create_parameters_and_layers(self):
        self._create_input_gate_parameters_and_layers()
        self._create_forget_gate_parameters_and_layers()
        self._create_cell_state_parameters_and_layers()
        self._create_output_gate_parameters_and_layers()

    def _set_parameters(self):
        glorot(self.w_c_i)
        glorot(self.w_c_f)
        glorot(self.w_c_o)
        zeros(self.b_i)
        zeros(self.b_f)
        zeros(self.b_c)
        zeros(self.b_o)

    def _set_hidden_state(self, X, H):
        if H is None:
            H = torch.zeros(X.shape[0], self.out_channels).to(X.device)
        return H

    def _set_cell_state(self, X, C):
        if C is None:
            C = torch.zeros(X.shape[0], self.out_channels).to(X.device)
        return C

    def _calculate_input_gate(self, X, edge_index, edge_weight, H, C, lambda_max):
        I = self.conv_x_i(X, edge_index, edge_weight, lambda_max=lambda_max)
        I = I + self.conv_h_i(H, edge_index, edge_weight, lambda_max=lambda_max)
        I = I + (self.w_c_i * C)
        I = I + self.b_i
        I = torch.sigmoid(I)
        return I

    def _calculate_forget_gate(self, X, edge_index, edge_weight, H, C, lambda_max):
        F = self.conv_x_f(X, edge_index, edge_weight, lambda_max=lambda_max)
        F = F + self.conv_h_f(H, edge_index, edge_weight, lambda_max=lambda_max)
        F = F + (self.w_c_f * C)
        F = F + self.b_f
        F = torch.sigmoid(F)
        return F

    def _calculate_cell_state(self, X, edge_index, edge_weight, H, C, I, F, lambda_max):
        T = self.conv_x_c(X, edge_index, edge_weight, lambda_max=lambda_max)
        T = T + self.conv_h_c(H, edge_index, edge_weight, lambda_max=lambda_max)
        T = T + self.b_c
        T = torch.tanh(T)
        C = F * C + I * T
        return C

    def _calculate_output_gate(self, X, edge_index, edge_weight, H, C, lambda_max):
        O = self.conv_x_o(X, edge_index, edge_weight, lambda_max=lambda_max)
        O = O + self.conv_h_o(H, edge_index, edge_weight, lambda_max=lambda_max)
        O = O + (self.w_c_o * C)
        O = O + self.b_o
        O = torch.sigmoid(O)
        return O

    def _calculate_hidden_state(self, O, C):
        H = O * torch.tanh(C)
        return H

    def forward(
        self,
        X: torch.FloatTensor,
        edge_index: torch.LongTensor,
        edge_weight: torch.FloatTensor = None,
        H: torch.FloatTensor = None,
        C: torch.FloatTensor = None,
        lambda_max: torch.Tensor = None,
    ) -> torch.FloatTensor:
        """
        Making a forward pass. If edge weights are not present the forward pass
        defaults to an unweighted graph. If the hidden state and cell state
        matrices are not present when the forward pass is called these are
        initialized with zeros.

        Arg types:
            * **X** *(PyTorch Float Tensor)* - Node features.
            * **edge_index** *(PyTorch Long Tensor)* - Graph edge indices.
            * **edge_weight** *(PyTorch Long Tensor, optional)* - Edge weight vector.
            * **H** *(PyTorch Float Tensor, optional)* - Hidden state matrix for all nodes.
            * **C** *(PyTorch Float Tensor, optional)* - Cell state matrix for all nodes.
            * **lambda_max** *(PyTorch Tensor, optional but mandatory if normalization is not sym)* - Largest eigenvalue of Laplacian.

        Return types:
            * **H** *(PyTorch Float Tensor)* - Hidden state matrix for all nodes.
            * **C** *(PyTorch Float Tensor)* - Cell state matrix for all nodes.
        """
        H = self._set_hidden_state(X, H)
        C = self._set_cell_state(X, C)
        I = self._calculate_input_gate(X, edge_index, edge_weight, H, C, lambda_max)
        F = self._calculate_forget_gate(X, edge_index, edge_weight, H, C, lambda_max)
        C = self._calculate_cell_state(X, edge_index, edge_weight, H, C, I, F, lambda_max)
        O = self._calculate_output_gate(X, edge_index, edge_weight, H, C, lambda_max)
        H = self._calculate_hidden_state(O, C)
        return H, C

## Temporal link predictor model architecture

In [None]:
class TemporalGNN(torch.nn.Module):
    def __init__(self, num_nodes, node_features, hidden_channels, output_channels):
        super(TemporalGNN, self).__init__()
        self.recurrent = GConvLSTM(node_features, hidden_channels, 3)
        self.linear = torch.nn.Linear(hidden_channels, output_channels)
        self.edge_mlp = torch.nn.Sequential(
                torch.nn.Linear(2 * output_channels, hidden_channels),
                torch.nn.ReLU(),
                torch.nn.Linear(hidden_channels, 1)
            )

    def forward(self, seq):
        H, C = None, None
        for i in range(len(seq)):
            x = seq[i].x
            edge_index = seq[i].edge_index
            edge_attr = seq[i].edge_attr
            H, C = self.recurrent(x, edge_index, edge_attr, H, C)

        H = F.relu(H)
        H = self.linear(H)
        return F.log_softmax(H, dim=1)

    def predict_edge_weight(self, node_embeddings, edge_index):
        src, dst = edge_index
        edge_features = torch.cat([node_embeddings[src], node_embeddings[dst]], dim=1)
        probs = self.edge_mlp(edge_features)
        probs = torch.sigmoid(probs)
        return probs.squeeze()

    def get_edge_embeddings(self, node_embeddings, edge_index):
        src, dst = edge_index
        edge_features = torch.cat([node_embeddings[src], node_embeddings[dst]], dim=1)
        return edge_features

# Initializing model parameters

In [None]:
node_dim = graphs[0].x.shape[1]
num_nodes = graphs[0].x.shape[0]
hidden_channels = 64
output_channels = 64
learning_rate = 0.0001
epochs = 30
time_window = 10
weight_decay = 0.0001

tgn_model = TemporalGNN(num_nodes, node_dim, hidden_channels, output_channels)
optimizer = torch.optim.Adam(tgn_model.parameters(), lr=learning_rate, weight_decay=weight_decay)
criterion = torch.nn.MSELoss()

In [None]:
import torch
import gcsfs

path = 'gs://datasets-dev-ded86f66/benchmarks/scientific_trend_prediction/model_weights/TGN_90th_precentile_checkpoint.pth'
fs = gcsfs.GCSFileSystem()

with fs.open(path, 'rb') as f:
    state_dict = torch.load(f)

tgn_model.load_state_dict(state_dict)

## Creating an 80:20 train-test split, where the test data follows the training data in chronological order.

In [None]:
from sklearn.model_selection import train_test_split
import copy

def create_sequences(data, time_step):
    X, Y = [], []
    for i in range(len(data) - time_step - 1):
        X.append(data[i:(i + time_step)])
        Y.append(data[i + time_step])
    return X, Y

x, y = create_sequences(graphs, time_window)

split_index = int(len(x) * 0.8)

x_train, x_test = copy.deepcopy(x[:split_index]), copy.deepcopy(x[split_index:])
y_train, y_test = copy.deepcopy(y[:split_index]), copy.deepcopy(y[split_index:])

print("Size of x_train:", len(x_train))
print("Size of x_test:", len(x_test))
print("Size of y_train:", len(y_train))
print("Size of y_test:", len(y_test))

In [None]:
import copy

def create_sequences(data, time_step, forecast_length):
    X, Y = [], []
    for i in range(len(data) - time_step - forecast_length + 1):
        X.append(data[i:i + time_step])
        Y.append(data[i + time_step + forecast_length - 1])
    return X, Y

time_window = 10
forecast_length = 5

x, y = create_sequences(graphs, time_window, forecast_length)

split_index = int(len(x) * 0.8)

x_train, x_test = copy.deepcopy(x[:split_index]), copy.deepcopy(x[split_index:])
y_train, y_test = copy.deepcopy(y[:split_index]), copy.deepcopy(y[split_index:])

print("Size of x_train:", len(x_train))
print("Size of x_test:", len(x_test))
print("Size of y_train:", len(y_train))
print("Size of y_test:", len(y_test))

In [None]:
tgn_model.load_state_dict(torch.load('test_model.pth'))

## Generating negative samples equal to the number of positive samples

In [None]:
import torch
import random
from torch_geometric.utils import negative_sampling

def add_negative_samples(data):
    num_pos_samples = data.edge_index.size(1)
    num_neg_samples = num_pos_samples
    neg_edge_index = negative_sampling(data.edge_index, num_nodes=data.num_nodes, num_neg_samples=num_neg_samples)
    neg_weights = torch.zeros(num_neg_samples, device=data.edge_index.device)

    data.edge_index = torch.cat([data.edge_index, neg_edge_index], dim=1)
    data.y = torch.cat([data.edge_attr, neg_weights])
    data.edge_attr = data.y

    perm = torch.randperm(data.edge_index.size(1))

    data.edge_index = data.edge_index[:, perm]
    data.edge_attr = data.edge_attr[perm]
    data.y = data.y[perm]

    return data

for i in range(len(y_train)):
    y_train[i] = add_negative_samples(y_train[i])

for i in range(len(y_test)):
    y_test[i] = add_negative_samples(y_test[i])

## Training loop for TGN with objective to predict edges for N+1th graph given last N graphs. Note: Use saved model: 'TGN_90th_precentile_checkpoint.pth' to save time.

In [None]:
import matplotlib.pyplot as plt

train_losses = []
val_losses = []

tgn_model.train()
for epoch in range(epochs):
    total_loss = 0
    tgn_model.train()
    for i in range(len(x_train)):
        optimizer.zero_grad()
        node_embeddings = tgn_model(x_train[i])
        probs = tgn_model.predict_edge_weight(node_embeddings, y_train[i].edge_index)
        loss = criterion(probs, y_train[i].y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_train_loss = total_loss / len(x_train)
    train_losses.append(avg_train_loss)

    # Validation
    tgn_model.eval()
    total_val_loss = 0
    with torch.no_grad():
        for i in range(len(x_test)):
            node_embeddings = tgn_model(x_test[i])
            val_loss = criterion(weights, y_test[i].y)
            total_val_loss += val_loss.item()

    avg_val_loss = total_val_loss / len(x_test)
    val_losses.append(avg_val_loss)

    print(f'Epoch {epoch+1}, Loss: {avg_train_loss}, Val Loss: {avg_val_loss}')

In [None]:
#torch.save(tgn_model.state_dict(), 'TGN_90th_precentile_checkpoint.pth')

## Decoupled MLP with inmemory TGN embeddings

In [None]:
class EdgeMLP(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(EdgeMLP, self).__init__()
        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(input_dim, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim, 1)
        )

    def forward(self, edge_features):
        probs = self.mlp(edge_features)
        probs = torch.sigmoid(probs)
        return probs.squeeze()

In [None]:
def extract_edge_embeddings(graphs,edge_index):
    node_embeddings = tgn_model(graphs)
    edge_embeddings = tgn_model.get_edge_embeddings(node_embeddings,edge_index).detach().numpy()
    return edge_embeddings

In [None]:
x_train_mlp = []
y_train_mlp = []

for i in range(len(x_train)):
    edge_embs = extract_edge_embeddings(x_train[i],y_train[i].edge_index)
    x_train_mlp.append(edge_embs)
    y_train_mlp.append(y_train[i].y)

In [None]:
x_test_mlp = []
y_test_mlp = []

for i in range(len(x_test)):
    edge_embs = extract_edge_embeddings(x_test[i],y_test[i].edge_index)
    x_test_mlp.append(edge_embs)
    y_test_mlp.append(y_test[i].y)

In [None]:
input_dim = x_train_mlp[0].shape[1]
hidden_dim = 64
mlp_model = EdgeMLP(input_dim, hidden_dim)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(mlp_model.parameters(), lr=0.0001)

num_epochs = 100

for epoch in range(num_epochs):
    mlp_model.train()
    train_loss_total = 0.0
    for i in range(len(x_train_mlp)):
        optimizer.zero_grad()

        outputs = mlp_model(torch.tensor(x_train_mlp[i]).float()).squeeze()
        loss = criterion(outputs, torch.tensor(y_train_mlp[i]).float())

        loss.backward()
        optimizer.step()
        train_loss_total += loss.item()

    avg_train_loss = train_loss_total / len(x_train_mlp)

    mlp_model.eval()
    val_loss_total = 0.0
    with torch.no_grad():
        for i in range(len(x_test_mlp)):
            outputs = mlp_model(torch.tensor(x_test_mlp[i]).float()).squeeze()
            loss = criterion(outputs, torch.tensor(y_test_mlp[i]).float())
            val_loss_total += loss.item()

    avg_val_loss = val_loss_total / len(x_test_mlp)

    print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss}, Val Loss: {avg_val_loss}')

## Decoupled LSTM with in memory TGN embeddings

In [None]:
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense
from sklearn.metrics import mean_squared_error
from tensorflow.keras.callbacks import Callback
import numpy as np

In [None]:
def extract_edge_embeddings(graphs,edge_index):
    node_embeddings = tgn_model(graphs)
    edge_embeddings = tgn_model.get_edge_embeddings(node_embeddings,edge_index).detach().numpy()
    return edge_embeddings

In [None]:
x_train_lstm = []
y_train_lstm = []

for i in range(len(x_train)):
    embs = []
    for j in range(len(x_train[i])):
        edge_embs = extract_edge_embeddings([x_train[i][j]],y_train[i].edge_index)
        embs.append(edge_embs)
    stacked_embs = np.stack(embs,axis=1)
    x_train_lstm.append(stacked_embs)
    y_train_lstm.append(y_train[i].y)

In [None]:
x_train_lstm = np.vstack(x_train_lstm)
y_train_lstm = np.hstack(y_train_lstm)

In [None]:
x_test_lstm = []
y_test_lstm = []

for i in range(len(x_test)):
    embs = []
    for j in range(len(x_test[i])):
        edge_embs = extract_edge_embeddings([x_test[i][j]],y_test[i].edge_index)
        embs.append(edge_embs)
    stacked_embs = np.stack(embs,axis=1)
    x_test_lstm.append(stacked_embs)
    y_test_lstm.append(y_test[i].y)

In [None]:
x_test_lstm = np.vstack(x_test_lstm)
y_test_lstm = np.hstack(y_test_lstm)

In [None]:
lstm_model = Sequential()
lstm_model.add(LSTM(50, return_sequences=True, input_shape=(x_train_lstm.shape[1], x_train_lstm.shape[2])))
lstm_model.add(LSTM(50))
lstm_model.add(Dense(1, activation='sigmoid'))

lstm_model.compile(optimizer='adam', loss='mean_squared_error')

history = lstm_model.fit(x_train_lstm, y_train_lstm, batch_size=1024, epochs=3, verbose=1)

In [None]:
y_pred = lstm_model.predict(x_test_lstm,batch_size = 1024, verbose=1)

In [None]:
from sklearn.metrics import mean_squared_error

mse = mean_squared_error(y_test_lstm, y_pred)
print("Mean Squared Error:", mse)