# Setting Environment

In [1]:
!pip install relbench[full]

Collecting relbench[full]
  Downloading relbench-1.1.0-py3-none-any.whl.metadata (12 kB)
Collecting pytorch_frame>=0.2.3 (from relbench[full])
  Downloading pytorch_frame-0.2.5-py3-none-any.whl.metadata (20 kB)
Collecting torch_geometric (from relbench[full])
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch->pytorch_frame>=0.2.3->relbench[full])
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch->pytorch_frame>=0.2.3->relbench[full])
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (f

In [2]:
import relbench
from relbench.datasets import get_dataset_names, get_dataset
from relbench.modeling.utils import get_stype_proposal
from relbench.modeling.graph import make_pkey_fkey_graph
from relbench.tasks import get_task_names, get_task
from relbench.base import TaskType


import torch
from torch_geometric.seed import seed_everything
from torch import Tensor
from torch_frame import stype
from torch_frame.config.text_embedder import TextEmbedderConfig
from torch_frame.data.multi_embedding_tensor import MultiEmbeddingTensor


from typing import List, Optional


from sentence_transformers import SentenceTransformer


import os


import pandas as pd

import numpy as np

import random

import pickle

import requests

In [3]:
class GloveTextEmbedding:
    def __init__(self, device: Optional[torch.device] = None):
        self.model = SentenceTransformer(
            "sentence-transformers/average_word_embeddings_glove.6B.300d",
            device=device,
        )

    def __call__(self, sentences: List[str]) -> Tensor:
        return self.model.encode(sentences, convert_to_tensor=True)

In [4]:
# Check that it's cuda if you want it to run in reasonable time!
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    torch.set_num_threads(1)
print(device)

# Set the seed for generating random numbers to ensure reproducibility
seed_everything(42)

# Path to the directory for caching graph data
root_dir = "./data"

# Configure the text encoder
text_embedder_cfg = TextEmbedderConfig(
    text_embedder=GloveTextEmbedding(device=device),
    batch_size=256
)

cuda


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


modules.json:   0%|          | 0.00/248 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/122 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/2.19k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/480M [00:00<?, ?B/s]

(…)rdEmbeddings%2Fwordembedding_config.json:   0%|          | 0.00/164 [00:00<?, ?B/s]

(…)ddings%2Fwhitespacetokenizer_config.json:   0%|          | 0.00/4.61M [00:00<?, ?B/s]

1_Pooling%2Fconfig.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

In [5]:
print(f"The RelBench version is {relbench.__version__}")
print(f"The RelBench datasets are {get_dataset_names()}")

The RelBench version is 1.1.0
The RelBench datasets are ['rel-amazon', 'rel-avito', 'rel-event', 'rel-f1', 'rel-hm', 'rel-stack', 'rel-trial']


# F1 Dataset Creation

## Downloading a link prediction task

In [6]:
get_task_names("rel-f1")
task = get_task("rel-f1", "driver-dnf", download=True)
assert task.task_type == TaskType.BINARY_CLASSIFICATION

Downloading file 'rel-f1/tasks/driver-dnf.zip' from 'https://relbench.stanford.edu/download/rel-f1/tasks/driver-dnf.zip' to '/root/.cache/relbench'.
100%|█████████████████████████████████████| 37.3k/37.3k [00:00<00:00, 3.44MB/s]
Unzipping contents of '/root/.cache/relbench/rel-f1/tasks/driver-dnf.zip' to '/root/.cache/relbench/rel-f1/tasks/.'


In [7]:
train_table = task.get_table("train")
val_table = task.get_table("val")
# relbench masks in default the values for test_table in order to prevent est leakage, but we need this information: mask_input_cols=False
test_table = task.get_table("test", mask_input_cols=False)

In [8]:
test_table

Table(df=
          date  driverId  did_not_finish
0   2013-03-16       814               0
1   2012-11-16         9               1
2   2012-11-16        17               0
3   2012-10-17         0               1
4   2012-09-17       816               0
..         ...       ...             ...
697 2011-08-24        14               1
698 2011-05-26        14               1
699 2011-05-26       154               0
700 2010-09-28        14               1
701 2010-09-28       154               0

[702 rows x 3 columns],
  fkey_col_to_pkey_table={'driverId': 'drivers'},
  pkey_col=None,
  time_col=date)

In [9]:
test_table.fkey_col_to_pkey_table

{'driverId': 'drivers'}

## Downloading the relbench KG

In [10]:
# We download the f1-dataset
f1_dataset = get_dataset(name="rel-f1", download=True)

# we download the entire database (also the test part)
f1_db = f1_dataset.get_db(upto_test_timestamp = False)
f1_col_to_stype_dict = get_stype_proposal(f1_db)

# Generate graph data
f1_data, f1_col_stats_dict = make_pkey_fkey_graph(
    f1_db,
    col_to_stype_dict = f1_col_to_stype_dict,  # Column types
    text_embedder_cfg = text_embedder_cfg,  # Our chosen text encoder
    cache_dir=os.path.join(
        root_dir, f"rel-f1_materialized_cache"
    ),  # Store materialized graph for convenience
)

Downloading file 'rel-f1/db.zip' from 'https://relbench.stanford.edu/download/rel-f1/db.zip' to '/root/.cache/relbench'.
100%|████████████████████████████████████████| 704k/704k [00:00<00:00, 532MB/s]
Unzipping contents of '/root/.cache/relbench/rel-f1/db.zip' to '/root/.cache/relbench/rel-f1/.'


Loading Database object from /root/.cache/relbench/rel-f1/db...
Done in 0.05 seconds.


Embedding raw data in mini-batch: 100%|██████████| 1/1 [00:00<00:00,  3.44it/s]
Embedding raw data in mini-batch: 100%|██████████| 1/1 [00:00<00:00, 271.92it/s]
Embedding raw data in mini-batch: 100%|██████████| 1/1 [00:00<00:00, 271.53it/s]
Embedding raw data in mini-batch: 100%|██████████| 1/1 [00:00<00:00, 182.35it/s]
  ser = pd.to_datetime(ser, format=time_format)
Embedding raw data in mini-batch: 100%|██████████| 5/5 [00:00<00:00, 177.83it/s]
  ser = pd.to_datetime(ser, format=self.format, errors='coerce')
Embedding raw data in mini-batch: 100%|██████████| 1/1 [00:00<00:00, 55.33it/s]
Embedding raw data in mini-batch: 100%|██████████| 1/1 [00:00<00:00, 86.51it/s]
Embedding raw data in mini-batch: 100%|██████████| 1/1 [00:00<00:00, 165.36it/s]
Embedding raw data in mini-batch: 100%|██████████| 4/4 [00:00<00:00, 119.21it/s]
Embedding raw data in mini-batch: 100%|██████████| 4/4 [00:00<00:00, 98.61it/s]
Embedding raw data in mini-batch: 100%|██████████| 4/4 [00:00<00:00, 121.13it/s]


In [11]:
f1_val_timestep = f1_dataset.val_timestamp
f1_test_timestep = f1_dataset.test_timestamp

print(f"The validation timestep is: {f1_val_timestep}")
print(f"The test timestep is: {f1_test_timestep}")

The validation timestep is: 2005-01-01 00:00:00
The test timestep is: 2010-01-01 00:00:00


In [12]:
f1_data

HeteroData(
  constructor_results={
    tf=TensorFrame([12290, 2]),
    time=[12290],
  },
  results={
    tf=TensorFrame([26080, 11]),
    time=[26080],
  },
  circuits={ tf=TensorFrame([77, 7]) },
  races={
    tf=TensorFrame([1101, 5]),
    time=[1101],
  },
  constructor_standings={
    tf=TensorFrame([13051, 4]),
    time=[13051],
  },
  constructors={ tf=TensorFrame([211, 3]) },
  standings={
    tf=TensorFrame([34124, 4]),
    time=[34124],
  },
  qualifying={
    tf=TensorFrame([9815, 3]),
    time=[9815],
  },
  drivers={ tf=TensorFrame([857, 6]) },
  (constructor_results, f2p_raceId, races)={ edge_index=[2, 12290] },
  (races, rev_f2p_raceId, constructor_results)={ edge_index=[2, 12290] },
  (constructor_results, f2p_constructorId, constructors)={ edge_index=[2, 12290] },
  (constructors, rev_f2p_constructorId, constructor_results)={ edge_index=[2, 12290] },
  (results, f2p_raceId, races)={ edge_index=[2, 26080] },
  (races, rev_f2p_raceId, results)={ edge_index=[2, 26080] },

In [13]:
f1_edges_dict = {
        ('constructor_standings', 'f2p_raceId', 'races'): 'constructor_standings',
        ('races', 'rev_f2p_raceId', 'constructor_standings'): 'races',
        ('constructor_standings', 'f2p_constructorId', 'constructors'): 'constructor_standings',
        ('constructors', 'rev_f2p_constructorId', 'constructor_standings'): 'constructor_standings',
        ('standings', 'f2p_raceId', 'races'): 'standings',
        ('races', 'rev_f2p_raceId', 'standings'): 'races',
        ('standings', 'f2p_driverId', 'drivers'): 'standings',
        ('drivers', 'rev_f2p_driverId', 'standings'): 'standings',
        ('constructor_results', 'f2p_raceId', 'races'): 'constructor_results',
        ('races', 'rev_f2p_raceId', 'constructor_results'): 'races',
        ('constructor_results', 'f2p_constructorId', 'constructors'): 'constructor_results',
        ('constructors', 'rev_f2p_constructorId', 'constructor_results'): 'constructor_results',
        ('results', 'f2p_raceId', 'races'): 'results',
        ('races', 'rev_f2p_raceId', 'results'): 'races',
        ('results', 'f2p_driverId', 'drivers'): 'results',
        ('drivers', 'rev_f2p_driverId', 'results'): 'results',
        ('results', 'f2p_constructorId', 'constructors'): 'results',
        ('constructors', 'rev_f2p_constructorId', 'results'): 'results',
        ('qualifying', 'f2p_raceId', 'races'): 'qualifying',
        ('races', 'rev_f2p_raceId', 'qualifying'): 'races',
        ('qualifying', 'f2p_driverId', 'drivers'): 'qualifying',
        ('drivers', 'rev_f2p_driverId', 'qualifying'): 'qualifying',
        ('qualifying', 'f2p_constructorId', 'constructors'): 'qualifying',
        ('constructors', 'rev_f2p_constructorId', 'qualifying'): 'qualifying',
        ('races', 'f2p_circuitId', 'circuits'): 'races',
        ('circuits', 'rev_f2p_circuitId', 'races'): 'races'
    }
node_without_timestamp = ['drivers', 'circuits', 'constructors']

In [14]:
def pick_pairs(KG_data, edge):
    source_edge_name = edge[0]
    target_edge_name = edge[2]

    edge_index = KG_data[edge].edge_index # retrieve indexes
    source_nodes = edge_index[0]  # source indexes
    target_nodes = edge_index[1]  # targer indexes

    pairs_of_indexes = torch.stack((source_nodes, target_nodes), dim=1)


    return pairs_of_indexes

In [15]:
def train_inference_split_pairs(data, pairs, time_node, val_timestamp, test_timestamp):

    final_pairs_train = []
    final_pairs_val = []
    final_pairs_test = []

    # We iterate trough pairs and timestamps
    for i in range(pairs.shape[0]):
        source_node = pairs[i, 0].item()
        target_node = pairs[i, 1].item()

        # we retrieve the source timestamp
        transaction_time_int = data[time_node].time[source_node].item()

        # Conversion from integer timestamp to pandas.Timestamp
        transaction_time = pd.to_datetime(transaction_time_int, unit='s')

        # We devide edges based on the Timestamps
        if transaction_time < val_timestamp:
            final_pairs_train.append((source_node, target_node))
        elif val_timestamp <= transaction_time and transaction_time < test_timestamp:
            final_pairs_val.append((source_node, target_node))
        else:
            final_pairs_test.append((source_node, target_node))

    return final_pairs_train, final_pairs_val, final_pairs_test

In [None]:
# We iterate on the entire dictionary
for edge, time in edges_dict.items():
    # For each edge we build the pairs of source nodes and destination nodes
    pairs_of_nodes = pick_pairs(KG_data, edge)


    # We devide nodes in train, test and validation based on timestamp
    train_pairs, val_pairs, test_pairs = train_inference_split_database(data = KG_data,
                                                                        pairs = pairs_of_nodes,
                                                                        time_node = time,
                                                                        val_timestamp = val_timestamp,
                                                                        test_timestamp = test_timestamp)

In [19]:
def retrieve_edges_features(KG_data, edges_dict):
    node_features = {}
    output_edges = []

    for edge, time in edges_dict.items():
          # ricavo i nomi di <src> e <dst>
          src_name = edge[0]
          dst_name = edge[2]

          # per ogni edge presente nel grafo prendo le coppie <src>, <dst>
          pairs_of_indexes = pick_pairs(KG_data, edge)

          for i in range(pairs_of_indexes.shape[0]):
              # dalla coppia prendo gli indici destinazione e sorgente
              src_index = pairs_of_indexes[i, 0].item()
              dst_index = pairs_of_indexes[i, 1].item()
              # usando gli indici ricavo le features di destinazione e sorgente dal KG
              src_feat = f1_data[src_name].tf[src_index]
              dst_feat = f1_data[dst_name].tf[dst_index]


              # se quel nodo non è già nel vocabolario node_features allora lo aggiungo
              if f"{src_name}_{src_index}" not in node_features:
                  node_features[f"{src_name}_{src_index}"] = src_feat
              if f"{dst_name}_{dst_index}" not in node_features:
                  node_features[f"{dst_name}_{dst_index}"] = dst_feat


              # itero su tutto il vocabolario node_features cercando f"{src_name}_{src_index}" e f"{dst_name}_{dst_index}" e ne ricavo gli indici
              for idx, (node_name, node_feature) in enumerate(node_features.items()):
                  if node_name == f"{src_name}_{src_index}":
                      src_index = idx
                  if node_name == f"{dst_name}_{dst_index}":
                      dst_index = idx
              # aggiungo l'edge ad output_edges
              output_edges.append((src_index, dst_index))

    return node_features, output_edges

In [None]:
node_features, output_edges = retrieve_edges_features(f1_data, f1_edges_dict)

In [None]:
node_without_timestamp = ['drivers', 'circuits', 'constructors']
def build_masks_and_labels(KG_data, node_without_timestamp, node_features, train_table, val_table, test_table, class_value, val_timestamp, test_timestamp):
    labels = []
    train_mask = []
    val_mask = []
    test_mask = []

    for node_name, node_features in node_features.items():
        # spezzo il nome del nodo in due parti: nome e indice
        node_name, idx = node_name.split('_')
        # controllo se il nodo ha nome oggetto della classificazione
        if node_name == next(iter(train_table.fkey_col_to_pkey_table.values())):
            IDname = next(iter(train_table.fkey_col_to_pkey_table.keys()))
            # controllo se il suo ID è presente in una delle tabelle di task
            nodeID = node_features.IDname

            if nodeID in train_table.df[IDname].values:
                idx = train_table.df[[IDname] == nodeID].index[0]
                label = train_table.df.loc[idx]['did_not_finish']
                labels.append(label)
                train_mask.append(True)
                val_mask.append(False)
                test_mask.append(False)
            if nodeID in val_table.df[IDname].values:
                idx = val_table.df[[IDname] == nodeID].index[0]
                label = val_table.df.loc[idx]['did_not_finish']
                labels.append(label)
                train_mask.append(False)
                val_mask.append(True)
                test_mask.append(False)
            if nodeID in test_table.df[IDname].values:
                idx = test_table.df[[IDname] == nodeID].index[0]
                label = test_table.df.loc[idx]['did_not_finish']
                labels.append(label)
                train_mask.append(False)
                val_mask.append(False)
                test_mask.append(True)
            continue # passa al prossimo nodo

        # se il nodo non è oggetto di classificazione aggiorno le maschere in base al timestamp e come label uso un valore non presente nelle tabelle
        # se il nodo ha un valore time allora lo uso per aggiornare le maschere
        elif node_name in node_without_timestamp:
          labels.append(class_value)
          time_int = KG_data[node_name].time[idx].item()
          time_value = pd.to_datetime(time_int, unit='s')

          if time_value < val_timestamp:
              train_mask.append(True)
              val_mask.append(False)
              test_mask.append(False)
          elif val_timestamp <= time_value and time_value < test_timestamp:
              train_mask.append(False)
              val_mask.append(True)
              test_mask.append(False)
          else:
              train_mask.append(False)
              val_mask.append(False)
              test_mask.append(True)

          continue # passa al prossimo nodo

        # se il nodo non ha un valore time allora è visibile in tutti gli split
        else:
          train_mask.append(True)
          val_mask.append(True)
          test_mask.append(True)
          continue # passa al prossimo nodo

    return labels, train_mask, val_mask, test_mask

In [None]:
labels, train_mask, val_mask, test_mask = build_masks_and_labels(f1_data,
                                                                  node_without_timestamp,
                                                                  node_features,
                                                                  train_table,
                                                                  val_table,
                                                                  test_table,
                                                                  2,
                                                                  f1_val_timestep,
                                                                  f1_test_timestep)

print(f"{len(node_features)}")
print(f"{len(output_edges)}")
print(f"{len(labels)}")
print(f"{len(train_mask)}")
print(f"{len(val_mask)}")
print(f"{len(test_mask)}")

In [None]:
def flatten_multi_embedding(met: MultiEmbeddingTensor, device=None, flatten_extra_dims=True):
    """
    Convert MultiEmbeddingTensor to dense tensor with better error handling.
    """
    if device is None:
        device = torch.device("cpu")

    # 1. First check for direct tensor conversion methods
    if hasattr(met, 'to_tensor') and callable(met.to_tensor):
        tensor = met.to_tensor()
        if isinstance(tensor, torch.Tensor):
            if tensor.layout != torch.strided:
                tensor = tensor.to_dense()
            tensor = tensor.to(device)
            if flatten_extra_dims and tensor.dim() > 2:
                tensor = tensor.flatten(start_dim=1)
            return tensor

    # 2. Look for embedding storage in attributes
    dict_candidates = ["_data", "embeddings", "_embeddings", "_tensor_dict", "values"]
    embedding_dict = None

    for candidate in dict_candidates:
        if hasattr(met, candidate):
            candidate_val = getattr(met, candidate)
            # Handle both direct tensors and dictionaries
            if isinstance(candidate_val, torch.Tensor):
                return _process_tensor(candidate_val, device, flatten_extra_dims)
            elif isinstance(candidate_val, dict):
                embedding_dict = candidate_val
                break

    # 3. Handle case where MultiEmbeddingTensor wraps a single tensor
    if embedding_dict is None:
        if hasattr(met, 'values') and callable(met.values):
            tensor = met.values()
            return _process_tensor(tensor, device, flatten_extra_dims)
        else:
            raise ValueError(
                f"Failed to unpack MultiEmbeddingTensor. Available attributes: {dir(met)}\n"
                "Consider inspecting the object structure with: "
                "print(dir(your_multi_embedding_tensor))"
            )

    # 4. Process dictionary of embeddings
    sub_tensors = []
    for emb in embedding_dict.values():
        if isinstance(emb, torch.Tensor):
            if emb.layout != torch.strided:
                emb = emb.to_dense()
            emb = emb.to(device)
            if flatten_extra_dims and emb.dim() > 2:
                emb = emb.flatten(start_dim=1)
            sub_tensors.append(emb)
        else:
            raise TypeError(f"Unexpected embedding type: {type(emb)}")

    return torch.cat(sub_tensors, dim=1)

def _process_tensor(tensor: torch.Tensor, device, flatten_extra_dims) -> torch.Tensor:
    """Helper for consistent tensor processing"""
    if tensor.layout != torch.strided:
        tensor = tensor.to_dense()
    tensor = tensor.to(device)
    if flatten_extra_dims and tensor.dim() > 2:
        tensor = tensor.flatten(start_dim=1)
    return tensor


def torchframe_to_tensor(tf, device=None, flatten_extra_dims=True):
    """
    Robustly convert a TorchFrame to a dense torch.Tensor by handling
    MultiEmbeddingTensors and other column types.
    """
    if device is None:
        device = torch.device("cpu")

    feats = []
    for stype_key, typed_feat in tf.feat_dict.items():
        # Resolve potential lazy-loaded tensors
        if callable(typed_feat):
            typed_feat = typed_feat()

        # Handle MultiEmbeddingTensor
        if isinstance(typed_feat, MultiEmbeddingTensor):
            met_tensor = flatten_multi_embedding(
                typed_feat, device=device, flatten_extra_dims=flatten_extra_dims
            )
            feats.append(met_tensor)
            continue

        # Convert sparse tensors to dense
        if hasattr(typed_feat, "to_dense") and callable(typed_feat.to_dense):
            typed_feat = typed_feat.to_dense()

        # Fallback to .values() if not a tensor
        if not isinstance(typed_feat, torch.Tensor):
            if hasattr(typed_feat, "values") and callable(typed_feat.values):
                typed_feat = typed_feat.values()
            else:
                raise TypeError(
                    f"Feature {stype_key} is not a tensor. Got {type(typed_feat)}"
                )

        # Flatten
        typed_feat = typed_feat.to(device)
        if flatten_extra_dims and typed_feat.dim() > 2:
            typed_feat = typed_feat.flatten(start_dim=1)

        feats.append(typed_feat)

    return torch.cat(feats, dim=1)

In [None]:
def remove_last_sep(s: str) -> str:
    sep = "[SEP]"
    last_index = s.rfind(sep)
    if last_index != -1:
        return s[:last_index]
    return s


def linearize_features(node_features: list, device=None) -> list:
    """
    Linearizes a list of node features in the format: <name_feature_1> <val_1> [SEP] <name_feature_2> <val2> [SEP] ...
    """

    if device is None:
        device = torch.device("cpu")

    linearized_tensors = []

    for i, tensor_frame in enumerate(node_features):

      tf_on_device = tensor_frame.to(device)
      feats = []
      for stype_key, typed_feat in tf_on_device.feat_dict.items():

          for i in range(len(tf_on_device.col_names_dict[stype_key])):
            if isinstance(typed_feat, MultiEmbeddingTensor):
                met_tensor = flatten_multi_embedding(
                    typed_feat, flatten_extra_dims = True
                )
                num_cols = typed_feat.num_cols
                emb_dim = typed_feat.values.shape[1] // num_cols


                sub_tensors = torch.split(met_tensor, emb_dim, dim=1)

                feature_name = tf_on_device.col_names_dict[stype_key][i]

                feature_value = sub_tensors[i].tolist()

                feats.append(f"{feature_name} {feature_value} [SEP] ")
                continue

            feature_name = tf_on_device.col_names_dict[stype_key][i]
            feature_value = typed_feat[0][i].tolist()

            feats.append(f"{feature_name} {feature_value} [SEP] ")


      linearized_tensors.append(remove_last_sep(''.join(feats)))
    return linearized_tensors

In [None]:
def text_embedding(linearize_features: list, embedder_model, device=None) -> list:
    """
    Convert a list of linearized features in a list of text embedded features (list of tensors).
    """
    if device is None:
        device = torch.device("cpu")

    # we compute the embedding of each linearized input feature
    embedded_features = []
    for feature in linearize_features:
        emb_feat = embedder_model(feature)
        embedded_features.append(emb_feat)

    return embedded_features

In [None]:
# we linearize the features in the format <name_feature_1> <val_1> [SEP] <name_feature_2> <val2> [SEP] ...
f1_linearized_features = linearize_features(node_features, device=device)

# now we compute the embedding for each linearized feature
embedding_model = GloveTextEmbedding(device)
f1_emb_features = text_embedding(f1_linearized_features, embedding_model, device)

In [None]:
# Now we have to change the format of our input data
# All the data must appear in the format of numpy arrays
f1_emb_features = torch.stack(f1_emb_features).cpu().numpy()

# we digitalize labels
f1_labels = np.array(labels)

f1_edges = np.array(output_edges)

f1_train_mask = np.array(train_mask)
f1_val_mask = np.array(val_mask)
f1_test_mask = np.array(test_mask)


data_to_save = {
    'node_features': f1_emb_features,
    'labels': f1_labels,
    'edges': f1_edges,
    'train_mask': f1_train_mask,
    'val_mask': f1_val_mask,
    'test_mask': f1_test_mask
}

with open('f1_data.pkl', 'wb') as f:
    pickle.dump(data_to_save, f)