### Temporal NNDR

This notebook show how to step by step get the temporal NNDR 

- Loading a checkpoint of an embedder trained on original dataset.
- Getting temporal embedding from original and generated sets.
- Applying DTW to calculate a distance matrix between these embeddings.
- Applying NNDR on this matrix to assess privacy.

In [1]:
%cd ..

/home/houssem.souid/brainiac-1-temporal


In [2]:
from brainiac_temporal.data.datasets import (
    fetch_insecta_dataset,
    load_imdb_dynamic_tgt,
    load_tigger_datasets_tgt,
)
from torch_geometric.data import Data
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
import torch
import torch_geometric_temporal as tgt  # type: ignore
from brainiac_temporal.data import LightningTemporalDataset
from brainiac_temporal.models import LinkPredictor

Start by loading the dataset and the corresponding embedder checkpoint

In [3]:
dataset_name= "wiki_small"

reference_dataset = load_tigger_datasets_tgt(dataset_name)


In [4]:
from sota_paper_implementation.convertors.dymond2tgt import convert_dymond_sample_to_tgt

# Load generated graph, works if you are on Tartarus
path = "/home/houssem.souid/brainiac-1-temporal/generated_graphs/dymond/generated_graph.pklz"
syn_data= convert_dymond_sample_to_tgt(path)

In [5]:
reference_dataset.snapshot_count

50

In [6]:
syn_data.snapshot_count

42

In [7]:
reference_dataset[0]

Data(x=[1616, 1], edge_index=[2, 48], edge_attr=[48], y=[1616])

In [8]:
syn_data[0]

Data(x=[1616, 1], edge_index=[2, 2], edge_attr=[2], y=[1616])

In [9]:
node_features = reference_dataset.features[0].shape[1]

Load the corresponding embedder from the checkpoints folder available on tartars:

- /home/houssem.souid/brainiac1_temporal/checkpoints_lp

In [10]:
import torch 
checkpoint = torch.load("/home/houssem.souid/brainiac-1-temporal/checkpoints_lp/best_wiki_small.ckpt")

In [11]:
embedder = LinkPredictor(
                node_features=node_features,
                embedding_size=checkpoint["embedding_size"],
                mlp_hidden_sizes= checkpoint["mlp_hidden_sizes"],
                message_passing_class="GConvGRU",
                message_passing_kwargs= checkpoint["message_passing_kwargs"]
            
            )

In [12]:
# Transfer weights from the checkpoint to the model
embedder.load_state_dict(checkpoint["state_dict"])

<All keys matched successfully>

In [13]:
embedder.eval()

LinkPredictor(
  (recurrent): GConvGRU(
    (conv_x_z): ChebConv(1, 32, K=2, normalization=sym)
    (conv_h_z): ChebConv(32, 32, K=2, normalization=sym)
    (conv_x_r): ChebConv(1, 32, K=2, normalization=sym)
    (conv_h_r): ChebConv(32, 32, K=2, normalization=sym)
    (conv_x_h): ChebConv(1, 32, K=2, normalization=sym)
    (conv_h_h): ChebConv(32, 32, K=2, normalization=sym)
  )
  (mlp): ModuleList(
    (0): Linear(in_features=64, out_features=64, bias=True)
    (1): Linear(in_features=64, out_features=32, bias=True)
    (2): Linear(in_features=32, out_features=1, bias=True)
  )
  (auc): BinaryAUROC()
)

Get the embeddings on both original and generated datasets

In [14]:
def get_embeddings(model, data: Data, prev_embedding = None):
    with torch.no_grad():
        x, edge_index = data.x, data.edge_index
        # Apply the model to get embeddings
        embeddings = model(x, edge_index, prev_embedding)
    return embeddings


In [15]:
# Iterate through the dataset and compute embeddings

embeddings =  get_embeddings(embedder, reference_dataset[0])
orig_embeddings_list= [embeddings.cpu().numpy()]
for data in reference_dataset[1:]:
    embeddings = get_embeddings(embedder, data, torch.tensor(embeddings))
    orig_embeddings_list.append(embeddings.cpu().numpy())


orig_embeddings = np.stack(orig_embeddings_list)

  embeddings = get_embeddings(embedder, data, torch.tensor(embeddings))


In [16]:

# Iterate through your dataset and compute embeddings
embeddings =  get_embeddings(embedder, syn_data[0])
gen_embeddings_list= [embeddings.cpu().numpy()]
for data in syn_data[1:]:
    if data.edge_index.size(0) != 0:
        embeddings = get_embeddings(embedder, data, torch.tensor(embeddings))
    gen_embeddings_list.append(embeddings.cpu().numpy())

gen_embeddings = np.stack(gen_embeddings_list)

  embeddings = get_embeddings(embedder, data, torch.tensor(embeddings))


In [17]:
gen_embeddings.shape

(42, 1616, 32)

In [18]:

orig_embeddings.shape

(50, 1616, 32)

Apply dtw on temporal embeddings to obtain the embeddings matrix

In [19]:
from dtaidistance.dtw_ndim import distance_fast
#distance_fast requires conversion to double
gen_embeddings = np.array(gen_embeddings, dtype=np.float64)
orig_embeddings = np.array(orig_embeddings, dtype=np.float64)
#get matrix of embeddings
emb_matrix = [
   [distance_fast(gen_embeddings[:, i, :], orig_embeddings[:, j, :]) 
   for j in range(orig_embeddings.shape[1])]
   for i in range(gen_embeddings.shape[1])
]


Calculate the NNDR score

In [20]:
from brainiac_temporal.metrics.nndr import get_nndr
(nndr_sore, hist, edges) = get_nndr(torch.Tensor(emb_matrix))

In [21]:
nndr_sore

tensor([1., 1., 1.,  ..., 1., 1., 1.])

In [22]:
hist

tensor([1584.,    0.,    0.,  ...,   30.,    0., 1584.])

In [23]:
edges

tensor([0.0000e+00, 3.1736e-09, 6.3471e-09,  ..., 2.4406e-07, 2.4644e-07,
        2.4882e-07])

Using integrated nndr in brainiac_temporal evaluator module

In [24]:
from brainiac_temporal.metrics.evaluator import MetricEvaluator

In [25]:
metric_evaluator = MetricEvaluator(utility_metrics = None,get_privacy_metric=True, embedder_path= "/home/houssem.souid/brainiac-1-temporal/checkpoints_lp/best_wiki_small.ckpt" )

In [26]:
metrics = metric_evaluator(reference_dataset, syn_data)



spectral
degree
degree_centrality
clustering
closeness_centrality
katz_centrality
eigenvector_centrality
avg_clust_coeff
transitivity
diameter
average_shortest_path_length


  embedder, data, torch.tensor(embeddings)


In [27]:
(metrics["nndr_score"] == nndr_sore.cpu().numpy()).any()

True

In [28]:
(metrics["nndr_histogram"] == hist.cpu().numpy()).any()

True

In [29]:
(metrics["nndr_edges"]== edges.cpu().numpy()).any()

True