In [1]:
import os
from dotenv import load_dotenv

from rai_gnns_experimental import GNNTable, ForeignKey
from rai_gnns_experimental import ColumnDType
from rai_gnns_experimental import EvaluationMetric
from rai_gnns_experimental import LinkTask, TaskType
from rai_gnns_experimental import Dataset
from rai_gnns_experimental import TrainerConfig
from rai_gnns_experimental import Trainer
from rai_gnns_experimental import SnowflakeConnectorDirectAccess

from IPython.display import Image, display


# Jupyter Magic Commands
%load_ext autoreload
%load_ext jupyter_black
%autoreload 2

In [None]:
# Load Environment Variables
load_dotenv()

### 1. Setup snowflake connection and input data

In [None]:
snowflake_config = {
    "account": os.getenv("ACCOUNT_NAME"),
    "user": os.getenv("USER_NAME"),
    "password": os.getenv("PASSWORD"),
    "warehouse": os.getenv("WAREHOUSE"),
    "app_name": os.getenv("APP_NAME"),
    "auth_method": "password",
}

connector = SnowflakeConnectorDirectAccess(
    **snowflake_config,
    endpoint_url="https://bmg4y5qc-ndsoebe-rai-gnns-test.snowflakecomputing.app",
)

In [None]:
buyers_pth = "GNN_DEMO_TF.TF_DATA.BUYERS"
senders_pth = "GNN_DEMO_TF.TF_DATA.SENDERS"
transactions_pth = "GNN_DEMO_TF.TF_DATA.TRANSACTIONS"
train_pth = "GNN_DEMO_TF.TF_LINK_PRED.TRAIN"
val_pth = "GNN_DEMO_TF.TF_LINK_PRED.VAL"
test_pth = "GNN_DEMO_TF.TF_LINK_PRED.TEST"

### Create GNN Tables (Graph Initialization)

We create three different tables, a table for the buyers a table
for the senders and a table with the transactions between them

In [None]:
buyers_table = GNNTable(
    connector=connector,
    source=buyers_pth,
    name="buyers",
    primary_key="BUY_TOKEN_ADDRESS",
)
buyers_table.show_table()

In [None]:
senders_table = GNNTable(
    connector=connector,
    source=senders_pth,
    name="senders",
    primary_key="TX_SENDER_ADDRESS",
)
senders_table.show_table()

In [None]:
transactions_table = GNNTable(
    connector=connector,
    source=transactions_pth,
    name="transactions",
    foreign_keys=[
        ForeignKey(
            column_name="TX_SENDER_ADDRESS", link_to="senders.TX_SENDER_ADDRESS"
        ),
        ForeignKey(column_name="BUY_TOKEN_ADDRESS", link_to="buyers.BUY_TOKEN_ADDRESS"),
    ],
    time_column="BLOCK_TIMESTAMP",
)
# modify column dtypes if needed
transactions_table.update_column_dtype(
    col_name="BUY_AMOUNT", dtype=ColumnDType.integer_t
)
transactions_table.update_column_dtype(
    col_name="SELL_AMOUNT", dtype=ColumnDType.integer_t
)
transactions_table.show_table()

### Preparing the link prediction task

In [None]:
# our task is a classic link prediction task,
# we will predict if there exists an edge between
# senders and buyers.

link_pred_task = LinkTask(
    connector=connector,
    name="link_prediction_example",
    task_data_source={"train": train_pth, "test": test_pth, "validation": val_pth},
    source_entity_column="TX_SENDER_ADDRESS",
    source_entity_table="senders",
    target_entity_column="BUY_TOKEN_ADDRESS",
    target_entity_table="buyers",
    task_type=TaskType.LINK_PREDICTION,
    time_column="BLOCK_TIMESTAMP",
    evaluation_metric=EvaluationMetric(name="link_prediction_map", eval_at_k=10),
)
link_pred_task.show_task()

### Putting it all together (dataset creation)

In [None]:
dataset = Dataset(
    connector=connector,
    dataset_name="tokenflow",
    tables=[buyers_table, transactions_table, senders_table],
    task_description=link_pred_task,
)

In [None]:
graph = dataset.visualize_dataset(show_dtypes=True)
graph.set_graph_defaults(size="50,50!")  # Increase graph size
plt = Image(graph.create_png(), width=600, height=600)
display(plt)

### Train a GNN on the link prediction task

In [None]:
# create the GNN model configuration
train_config = TrainerConfig(
    connector=connector,
    device="cuda",
    n_epochs=5,
    id_awareness=True,
    temporal_strategy="last",
    max_iters=2000,
    train_batch_size=128,
    shallow_embeddings_list=["buyers"],
)
# initialize the Trainer
trainer = Trainer(connector=connector, config=train_config)

In [None]:
# submit a training job
train_job = trainer.fit(dataset=dataset)

In [None]:
# job status
train_job.get_status()

In [None]:
train_job.stream_logs()

In [None]:
# when the model is ready we can get a model run ID
# this can be used to run inference
train_job.get_status()