# HM Purchase Recommendations: Link Prediction example using the GNN learning engine under the RelationalAI Snowflake Native App

### 🔐 Connecting to the RelationalAI Native App via Snowflake

To connect to the RelationalAI native app, we'll first set up a dictionary that defines the necessary environment variables for establishing a Snowflake connection. Here we will use
the `active_session` authentication method which gets all the credentials needed from the
current snowflake session. We only need to provide the application name (`RAI_EXPT_APP`).

#### 🔐 About Authentication

In this tutorial, we’ll use **active_session**. However, other authentication methods—such as **Key Pair Authentication**, **Password** or **OAuth Token Authentication**—are also supported. 


In [None]:
# load the snowflake configuration to a python dict
snowflake_config = {
    "app_name": "RELATIONALAI",
    "auth_method": "active_session"
}

## ⚙️ Managing Your GNN Engines

The `Provider` class in the RelationalAI GNN SDK allows you to manage your GNN engines seamlessly. Below, we walk through common operations you can perform with the `Provider`:

* ✅ Create a new GNN engine
* 📋 List all available engines
* 🔍 Check the status of an engine
* 🔄 Resume a paused engine
* ❌ Delete an engine

Each of these operations can be done with simple method calls, as shown in the following examples.

In [1]:
from rai_gnns_experimental import GNNTable, ForeignKey
from rai_gnns_experimental import ColumnDType
from rai_gnns_experimental import EvaluationMetric
from rai_gnns_experimental import LinkTask, TaskType
from rai_gnns_experimental import Dataset
from rai_gnns_experimental import TrainerConfig
from rai_gnns_experimental import Trainer
from rai_gnns_experimental import JobManager
from rai_gnns_experimental import OutputConfig, SnowflakeConnector, Provider

from graphviz import Source
import time

In [None]:
# engine setup
engine_name = "my_engine"
engine_size = "GPU_NV_S" # Available sizes: "GPU_NV_S" or "HIGHMEM_X64_S"

# database /s chema
DB_NAME = "HM_DB"
DB_SCHEMA_NAME = "HM_SCHEMA"
TASK_SCHEMA_NAME = "HM_PURCHASE"
DB_SCHEMA = f"{DB_NAME}.{DB_SCHEMA_NAME}"
TASK_SCHEMA = f"{DB_NAME}.{TASK_SCHEMA_NAME}"

# dataset name
DATASET_NAME = "hm_purchase_dataset"

# customer data
CUSTOMER_SOURCE_TABLE = f"{DB_SCHEMA}.CUSTOMERS"
CUSTOMER_NAME = "customers"
CUSTOMER_PRIMARY_KEY = "customer_id"

# article data
ARTICLE_SOURCE_TABLE = f"{DB_SCHEMA}.ARTICLES"
ARTICLE_NAME = "articles"
ARTICLE_PRIMARY_KEY = "article_id"

# transaction data
TRANSACTION_SOURCE_TABLE = f"{DB_SCHEMA}.TRANSACTIONS"
TRANSACTION_NAME = "transactions"
TIME_COLUMN = "t_dat"

# node task
LINK_TASK_NAME = "purchase_task"
TASK_TIME_COLUMN_NAME = "timestamp"
TASK_TRAIN_TABLE = f"{TASK_SCHEMA}.TRAIN"
TASK_TEST_TABLE  = f"{TASK_SCHEMA}.TEST"
TASK_VALIDATION_TABLE = f"{TASK_SCHEMA}.VALIDATION"

# model params
MODEL_DEVICE = "cuda" # either 'cuda' or 'cpu'
MODEL_N_EPOCHS = 3
MODEL_MAX_ITERS = 200
MODEL_TEXT_EMBEDDER = "model2vec-potion-base-4M"


OUTPUT_ALIAS = "PURCHASE_TEST_PREDS_1"
OUTPUT_TABLE = f"PREDICTIONS_{OUTPUT_ALIAS}"
TEST_BATCH_SIZE = 128

In [None]:
# initialize the provider using the snowflake configuration
# (note: you might be prompted from your MFA app at this point)
provider = Provider(**snowflake_config)

In [None]:
# Create a new GNN engine.
# Currently supported engine types:
# - For Snowflake accounts hosted on AWS: "GPU_NV_S" and "HIGHMEM_X64_S"
# - For Snowflake accounts hosted on Azure: "HIGHMEM_X64_S" only
#
# You must provide:
# - A custom engine name (via the `name` parameter)
# - The engine type (via the `size` parameter)
#
# Available sizes:
# - AWS: "GPU_NV_S", "HIGHMEM_X64_S"
# - Azure: "HIGHMEM_X64_S"
#
# provider.create_gnn(
#    name=engine_name,
#    size="GPU_NV_S"  # or "HIGHMEM_X64_S"
#)


# The creation of a new engine may take 
# several minutes to complete (e.g., ~4 minutes).
# Be patient and do not interrupt the process.


# check if engine exists, if yes, resume if not create it
if not provider.get_gnn(engine_name): 
    print(f"Creating Engine {engine_name}")
    provider.create_gnn(name=engine_name, size=engine_size)
else:
    print(f"Resuming Engine {engine_name}")
    provider.resume_gnn(name=engine_name)

In [None]:
# HINT: We can always resume a GNN engine that has been suspended:
# Note: Engine provisioning can take some minutes. Please
# check the engine status using provider.get_gnn(name="my_engine")

# provider.resume_gnn(name=engine_name)

# And if we need we can also delete a GNN engine

# provider.delete_gnn(name=engine_name)

# we can also check the existing engines. If there is not engines listed here you would need to create one using provider.create_gnn(name="my_engine", size="GPU_NV_S")
# provider.list_gnns()


In [None]:
# If we resume an engine, we can directly see the status of the engine
# the status of the engine 'READY' marks the fact that the
# engine is ready to be used. A `PENDING` status marks the
# fact that the engine has  been automaticaly suspended.
# Notice also that under the settings 
# the provider exposes a URL for the MLFLOW endpoint
# that we can use to track our experiments
# NOTE: we should wait until the status is READY

engine_data = provider.get_gnn(engine_name)
if engine_data["state"] == 'SUSPENDED':
    print(f'ENGINE {engine_name} IS SUSPENDED, YOU HAVE TO RESUME IT FIRST')
else:
    while not engine_data or engine_data["state"] != 'READY':
        time.sleep(10)
        engine_data = provider.get_gnn(engine_name) 
    

    print(f'ENGINE {engine_name} READY')

⚠️ **Warning:** To  make sure to check the engine status (e.g., via `provider.get_gnn("engine_name")`) and make sure that the status is `READY` 

## 🔌 Connector Setup

The `Connector` class, like the `Provider` class, is used to communicate with Snowflake. However, while the `Provider` is responsible for managing GNN engines, the `Connector` is specifically used to interface with the **GNN learning engine** itself.

You’ll use the `Connector` instance as an input to all SDK components that need to send requests to the GNN engine—such as loading data, running training jobs, or performing inference.

In short:

* `Provider` → Manages GNN engine instances (create, list, delete, etc.)
* `Connector` → Sends requests to a specific GNN engine for processing tasks

Let’s now walk through how to create and use a `Connector`.


In [None]:
# we initialize the connector and passing all our credentials.
connector = SnowflakeConnector(
    **snowflake_config,
    engine_name=engine_name,
)
# the connector also provides access to MLFLOW that you can
# use to monitor your experiments and register trained GNN models
# connector.mlflow_session_url

## 📈 MLflow: Monitor Training

You can visit MLflow to monitor the training process in real time, including loss trends and evaluation metrics, using the mlflowendpoint ingress url. For a detailed example on how to use MLflow you can visit https://github.com/RelationalAI/rai-gnns-tutorial/blob/main/HM/MLflow.md 

In [None]:
connector.get_gnn(engine_name)["settings"]["mlflowendpoint"]

## 📊 Preparing the data: Creating the GNN tables

In this section, we will define the GNN tables and the associated learning task. These components will then be used to construct a GNN dataset suitable for training.

For this tutorial, we’ll use the H&M database as our working example. This database includes three tables: CUSTOMERS, ARTICLES, and TRANSACTIONS. The TRANSACTIONS table links CUSTOMERS to ARTICLES. Our objective is to predict if a given CUSTOMER is going to churn in the next week, meaning if they are going to stop making any TRANSACTIONS. We handle this problem as a binary node classification task (churn=1/no churn=0).


In [None]:
# create a table for the customers and set the 
# customer_id as a primary key (primary and 
# foreign keys are used to construct the edges of the graph)
customers_table = GNNTable(
    connector=connector,
    source=CUSTOMER_SOURCE_TABLE,
    name=CUSTOMER_NAME,
    primary_key=CUSTOMER_PRIMARY_KEY,
)
customers_table.show_table()

In [None]:
# in a similar manner we can create the ARTICLES table
articles_table = GNNTable(
    connector=connector,
    source=ARTICLE_SOURCE_TABLE,
    name=ARTICLE_NAME,
    primary_key=ARTICLE_PRIMARY_KEY,
)
articles_table.show_table()

In [None]:
# and finally we will link the two tables using the foreign
# keys from the TRANSACTIONS table. Note: the transactions
# table  has also one special "time column" that will be used
# to prevent data leakage (see the documentation for more details)
transactions_table = GNNTable(
    connector=connector,
    source=TRANSACTION_SOURCE_TABLE,
    name=TRANSACTION_NAME,
    foreign_keys=[
        ForeignKey(
            column_name=CUSTOMER_PRIMARY_KEY, link_to=CUSTOMER_NAME+"."+CUSTOMER_PRIMARY_KEY),
        ForeignKey(
            column_name=ARTICLE_PRIMARY_KEY, link_to=ARTICLE_NAME+"."+ARTICLE_PRIMARY_KEY),
    ],
    time_column=TIME_COLUMN,
)
transactions_table.show_table()


## 🔧 Preparing the Data: Creating the Task

To define the task, we begin by specifying the locations of the training, validation, and test datasets. We also identify the source and destination entity tables, along with the corresponding columns that uniquely identify each entity.

Since this is a **link prediction** task, our objective is to predict future connections between a source entity and a destination entity.

Additionally, we define a **timestamp column** to avoid information leakage by ensuring that future data doesn't influence past predictions. Lastly, we specify the **evaluation metric**—in this case, **Mean Average Precision (MAP)**—to assess the performance of the model.


In [None]:
link_pred_task = LinkTask(
    connector=connector,
    name=LINK_TASK_NAME,
    task_data_source={
        "train": TASK_TRAIN_TABLE, 
        "test": TASK_TEST_TABLE, 
        "validation": TASK_VALIDATION_TABLE
    },
    # name of source entity column that we want to do predictions for
    source_entity_column=CUSTOMER_PRIMARY_KEY,
    # name of GNN table that column is at
    source_entity_table=CUSTOMER_NAME,
    # name of target entity column that we want to predict
    target_entity_column=ARTICLE_PRIMARY_KEY,
    # name of GNN table that column is at
    target_entity_table=ARTICLE_NAME,
    time_column=TASK_TIME_COLUMN_NAME,
    task_type=TaskType.LINK_PREDICTION,
    evaluation_metric=EvaluationMetric(name="link_prediction_map", eval_at_k=12),
)

link_pred_task.show_task()

## 🧩 Preparing the Data: Creating the Dataset

Finally, we combine all the components by constructing a dataset object that encapsulates both the GNN tables and the task definition. This dataset will serve as the input to the model training pipeline, ensuring that the task and its associated data are tightly integrated and ready for downstream processing.


In [None]:
dataset = Dataset(
    connector=connector,
    dataset_name=DATASET_NAME,
    tables=[articles_table, customers_table, transactions_table],
    task_description=link_pred_task,
)

In [None]:
# we can also visualize the dataset 
graph = dataset.visualize_dataset(show_dtypes=True)
# play with font size and plot size to get a good visualization
for node in graph.get_nodes():    
    font_size = node.get_attributes()['fontsize']
    font_size = "16"
    node.set('fontsize', font_size)

graph.set_graph_defaults(size="10,10!")  # Increase graph size

src = Source(graph.to_string())
src  # Display in notebo

## 🚀 GNN Model Training

Now that our dataset is ready, we can train our first GNN model. We’ll begin by defining a **configuration** that specifies the training parameters, such as model architecture, optimizer settings, and training duration.

Next, we’ll instantiate a **trainer** using this configuration. The trainer will consume the dataset we previously created and manage the entire training process. By calling the `fit()` method on the trainer, we initiate a training job—whose progress and status can be monitored throughout execution.

In [None]:
# the first step will be to define a configuration for our Trainer.
# the configuration includes many parameters that are explained in
# detail in the documentation. It does not only provide parameters
# for the graph neural network but also parameters for other components
# of the model (such as feature extractors, prediction head parameters,
# training parameters etc.)
model_config = TrainerConfig(
    connector=connector,
    device=MODEL_DEVICE,  # either 'cuda' or 'cpu'
    n_epochs=MODEL_N_EPOCHS,
    max_iters=MODEL_MAX_ITERS,
    text_embedder=MODEL_TEXT_EMBEDDER,
)

In [None]:
# we initialize now our trainer object with the trainer configuration
# the trainer object can be used to train a model, to perform inference
# or to perform training & inference.
trainer = Trainer(connector=connector, config=model_config)

In [None]:
# in our first example we will use the trainer to perform training only.
# every time the trainer is "executed" (calling fit(), predict() or fit_predict())
# it returns a job object that can be used to monitor the current running job.
# See the documentation for the meaning of the job statuses
train_job = trainer.fit(dataset=dataset)

In [None]:
# we can also stream the logs of the training job in real time
# Hint: You can stop the cell execution to stop monitoring of logs
# Hint: At this point you can also open the MLFLow URL to monitor your experiments
train_job.stream_logs()

In [None]:
# hint: one can cancel a running job as well
# train_job.cancel()

In [None]:
# now we can monitor the job status
# observe that once the job is running we also get back an experiment name
# we will see later how we can use that to perform inference
train_job.get_status()

## 🔍 Inference Using a Trained Model

Finally, we’ll demonstrate how to perform inference using the model we’ve just trained. In this example, we'll directly use the recently trained model to generate predictions.

For more advanced use cases—such as registering a model for reuse or automatically selecting the best-performing model—please refer to the churn prediction notebook.


⚠️ **Warning:** Existing Snowflake tables are never overwritten. To run inference and save predictions to a Snowflake table, you must either:

* Change the OUTPUT_TABLE to a new, non-existent table, or

* Delete the existing OUTPUT_TABLE (if you have the necessary permissions). You can do this by running the cell below.

In [None]:
# remove the previous predictions, if the table exists
df = provider._session.sql(f"SELECT * FROM {DB_NAME}.information_schema.tables WHERE table_name = '{OUTPUT_TABLE}';"); 
if (len(df.collect()) > 0): # table exists
    df = provider._session.sql(f"GRANT OWNERSHIP ON {DB_SCHEMA}.{OUTPUT_TABLE} TO ROLE ACCOUNTADMIN REVOKE CURRENT GRANTS;") ; df.collect()
    df = provider._session.sql(f"DROP TABLE IF EXISTS {DB_SCHEMA}.{OUTPUT_TABLE};") ; df.collect()

In [None]:
output_config = OutputConfig.snowflake(database_name=DB_NAME, schema_name="PUBLIC")
# make sure that the table with the same alias does not already exist
# we never overwrite tables

inference_job = trainer.predict(
    output_alias=OUTPUT_ALIAS,
    output_config=output_config,
    test_batch_size=TEST_BATCH_SIZE,
    dataset=dataset,
    model_run_id=train_job.model_run_id,
    extract_embeddings=True,
)

inference_job.stream_logs()

In [None]:
inference_job.get_status()

In [None]:
# Finally, let's take a look at some of the predictions done by our GNN model
# These predictions are saved in the OUTPUT_TABLE 

df = provider._session.sql(f"SELECT * FROM {OUTPUT_TABLE} LIMIT 100;") ; df.collect()

## 📋 Job Manager

It might be the case that we have lost track of the jobs that we are running. To this end we also provide to the user a JobManager object that can give us the status of all jobs.

In [None]:
# Let's see an example:
job_manager = JobManager(connector=connector)
job_manager.show_jobs()

In [None]:
# You can retrieve job details using its job ID.
# For example, use the job ID of the training job above to access its details
# and use the trained model for inference.

# NOTE: To run this cell, replace the job ID below with the actual ID
# from your training job output in the previous cell.

retrieved_job = job_manager.fetch_job("01bdaf68-020e-076c-000a-1dc701bcbba6")
# hint: the job manager can be used to cancel any job as well

In [None]:
# remove the previous predictions, if the table exists
df = provider._session.sql(f"SELECT * FROM {DB_NAME}.information_schema.tables WHERE table_name = '{OUTPUT_TABLE}';"); 
if (len(df.collect()) > 0): # table exists
    df = provider._session.sql(f"GRANT OWNERSHIP ON {DB_SCHEMA}.{OUTPUT_TABLE} TO ROLE ACCOUNTADMIN REVOKE CURRENT GRANTS;") ; df.collect()
    df = provider._session.sql(f"DROP TABLE IF EXISTS {DB_SCHEMA}.{OUTPUT_TABLE};") ; df.collect()

In [None]:
output_config = OutputConfig.snowflake(database_name=DB_NAME, schema_name="PUBLIC")
# make sure that the table with the same alias does not already exist
# we never overwrite tables
# use the retrieved job to get the trained model for inference

inference_job = trainer.predict(
    output_alias=OUTPUT_ALIAS,
    output_config=output_config,
    test_batch_size=TEST_BATCH_SIZE,
    dataset=dataset,
    model_run_id=retrieved_job.model_run_id,
    extract_embeddings=False,
)

inference_job.stream_logs()