In [None]:
import pandas as pd
from graphert.processing_data  import *
from graphert.create_random_walks  import *
from graphert.train_model  import *
from graphert.train_tokenizer  import *
from graphert.temporal_embeddings import *

import numpy as np
import torch
import scipy as sp
import transformers
from sklearn.manifold import TSNE
import seaborn as sns
import matplotlib.pyplot as plt


In [2]:
from sklearn.metrics import make_scorer, precision_score, recall_score, f1_score, accuracy_score
from sklearn.model_selection import cross_val_score, StratifiedKFold
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix

In [3]:
from matplotlib.colors import ListedColormap
import plotly.graph_objs as go

In [4]:
import os

In [5]:
tsne = TSNE(n_components=2, perplexity=40, n_iter=1000, random_state=4)


In [22]:
dataset_name = "hippocampus_rat"

In [23]:
graph_path = 'data/facebook/facebook-wall.txt'

In [9]:
graph_df = pd.read_table(graph_path, sep='\t', header=None)
graph_df.columns = ['source', 'target', 'time']

In [10]:
graph_df

Unnamed: 0,source,target,time
0,28,28,1095135831
1,1015,1017,1097725406
2,959,959,1098387569
3,991,991,1098425204
4,1015,1017,1098489762
...,...,...,...
876988,1715,17995,1232597482
876989,18616,18616,1232598051
876990,28549,31056,1232598370
876991,24830,59912,1232598672


## Constructing temporal graphs from provided adjacency matrices

In [11]:
adj_time_list = pd.read_pickle(f'../../Dataset/adj_time_list_hippocampus_rat.pkl')

In [12]:
threshold = 0 # choose accordingly 

adj_matrices = [np.abs(np.where(np.abs(connectivity_matrix.toarray()) < threshold, 0, connectivity_matrix.toarray())) for connectivity_matrix in adj_time_list]

In [13]:
data = []

# Iterate over each time step and adjacency matrix
for time_step, adj_matrix in enumerate(adj_matrices):
    # Find the indices of the non-zero elements in the adjacency matrix
    source, target = np.where(adj_matrix != 0)
    # Extract the corresponding weights
    weights = adj_matrix[source, target]
    
    # Append the data to the list
    for s, t, w in zip(source, target, weights):
        data.append({
            "source": s,
            "target": t,
            "year": int(time_step+2000),  # You can adjust this based on the actual time step
            "weight": w
        })

# Create the DataFrame
df = pd.DataFrame(data)


In [14]:

# Display the first few rows of the DataFrame
print(df)

        source  target  year    weight
0            0       0  2000  1.000000
1            1       1  2000  1.000000
2            1       2  2000  0.042241
3            1       3  2000  0.033968
4            1       5  2000  0.027462
...        ...     ...   ...       ...
191559     115     115  2084  1.000000
191560     116     116  2084  1.000000
191561     117     117  2084  1.000000
191562     118     118  2084  1.000000
191563     119     119  2084  1.000000

[191564 rows x 4 columns]


In [16]:
graph_nx, temporal_graph = load_dataset(df, dataset_name, time_granularity='years')

In [17]:
graph_nx

<networkx.classes.multidigraph.MultiDiGraph at 0x7b03b796eb10>

In [18]:

graphs = temporal_graph.get_temporal_graphs(min_degree=5)
print(graphs)

{datetime.datetime(2000, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03b601d950>, datetime.datetime(2001, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03b606e6d0>, datetime.datetime(2002, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03b6c69550>, datetime.datetime(2003, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03b609e490>, datetime.datetime(2004, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03aa389dd0>, datetime.datetime(2005, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03b60a2990>, datetime.datetime(2006, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03aa1ee890>, datetime.datetime(2007, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03aa11b090>, datetime.datetime(2008, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03aa1ee350>, datetime.datetime(2009, 1, 1, 0, 0): <networkx.classes.digraph.DiGraph object at 0x7b03aa1ede90>, datetime.datetime(2

In [19]:
# Print the number of nodes and edges
print(f"Number of nodes: {graph_nx.number_of_nodes()}")
print(f"Number of edges: {graph_nx.number_of_edges()}")

Number of nodes: 120
Number of edges: 191564


In [20]:
# Print the attributes of the temporal_graph object
print("TemporalGraph attributes:")
print("Data:")
print(temporal_graph.data.head())  # Display the first few rows of the data
print("Time Granularity:", temporal_graph.time_granularity)
print("Time Columns:", temporal_graph.time_columns)
print("Step:", temporal_graph.step)


TemporalGraph attributes:
Data:
   source  target  year    weight time_index       time
0       0       0  2000  1.000000 2000-01-01 2000-01-01
1       1       1  2000  1.000000 2000-01-01 2000-01-01
2       1       2  2000  0.042241 2000-01-01 2000-01-01
3       1       3  2000  0.033968 2000-01-01 2000-01-01
4       1       5  2000  0.027462 2000-01-01 2000-01-01
Time Granularity: years
Time Columns: ['year']
Step: relativedelta(years=+1)


## GraphERT: 
Transformers-based Temporal Dynamic Graph Embedding

![Alt text](GraphERT.png)

Moran Beladev, Gilad Katz, Lior Rokach, Uriel Singer, Kira Radinsky.
CIKM’23 – October 2023, Birmingham, United Kingdom.

In [21]:
cc_nodes = sorted(nx.connected_components(graph_nx.to_undirected()), key=len, reverse=True)[0] # biggest cc


In [None]:
graphs = {i: v for i, (k, v) in enumerate(graphs.items())}
qs = [0.25, 0.5, 1, 2, 4]
ps = [0.25, 0.5, 1, 2, 4]
walk_lengths = [32]
num_walks_list = [10]
create_random_walks(graphs, ps, qs, walk_lengths, num_walks_list, dataset_name, cc_nodes)

walk_len=32, num_walks=10
0


100%|██████████| 25/25 [03:39<00:00,  8.77s/it]


1


100%|██████████| 25/25 [00:20<00:00,  1.25it/s]


2


100%|██████████| 25/25 [00:06<00:00,  3.74it/s]


3


100%|██████████| 25/25 [00:53<00:00,  2.14s/it]


4


100%|██████████| 25/25 [01:23<00:00,  3.33s/it]


5


100%|██████████| 25/25 [00:35<00:00,  1.43s/it]


6


100%|██████████| 25/25 [03:01<00:00,  7.28s/it]


7


100%|██████████| 25/25 [03:16<00:00,  7.84s/it]


8


100%|██████████| 25/25 [01:43<00:00,  4.13s/it]


9


100%|██████████| 25/25 [02:07<00:00,  5.09s/it]


10


100%|██████████| 25/25 [03:42<00:00,  8.90s/it]


11


  0%|          | 0/25 [00:00<?, ?it/s]

In [49]:
walk_len = walk_lengths[0]
num_walks = num_walks_list[0]

random_walk_path = f'datasets_res/{dataset_name}/paths_walk_len_{walk_len}_num_walks_{num_walks}.csv'

In [None]:
#train a node-level tokenizer
train_graph_tokenizer(random_walk_path, dataset_name, walk_len)
# train_only_temporal_model(random_walk_path, dataset_name, walk_len)
train_mlm_temporal_model(random_walk_path, dataset_name, walk_len)
# train_2_steps_model(random_walk_path, dataset_name, walk_len)


In [None]:
model_path = f'datasets_res/{dataset_name}/models/mlm_and_temporal_model'
    # get temporal embeddings by the last layer


In [None]:
cls_temporal_embeddings = get_temporal_embeddings(model_path)

In [None]:
 # get temporal embeddings by averaging the paths embeddings per time
data_df = pd.read_csv(random_walk_path, index_col=None)
t_cls_emb_mean, t_cls_emb_weighted_mean, t_prob, t_nodes_emb_mean = get_embeddings_by_paths_average(data_df, model_path, dataset_name,
                                                                          walk_len)

In [65]:
# Dictionary to store the average embeddings for each time point
average_embeddings = {}

for time_point, inner_dict in t_nodes_emb_mean.items():
    # Collect all embeddings for the current time point
    embeddings = np.array(list(inner_dict.values()))
    
    # Calculate the average embedding
    average_embedding = np.mean(embeddings, axis=0)
    
    # Store the average embedding in the dictionary
    average_embeddings[time_point] = average_embedding


In [67]:
cls_emb_mean_list = list(t_cls_emb_mean.values())
cls_emb_weighted_mean_list = list(t_cls_emb_weighted_mean.values())
average_embeddings_list = list(average_embeddings.values())