In [1]:
from corvic.model import Source, Column, feature_type, FeatureView, FeatureViewEdgeTableMetadata
from corvic import system_sqlite
from corvic import system
from pathlib import Path
import tempfile

In [2]:
from TemporalDataHandler import TemporalDataHandler
from TGNEncoder import TGNEncoder

In [None]:
import numpy as np

from corvic.model import FeatureView, FeatureViewEdgeTableMetadata
import polars as pl

import tempfile

import deepgnn.graph_engine.snark.convert as convert
from deepgnn.graph_engine.snark.decoders import EdgeListDecoder

from deepgnn.graph_engine.snark.distributed import Server, Client as DistributedClient


class DeepGnnGraph:
  def __init__(self, feature_view:FeatureView, feature_column:str, compute_server_url:str, random_feature_dim:int,  partitions:int = 2):
      self.feature_view = feature_view
      self.partitions = partitions
      self.graph_engine = compute_server_url
      self.random_feature_dim = random_feature_dim
      self.feature_column_name = feature_column

  def _edge_generator(self):
      table = self.feature_view.output_edge_tables()[0]
      edge_table_info = table.get_typed_metadata(FeatureViewEdgeTableMetadata)
      batch = list(table.to_polars().unwrap_or_raise())[0]
      df = batch.with_columns(
          pl.col(edge_table_info.start_source_column_name).alias("start_id"),
          pl.lit(edge_table_info.start_source_name).alias("start_source"),
          pl.col(edge_table_info.end_source_column_name).alias("end_id"),
          pl.lit(edge_table_info.end_source_name).alias("end_source"),
      ).select("start_id", "end_id")
      edges_dict = df.to_dict(as_series=False)
      return [(start_id, 'None', end_id, 1) for start_id, end_id in zip(edges_dict['start_id'], edges_dict['end_id'])]

  def _node_generator(self):
    nodes = {}
    for nodes_source in self.feature_view.sources:
        if nodes_source.table.schema.get_primary_key() is not None:
            nodes_df = pl.concat(nodes_source.table.to_polars().unwrap_or_raise())
            nodes_dict = nodes_df.to_dict(as_series=False)
            id_column = nodes_source.table.schema.get_primary_key().name
            if self.feature_column_name is None:
                print("Feature column not specified, using random features.")
                for entity_id in nodes_dict[id_column]:
                    nodes[entity_id] = ('None', 1, {'float32': np.random.rand(self.random_feature_dim).astype('float32')})
            else:
                for entity_id, features in zip(nodes_dict[id_column], nodes_dict[self.feature_column_name]):
                    nodes[entity_id] = ('None', 1, {'float32': features})
    return nodes

  def _to_deepgnn_format(self, delimiter=","):
    nodes, edges = self._node_generator(), self._edge_generator()
    node_id_map = {node_id: idx for idx, node_id in enumerate(nodes.keys())}
    node_type_map = {node_type: idx for idx, node_type in enumerate(set(node_info[0] for node_info in nodes.values()))}
    edge_label_map = {edge_label: idx for idx, edge_label in enumerate(set(edge[1] for edge in edges))}

    def format_features(features):
        formatted = []
        for dtype, values in features.items():
            formatted.append(f"{dtype},{len(values)},{','.join(map(str, values))}")
        return delimiter.join(formatted)

    def format_node(node_id, node_info):
        mapped_node_id = node_id_map[node_id]
        mapped_node_type = node_type_map[node_info[0]]
        node_weight = node_info[1]
        features = node_info[2]
        formatted_features = format_features(features)
        return f"{mapped_node_id},-1,{mapped_node_type},{node_weight},{formatted_features}"

    def format_edge(edge_info):
        src, edge_label, dst, edge_weight = edge_info
        mapped_src = node_id_map[src]
        mapped_dst = node_id_map[dst]
        mapped_edge_label = edge_label_map[edge_label]
        return f"{mapped_src},{mapped_edge_label},{mapped_dst},{edge_weight}"

    node_lines = [format_node(node_id, info) for node_id, info in nodes.items()]
    edge_lines = [format_edge(info) for info in edges]
    all_lines = sorted(node_lines + edge_lines, key=lambda x: int(x.split(delimiter)[0]))
    return "\n".join(all_lines)

  def create(self,working_dir):
    edgelist = self._to_deepgnn_format()
    with tempfile.NamedTemporaryFile(delete=True) as temp_file:
        temp_file.write(edgelist.encode())
        temp_file.flush()

        edge_decoder = EdgeListDecoder()
        convert.MultiWorkersConverter(
            graph_path=temp_file.name,
            partition_count=self.partitions,
            output_dir=working_dir.name,
            decoder=edge_decoder,
        ).convert()

    s = Server(self.graph_engine, working_dir.name, 0, 1)
    graph = DistributedClient([self.graph_engine])
    return graph

In [None]:
import torch
import numpy as np

from deepgnn.graph_engine import SamplingStrategy



from deepgnn.graph_engine.graph_ops import sub_graph

from torch.utils.data import IterableDataset

class Sampler(IterableDataset):
    def __init__(self, batch_size: int, hops: int, graph, feature_dim: int):
        super().__init__()
        self.batch_size = batch_size
        self.hops = hops
        self.graph = graph
        self.feature_dim = feature_dim

    def __iter__(self):
        return map(self.query, range(0, self.graph.node_count(np.array([0], dtype=np.int32)), self.batch_size))

    def _sampler(self,seed):
        nodes, edge_index, _ = sub_graph(self.graph, seed, np.array([0], dtype=np.int32), num_hops=self.hops, return_edges=True)
        edge_index = edge_index.transpose()
        return nodes, edge_index

    def query(self, batch_id: int) -> tuple:
        seed = self.graph.sample_nodes(self.batch_size, np.array([0], dtype=np.int32), strategy=SamplingStrategy.Random)
        seed, inverse_seed = np.unique(seed, return_inverse=True)
        nodes, edge_index = self._sampler(seed)
        feats = self.graph.node_features(nodes, np.array([[0, self.feature_dim]], dtype=np.int32), np.float32)
        return torch.tensor(feats, dtype=torch.float), torch.tensor(edge_index, dtype=torch.long)

In [None]:
import polars as pl

def construct_sources(client: system.Client):
    node_df = pl.read_parquet(f"{path}/nodes.parquet").unique()
    edge_df = pl.read_parquet(f"{path}/edges.parquet").unique()

    nodes_source = (
            Source.from_polars("concepts", node_df, client)
            .with_feature_types(
                {
                    "node_id": feature_type.primary_key(),
                    "node_label": feature_type.text(),
                }
            ).as_dimension_table()
        ).register()

    edges_source = (
            Source.from_polars("trends", edge_df, client)
            .with_feature_types(
                {
                    "source_id": feature_type.foreign_key(nodes_source.id),
                    "destination_id": feature_type.foreign_key(nodes_source.id),
                    "weight": feature_type.numerical(),
                    "year": feature_type.numerical(),
                }
            )
            .as_fact_table()
        ).register()

    return nodes_source, edges_source

In [None]:
path = "gs://datasets-dev-ded86f66/benchmarks/scientific_trend_prediction/feature_view_data"

with tempfile.TemporaryDirectory() as tdir:
    client = system_sqlite.Client(Path(tdir) / "corvic_data.sqlite3")
    nodes_source, edges_source = construct_sources(client)

In [None]:
handler = TemporalDataHandler(nodes_source, edges_source, 1980, 2023, 1, client)

In [None]:
gnn = TGNEncoder(handler,10)

In [None]:
gnn.train()

In [None]:
embeddings = gnn.generate_node_embeddings()

In [None]:
import torch
import io
from google.cloud import storage

def upload_tensor_to_gcs(tensor, bucket_name, destination_blob_name):
    """Uploads a tensor to the GCS bucket."""
    client = storage.Client()
    bucket = client.bucket(bucket_name)
    blob = bucket.blob(destination_blob_name)

    buffer = io.BytesIO()
    torch.save(tensor, buffer)
    buffer.seek(0)

    blob.upload_from_file(buffer, content_type='application/octet-stream')

bucket_name = 'datasets-dev-ded86f66'
prefix = 'benchmarks/scientific_trend_prediction/Tgn_embeddings'

for year in range(len(embeddings)):
    node_embeddings = embeddings[year]
    gcs_file = f'{prefix}/{1980 + year}.pt'

    upload_tensor_to_gcs(node_embeddings, bucket_name, gcs_file)

    print(f'Saved embeddings for year {1980 + year} to {gcs_file}')