In [1]:
from relbench.datasets import get_dataset, get_dataset_names, register_dataset

In [2]:
import os
import pandas as pd
import numpy as np
from relbench.base import Database, Dataset, Table

class TransactionalDataset(Dataset):
    # Set timestamps or other relevant information if needed
    val_timestamp = pd.Timestamp("2022-02-15")
    test_timestamp = pd.Timestamp("2022-02-22")

    def make_db(self) -> Database:
        # Path to your CSVs folder
        path = os.path.join("C:/Users/KN2C/Desktop/Dani/relbench/relbench/", "hyper_data")
        customers = os.path.join(path, "Customers.csv")
        articles = os.path.join(path, "Articles.csv")
        branches = os.path.join(path, "Branches.csv")
        transactions = os.path.join(path, "Transactions.csv")

        # Ensure that CSV files exist in the specified path
        if not os.path.exists(customers):
            raise RuntimeError(f"Dataset not found at '{path}'. Please make sure the CSV files are in the correct folder.")

        # Read the CSV data into pandas DataFrames
        customers_df = pd.read_csv(customers)
        articles_df = pd.read_csv(articles)
        branches_df = pd.read_csv(branches)
        transactions_df = pd.read_csv(transactions)

        ################################################################################
        # Check for and handle duplicate primary keys in articles, customers, and branches tables
        ################################################################################

        # Handle duplicates in the articles table
        if articles_df.duplicated(subset=['articles_id']).any():
            print("Duplicates found in the 'articles_id' column. Removing duplicates...")
            articles_df = articles_df.drop_duplicates(subset=['articles_id'], keep='first')

        # Handle duplicates in the customers table
        if customers_df.duplicated(subset=['customers_id']).any():
            print("Duplicates found in the 'customers_id' column. Removing duplicates...")
            customers_df = customers_df.drop_duplicates(subset=['customers_id'], keep='first')

        # Handle duplicates in the branches table
        if branches_df.duplicated(subset=['BranchCode']).any():
            print("Duplicates found in the 'BranchCode' column. Removing duplicates...")
            branches_df = branches_df.drop_duplicates(subset=['BranchCode'], keep='first')

        ################################################################################
        # Clean and process the data (drop unnecessary columns, handle missing data)
        ################################################################################
        # Drop unnecessary columns
        transactions_df.drop(columns=["Return Amount"], inplace=True)
        articles_df.drop(columns=["Item Barcode", "External Item Number"], inplace=True)

        # Replace any missing or invalid values
        transactions_df["salesTime"] = transactions_df["salesTime"].replace(r"^\\N$", "00:00:00", regex=True)
        transactions_df = transactions_df.replace(r"^\\N$", np.nan, regex=True)

        # Combine date and time into a single 'datetime' column
        # transactions_df['datetime'] = pd.to_datetime(transactions_df['d_dat'] + ' ' + transactions_df['salesTime'])
        # transactions_df.drop(columns=["d_dat"], inplace=True)        
        # Convert date column to pd.Timestamp
        # transactions_df["datetime"] = pd.to_datetime(transactions_df["datetime"])

        transactions_df["datetime"] = pd.to_datetime(
        transactions_df["d_dat"], format="%Y-%m-%d"
        )
        transactions_df.drop(columns=["d_dat"], inplace=True)          
        # Convert other fields if necessary
        transactions_df['price_purchase'] = pd.to_numeric(transactions_df['price_purchase'], errors='coerce')
        transactions_df['Discount_ratio'] = pd.to_numeric(transactions_df['Discount_ratio'], errors='coerce')
        transactions_df['Quantity'] = pd.to_numeric(transactions_df['Quantity'], errors='coerce')

        ################################################################################
        # Now we define the table structure and relationships.
        ################################################################################

        tables = {}

        # Articles table
        tables["article"] = Table(
            df=pd.DataFrame(articles_df),
            fkey_col_to_pkey_table={},
            pkey_col="articles_id",
            time_col=None,
        )

        # Customers table
        tables["customer"] = Table(
            df=pd.DataFrame(customers_df),
            fkey_col_to_pkey_table={},
            pkey_col="customers_id",
            time_col=None,
        )

        # Branches table (renamed from "branche" to "branches")
        tables["branches"] = Table(
            df=pd.DataFrame(branches_df),
            fkey_col_to_pkey_table={},
            pkey_col="BranchCode",
            time_col=None,
        )

        # Transactions table
        tables["transactions"] = Table(
            df=pd.DataFrame(transactions_df),
            fkey_col_to_pkey_table={
                "articles_id": "article",    # Foreign key to articles
                "customers_id": "customer",  # Foreign key to customers
                "BranchCode": "branches",    # Foreign key to branches
            },
            pkey_col=None,
            time_col="datetime",  # Use the combined datetime column for time-based operations
        )

        return Database(tables)


In [3]:
transactional_dataset = TransactionalDataset()
db = transactional_dataset.make_db()

Duplicates found in the 'articles_id' column. Removing duplicates...
Duplicates found in the 'customers_id' column. Removing duplicates...


In [4]:
table = db.table_dict["transactions"]

In [5]:
db.table_dict["transactions"]

Table(df=
         factor_id   articles_id  customers_id  BranchCode salesTime  \
0                0  3.515792e+07    9150651948           9  14:35:06   
1                0  2.172070e+08    9150651948           9  14:35:06   
2                0  9.892272e+08    9150651948           9  14:35:06   
3                0  2.011346e+06    9150651948           9  14:35:06   
4                0  2.006610e+06    9150651948           9  14:35:06   
...            ...           ...           ...         ...       ...   
6950117    1273663  3.214409e+07    9157962188         356  11:56:45   
6950118    1273664  1.234744e+09    9153858979         356  18:53:39   
6950119    1273665  2.013925e+06    9123311767         275  10:29:11   
6950120    1273666  1.682033e+08    9393251711           4  09:16:39   
6950121    1273667  1.234744e+09    9151875866         356  19:16:49   

         price_purchase  Discount_ratio  Quantity   datetime  
0               19610.1        0.266504     1.000 2021-03-22  

In [6]:
db.table_dict["branches"]

Table(df=
   BranchCode             BranchName
0           9   فروشگاه شاندیز(مشهد)
1         275   فروشگاه سمنان(ققنوس)
2         509  فروشگاه یزد(امام علی)
3           4     فروشگاه گنبد کاووس
4         356         فروشگاه بجنورد,
  fkey_col_to_pkey_table={},
  pkey_col=BranchCode,
  time_col=None)

In [7]:
db.table_dict["article"]

Table(df=
        articles_id                                           art_name  \
0      3.515792e+07                              اصالت آبلیمو 900 گرمی   
1      2.172070e+08                  رامک پنیر سفید پروبیوتیک 400 گرمی   
2      9.892272e+08                                   تخم مرغ طلقی فله   
3      2.011346e+06                      سس گلوریا 88 گرم فلفل زرد تند   
4      2.006610e+06                   فانتا نوشابه لیمویی1500 سی سی پت   
...             ...                                                ...   
44876  1.073731e+09              سحرکمپوت زردالو قوطی ایزی اپن430 گرمی   
44877  1.083732e+09                        رزگلد سینی تخت گلدار سایز 1   
44894  1.063731e+09                                فامیلی کایل اصلی A5   
44899  1.091735e+09  سام آرشیت مانابلوز بچه گانه مد کد 950 بغل مشکی...   
44900  1.061730e+09                            فامیلی کیف دسته قهوه ای   

                           Department Name                   group  \
0                   FMCG - کالا

In [8]:
db.table_dict["customer"]

Table(df=
        customers_id  customers_no
0         9150651948     508206249
1         9155143265      36124744
2         9144518134      42737591
3         9151739848      24102811
4         9138537082      72155679
...              ...           ...
300316    9152264499    1394869309
300317    9166036927      72736526
300318    9131583129      72736530
300319    9132515574      72836534
300320    9308737056    1193953900

[299536 rows x 2 columns],
  fkey_col_to_pkey_table={},
  pkey_col=customers_id,
  time_col=None)

In [9]:
table.df.iloc[table.df["datetime"].idxmax()]


factor_id                     1269935
articles_id                 2000568.0
customers_id               9151861977
BranchCode                        356
salesTime                    18:16:41
price_purchase                30737.6
Discount_ratio               0.043444
Quantity                          1.0
datetime          2022-03-12 00:00:00
Name: 6930864, dtype: object

In [10]:
table.df.iloc[table.df["datetime"].idxmin()]

factor_id                           0
articles_id                35157919.0
customers_id               9150651948
BranchCode                          9
salesTime                    14:35:06
price_purchase                19610.1
Discount_ratio               0.266504
Quantity                          1.0
datetime          2021-03-22 00:00:00
Name: 0, dtype: object

In [11]:
register_dataset("hyperr-aras", TransactionalDataset)
get_dataset_names()

['rel-amazon',
 'rel-avito',
 'rel-event',
 'rel-f1',
 'rel-hm',
 'rel-stack',
 'rel-trial',
 'hyperr-aras']

In [13]:
hyper_dataset = get_dataset("hyperr-aras")
hyper_dataset

TransactionalDataset()

In [14]:
hyper_dataset.val_timestamp, hyper_dataset.test_timestamp

(Timestamp('2022-02-15 00:00:00'), Timestamp('2022-02-22 00:00:00'))

In [15]:
import relbench

relbench.__version__

'1.1.0'

In [23]:
import duckdb
import pandas as pd
from relbench.tasks import get_task, get_task_names, register_task
from relbench.base import Database, EntityTask, RecommendationTask, Table, TaskType
from relbench.metrics import (
    accuracy,
    average_precision,
    f1,
    link_prediction_map,
    link_prediction_precision,
    link_prediction_recall,
    mae,
    r2,
    rmse,
    roc_auc,
)
from metrics import link_prediction_top
class UserItemPurchaseTask(RecommendationTask):
    r"""Predict the list of articles each customer will purchase in the next seven
    days."""

    task_type = TaskType.LINK_PREDICTION
    src_entity_col = "customer_id"
    src_entity_table = "customer"
    dst_entity_col = "article_id"
    dst_entity_table = "article"
    time_col = "timestamp"
    timedelta = pd.Timedelta(days=7)
    metrics = [link_prediction_precision, link_prediction_recall, link_prediction_map, link_prediction_top]
    eval_k = 12

    def make_table(self, db: Database, timestamps: "pd.Series[pd.Timestamp]") -> Table:
        customer = db.table_dict["customer"].df
        transactions = db.table_dict["transactions"].df
        timestamp_df = pd.DataFrame({"timestamp": timestamps})

        df = duckdb.sql(
            f"""
            SELECT
                t.timestamp,
                transactions.customer_id,
                LIST(DISTINCT transactions.article_id) AS article_id
            FROM
                timestamp_df t
            LEFT JOIN
                transactions
            ON
                transactions.t_dat > t.timestamp AND
                transactions.t_dat <= t.timestamp + INTERVAL '{self.timedelta} days'
            GROUP BY
                t.timestamp,
                transactions.customer_id
            """
        ).df()

        return Table(
            df=df,
            fkey_col_to_pkey_table={
                self.src_entity_col: self.src_entity_table,
                self.dst_entity_col: self.dst_entity_table,
            },
            pkey_col=None,
            time_col=self.time_col,
        )

# Task 1: Predict articles each customer will purchase in the next 7 days
class CustomerArticlePurchaseTask(RecommendationTask):
    r"""Predict the list of articles each customer will purchase in the next seven days."""
    
    task_type = TaskType.LINK_PREDICTION
    src_entity_col = "customers_id"
    src_entity_table = "customer"
    dst_entity_col = "articles_id"
    dst_entity_table = "article"
    time_col = "timestamp"
    timedelta = pd.Timedelta(days=7)
    metrics = [link_prediction_precision, link_prediction_recall, link_prediction_map, link_prediction_top]
    eval_k = 4

    def make_table(self, db: Database, timestamps: "pd.Series[pd.Timestamp]") -> Table:
        transactions = db.table_dict["transactions"].df
        timestamp_df = pd.DataFrame({"timestamp": timestamps})

        df = duckdb.sql(
            f"""
            SELECT
                t.timestamp,
                transactions.customers_id,
                LIST(DISTINCT transactions.articles_id) AS articles_id
            FROM
                timestamp_df t
            LEFT JOIN
                transactions
            ON
                transactions.datetime > t.timestamp AND
                transactions.datetime <= t.timestamp + INTERVAL '{self.timedelta.days} days'
            GROUP BY
                t.timestamp,
                transactions.customers_id
            """
        ).df()

        return Table(
            df=df,
            fkey_col_to_pkey_table={
                self.src_entity_col: self.src_entity_table,
                self.dst_entity_col: self.dst_entity_table,
            },
            pkey_col=None,
            time_col=self.time_col,
        )


# Task 2: Predict customer churn (no purchases in the next week)
class CustomerChurnTask(EntityTask):
    r"""Predict the churn for a customer (no transactions) in the next 6 days."""

    task_type = TaskType.BINARY_CLASSIFICATION
    entity_col = "customers_id"
    entity_table = "customer"
    time_col = "timestamp"
    target_col = "churn"
    timedelta = pd.Timedelta(days=7)
    metrics = [average_precision, accuracy, f1, roc_auc]

    def make_table(self, db: Database, timestamps: "pd.Series[pd.Timestamp]") -> Table:
        customer = db.table_dict["customer"].df
        transactions = db.table_dict["transactions"].df
        timestamp_df = pd.DataFrame({"timestamp": timestamps})

        df = duckdb.sql(
            f"""
            SELECT
                timestamp,
                customers_id,
                CAST(
                    NOT EXISTS (
                        SELECT 1
                        FROM transactions
                        WHERE
                            transactions.customers_id = customer.customers_id AND
                            transactions.datetime > timestamp AND
                            transactions.datetime <= timestamp + INTERVAL '{self.timedelta}'
                    ) AS INTEGER
                ) AS churn
            FROM
                timestamp_df,
                customer
            WHERE
                EXISTS (
                    SELECT 1
                    FROM transactions
                    WHERE
                        transactions.customers_id = customer.customers_id AND
                        transactions.datetime > timestamp - INTERVAL '{self.timedelta}' AND
                        transactions.datetime <= timestamp
                )
            """
        ).df()

        return Table(
            df=df,
            fkey_col_to_pkey_table={self.entity_col: self.entity_table},
            pkey_col=None,
            time_col=self.time_col,
        ) 
    # def make_table(self, db: Database, timestamps: "pd.Series[pd.Timestamp]") -> Table:
    #     transactions = db.table_dict["transactions"].df
    #     customer = db.table_dict["customer"].df
    #     timestamp_df = pd.DataFrame({"timestamp": timestamps})

    #     df = duckdb.sql(
    #         f"""
    #         SELECT
    #             t.timestamp,
    #             c.customers_id,
    #             CAST(
    #                 NOT EXISTS (
    #                     SELECT 1
    #                     FROM transactions
    #                     WHERE
    #                         transactions.customers_id = c.customers_id AND
    #                         transactions.datetime > t.timestamp AND
    #                         transactions.datetime <= t.timestamp + INTERVAL '{self.timedelta.days} days'
    #                 ) AS INTEGER
    #             ) AS churn
    #         FROM
    #             timestamp_df t,
    #             customer c
    #         WHERE
    #             EXISTS (
    #                 SELECT 1
    #                 FROM transactions
    #                 WHERE
    #                     transactions.customers_id = c.customers_id AND
    #                     transactions.datetime > t.timestamp - INTERVAL '{self.timedelta.days} days' AND
    #                     transactions.datetime <= t.timestamp
    #             )
    #         """
    #     ).df()

    #     return Table(
    #         df=df,
    #         fkey_col_to_pkey_table={self.entity_col: self.entity_table},
    #         pkey_col=None,
    #         time_col=self.time_col,
    #     )


# Task 3: Predict article sales in the next 7 days
class ArticleSalesTask(EntityTask):
    r"""Predict the total sales for an article (sum of `price_purchase`) in the next 7 days."""
    
    task_type = TaskType.REGRESSION
    entity_col = "articles_id"
    entity_table = "article"
    time_col = "datetime"
    target_col = "sales"
    timedelta = pd.Timedelta(days=7)
    metrics = [r2, mae, rmse]

    def make_table(self, db: Database, timestamps: "pd.Series[pd.Timestamp]") -> Table:
        transactions = db.table_dict["transactions"].df
        articles = db.table_dict["article"].df
        timestamp_df = pd.DataFrame({"timestamp": timestamps})

        df = duckdb.sql(
            f"""
            SELECT
                t.timestamp,
                a.articles_id,
                COALESCE(SUM(transactions.price_purchase), 0) AS sales
            FROM
                timestamp_df t,
                article a
            LEFT JOIN
                transactions
            ON
                transactions.articles_id = a.articles_id AND
                transactions.datetime > t.timestamp AND
                transactions.datetime <= t.timestamp + INTERVAL '{self.timedelta.days} days'
            GROUP BY
                t.timestamp,
                a.articles_id
            """
        ).df()

        return Table(
            df=df,
            fkey_col_to_pkey_table={self.entity_col: self.entity_table},
            pkey_col=None,
            time_col=self.time_col,
        )



In [24]:
aras_recom_task = CustomerArticlePurchaseTask(hyper_dataset, cache_dir="./cache/hyper_aras390111d195")
aras_recom_task

CustomerArticlePurchaseTask(dataset=TransactionalDataset())

In [58]:
register_task("hyperr-aras", "aras_recom_task2", CustomerArticlePurchaseTask)
get_task_names("hyperr-aras")

['aras_recom_task', 'aras_recom_task1', 'aras_recom_task2']

In [59]:
get_task_names("hyperr-aras")

['aras_recom_task', 'aras_recom_task1', 'aras_recom_task2']

In [60]:
import numpy as np

from torch.nn import BCEWithLogitsLoss, L1Loss
from relbench.datasets import get_dataset
from relbench.tasks import get_task

dataset = get_dataset("hyperr-aras")
task = get_task("hyperr-aras", "aras_recom_task2")


train_table = task.get_table("train")
val_table = task.get_table("val")
test_table = task.get_table("test")

out_channels = 1
loss_fn = BCEWithLogitsLoss()
tune_metric = "link_prediction_map"
higher_is_better = True

Making task table for train split from scratch...
(You can also use `get_task(..., download=True)` for tasks prepared by the RelBench team.)
Done in 18.05 seconds.
Making task table for val split from scratch...
(You can also use `get_task(..., download=True)` for tasks prepared by the RelBench team.)
Done in 0.21 seconds.
Making task table for test split from scratch...
(You can also use `get_task(..., download=True)` for tasks prepared by the RelBench team.)
Done in 0.37 seconds.


In [61]:
train_table

Table(df=
        timestamp  customers_id  \
0      2021-04-13         30946   
1      2021-04-13          9379   
2      2021-04-13         53821   
3      2021-04-13         53823   
4      2021-04-13         53825   
...           ...           ...   
720863 2022-02-08         14613   
720864 2022-02-08         96407   
720865 2022-02-08        122007   
720866 2022-02-08        284085   
720867 2022-02-08        271058   

                                              articles_id  
0       [4186, 7751, 1230, 1565, 7, 11495, 1793, 5503,...  
1       [5856, 1868, 925, 2192, 1450, 7804, 314, 8471,...  
2       [204, 1895, 323, 288, 1710, 1326, 4309, 1485, ...  
3       [10486, 2589, 230, 38, 3388, 126, 182, 56, 505...  
4       [478, 930, 2445, 1641, 1280, 1241, 818, 1017, ...  
...                                                   ...  
720863                                             [1257]  
720864                                            [19049]  
720865                       

In [29]:
val_table

Table(df=
       timestamp  customers_id  \
0     2022-02-15        179203   
1     2022-02-15          6980   
2     2022-02-15        192267   
3     2022-02-15         16392   
4     2022-02-15        126384   
...          ...           ...   
15086 2022-02-15        287773   
15087 2022-02-15          4373   
15088 2022-02-15         12133   
15089 2022-02-15        228877   
15090 2022-02-15         21118   

                                             articles_id  
0      [170, 2783, 139, 3912, 5395, 180, 1450, 3100, ...  
1      [695, 16568, 18163, 1788, 5409, 1913, 3523, 83...  
2      [722, 2200, 19518, 15992, 7067, 2922, 700, 618...  
3      [3799, 1938, 3186, 11819, 4041, 5709, 19379, 1...  
4      [1025, 16657, 1369, 2702, 256, 10591, 12741, 8...  
...                                                  ...  
15086                                             [4000]  
15087                                             [2748]  
15088                                             

In [30]:
test_table

Table(df=
       timestamp  customers_id  \
0     2022-02-22           810   
1     2022-02-22         55934   
2     2022-02-22         31213   
3     2022-02-22        288753   
4     2022-02-22        288754   
...          ...           ...   
15537 2022-02-22        292744   
15538 2022-02-22        292791   
15539 2022-02-22        292845   
15540 2022-02-22        292852   
15541 2022-02-22        292882   

                                             articles_id  
0      [430, 4862, 18183, 1079, 1990, 1443, 2959, 245...  
1                                     [1661, 526, 19210]  
2      [1272, 3047, 7450, 4812, 7138, 4835, 10475, 13...  
3                             [16081, 3012, 16117, 1661]  
4      [19374, 6629, 897, 3635, 2336, 19254, 2267, 11...  
...                                                  ...  
15537                                             [3616]  
15538                                             [1286]  
15539                                            [

In [62]:
import os
import math
import numpy as np
from tqdm import tqdm

import torch
import torch_geometric
import torch_frame

# Some book keeping
from torch_geometric.seed import seed_everything

seed_everything(42)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)  # check that it's cuda if you want it to run in reasonable time!
root_dir = "./data_ARAS"

cuda


In [63]:
from relbench.modeling.utils import get_stype_proposal

db = dataset.get_db()
col_to_stype_dict = get_stype_proposal(db)
col_to_stype_dict

{'article': {'articles_id': <stype.numerical: 'numerical'>,
  'art_name': <stype.text_embedded: 'text_embedded'>,
  'Department Name': <stype.categorical: 'categorical'>,
  'group': <stype.text_embedded: 'text_embedded'>,
  'Subgroup Name': <stype.text_embedded: 'text_embedded'>,
  'group_id': <stype.numerical: 'numerical'>,
  'whole_Branch_Name': <stype.text_embedded: 'text_embedded'>},
 'branches': {'BranchCode': <stype.numerical: 'numerical'>,
  'BranchName': <stype.text_embedded: 'text_embedded'>},
 'customer': {'customers_id': <stype.numerical: 'numerical'>,
  'customers_no': <stype.numerical: 'numerical'>},
 'transactions': {'factor_id': <stype.numerical: 'numerical'>,
  'articles_id': <stype.numerical: 'numerical'>,
  'customers_id': <stype.numerical: 'numerical'>,
  'BranchCode': <stype.categorical: 'categorical'>,
  'salesTime': <stype.timestamp: 'timestamp'>,
  'price_purchase': <stype.numerical: 'numerical'>,
  'Discount_ratio': <stype.numerical: 'numerical'>,
  'Quantity': 

In [64]:
from typing import List, Optional
from sentence_transformers import SentenceTransformer
from torch import Tensor
import torch

class BertPersianTextEmbedding:
    def __init__(self, device: Optional[torch.device] = None):
        # Replace the model with a Persian BERT model
        self.model = SentenceTransformer("HooshvareLab/bert-fa-zwnj-base",  # Example Persian BERT model
            device=device,
        )

    def __call__(self, sentences: List[str]) -> Tensor:
        # Encode the sentences using the Persian BERT model and return as a tensor
        return torch.from_numpy(self.model.encode(sentences))


In [65]:
from torch_frame.config.text_embedder import TextEmbedderConfig
from relbench.modeling.graph import make_pkey_fkey_graph

text_embedder_cfg = TextEmbedderConfig(
    text_embedder=BertPersianTextEmbedding(device=device), batch_size=64
)

data, col_stats_dict = make_pkey_fkey_graph(
    db,
    col_to_stype_dict=col_to_stype_dict,  # speficied column types
    text_embedder_cfg=text_embedder_cfg,  # our chosen text encoder
    cache_dir=os.path.join(
        root_dir, f"rel-aras_recom_materialized_cache"
    ),  # store materialized graph for convenience
)

Some weights of BertModel were not initialized from the model checkpoint at HooshvareLab/bert-fa-zwnj-base and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Embedding raw data in mini-batch: 100%|██████████| 318/318 [01:03<00:00,  5.03it/s]
Embedding raw data in mini-batch: 100%|██████████| 318/318 [02:08<00:00,  2.47it/s]
Embedding raw data in mini-batch: 100%|██████████| 318/318 [00:41<00:00,  7.60it/s]
Embedding raw data in mini-batch: 100%|██████████| 318/318 [00:43<00:00,  7.29it/s]
Embedding raw data in mini-batch: 100%|██████████| 1/1 [00:00<00:00, 76.96it/s]
  ser = pd.to_datetime(ser, format=time_format)
  ser = pd.to_datetime(ser, format=self.format, errors='coerce')


In [66]:
data

HeteroData(
  article={ tf=TensorFrame([20298, 6]) },
  branches={ tf=TensorFrame([5, 1]) },
  customer={ tf=TensorFrame([299536, 1]) },
  transactions={
    tf=TensorFrame([6581467, 6]),
    time=[6581467],
  },
  (transactions, f2p_articles_id, article)={ edge_index=[2, 6581467] },
  (article, rev_f2p_articles_id, transactions)={ edge_index=[2, 6581467] },
  (transactions, f2p_customers_id, customer)={ edge_index=[2, 6581467] },
  (customer, rev_f2p_customers_id, transactions)={ edge_index=[2, 6581467] },
  (transactions, f2p_BranchCode, branches)={ edge_index=[2, 6581467] },
  (branches, rev_f2p_BranchCode, transactions)={ edge_index=[2, 6581467] }
)

## RelBench Recommendation

In [None]:
import argparse
import copy
import json
import os
import warnings
from pathlib import Path
from typing import Dict, Tuple

import numpy as np
import torch
import torch.nn.functional as F
# from model import Model
# from text_embedder import GloveTextEmbedding
from torch import Tensor
from torch_frame import stype
from torch_frame.config.text_embedder import TextEmbedderConfig
from torch_geometric.loader import NeighborLoader
from torch_geometric.seed import seed_everything
from tqdm import tqdm

from relbench.base import Dataset, RecommendationTask, TaskType
from relbench.datasets import get_dataset
from relbench.modeling.graph import get_link_train_table_input, make_pkey_fkey_graph
from relbench.modeling.loader import SparseTensor
from relbench.modeling.utils import get_stype_proposal
from relbench.tasks import get_task

In [None]:
# Initialize the loader dictionary
loader_dict: Dict[str, NeighborLoader] = {}
dst_nodes_dict: Dict[str, Tuple[NodeType, Tensor]] = {}

# Loop over the train, val, and test splits
for split, table in [
    ("train", train_table),
    ("val", val_table),
    ("test", test_table),
]:
    # Get link train table input for link prediction task
    table_input = get_link_train_table_input(
        table=table,
        task=task,
    )
    
    # Save destination nodes for later use
    dst_nodes_dict[split] = table_input.dst_nodes

    # Create NeighborLoader for link prediction
    loader_dict[split] = NeighborLoader(
        data,
        num_neighbors=[128 for _ in range(2)],  # Sample subgraphs of depth 2, 128 neighbors per node
        time_attr="time",  # Use time attribute if available
        input_nodes=table_input.src_nodes,  # Source nodes for link prediction
        input_time=table_input.src_time,  # Use src_time if time data is available
        subgraph_type="bidirectional",
        batch_size=512,
        temporal_strategy="last",  # Uniform sampling strategy for time
        shuffle=split == "train",  # Shuffle only during training
        num_workers=0,
        persistent_workers=False,
    )


In [None]:
# Initialize the model for link prediction task
model = Model(
    data=data,  # Heterogeneous data object
    col_stats_dict=col_stats_dict,  # Column statistics dictionary
    num_layers=2,  # Adjust this to match your desired architecture (depth of GNN)
    channels=128,  # Number of hidden channels in GNN layers
    out_channels=1,  # Output size (for link prediction, usually a scalar per edge)
    aggr="sum",  # Aggregation method (can be "sum", "mean", etc.)
    norm="layer_norm",  # Normalization method
    id_awareness=True,  # Whether the model is aware of node IDs
).to(device)  # Move model to the appropriate device (e.g., GPU)

# Set up the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)  # Use the desired learning rate

# Handling sparse destination nodes for training
# dst_nodes_dict stores the destination nodes for the "train" split (in sparse format)
train_sparse_tensor = SparseTensor(dst_nodes_dict["train"][1], device=device)


In [None]:
def train() -> float:
    model.train()  # Set model to training mode

    loss_accum = count_accum = 0
    steps = 0
    total_steps = min(len(loader_dict["train"]), 2000)  # Change the max_steps_per_epoch to 2000 or your preferred value

    for batch in tqdm(loader_dict["train"], total=total_steps):
        batch = batch.to(device)  # Move batch data to device (GPU or CPU)

        # Forward pass through the model for link prediction (source and destination tables)
        out = model.forward_dst_readout(
            batch, task.src_entity_table, task.dst_entity_table
        ).flatten()  # Flatten the output

        batch_size = batch[task.src_entity_table].batch_size  # Get batch size for the source entity table

        # Get ground-truth labels
        input_id = batch[task.src_entity_table].input_id  # Input IDs for the batch
        src_batch, dst_index = train_sparse_tensor[input_id]  # Get the source and destination indices

        # Get the target labels by checking if source-destination pairs exist
        target = torch.isin(
            batch[task.dst_entity_table].batch
            + batch_size * batch[task.dst_entity_table].n_id,
            src_batch + batch_size * dst_index,
        ).float()  # Convert the result to float for loss computation

        # Optimization
        optimizer.zero_grad()  # Clear previous gradients
        loss = F.binary_cross_entropy_with_logits(out, target)  # Compute binary cross-entropy loss
        loss.backward()  # Backpropagation to compute gradients

        optimizer.step()  # Update model parameters

        # Accumulate the total loss and count for averaging later
        loss_accum += float(loss) * out.numel()
        count_accum += out.numel()

        steps += 1
        if steps >= total_steps:
            break  # Break the loop if max steps per epoch is reached

    # Handle the case where no data was sampled
    if count_accum == 0:
        warnings.warn(
            f"Did not sample a single '{task.dst_entity_table}' node in any mini-batch. "
            "Try increasing the number of layers/hops or reducing the batch size."
        )

    # Return average loss for the epoch
    return loss_accum / count_accum if count_accum > 0 else float("nan")


In [None]:
@torch.no_grad()  # No gradient computation for evaluation
def test(loader: NeighborLoader) -> np.ndarray:
    model.eval()  # Set model to evaluation mode

    pred_list: list[Tensor] = []  # Store predictions
    for batch in tqdm(loader):  # Iterate over batches in the test loader
        batch = batch.to(device)  # Move the batch data to the device (GPU or CPU)

        # Forward pass through the model for link prediction
        out = (
            model.forward_dst_readout(
                batch, task.src_entity_table, task.dst_entity_table
            )
            .detach()
            .flatten()  # Detach the output from the computational graph
        )

        batch_size = batch[task.src_entity_table].batch_size  # Get the batch size for source nodes

        # Prepare a tensor to hold the scores for the source-destination pairs
        scores = torch.zeros(batch_size, task.num_dst_nodes, device=out.device)

        # Fill the scores with sigmoid activations for the destination nodes in the current batch
        scores[
            batch[task.dst_entity_table].batch, batch[task.dst_entity_table].n_id
        ] = torch.sigmoid(out)  # Apply sigmoid activation to get probabilities

        # Use top-k (e.g., top recommended items) based on the scores
        _, pred_mini = torch.topk(scores, k=task.eval_k, dim=1)  # Get top-k predictions
        pred_list.append(pred_mini)  # Append predictions to the list

    # Concatenate all predictions and move to CPU for further processing
    pred = torch.cat(pred_list, dim=0).cpu().numpy()

    return pred  # Return the final predictions as a NumPy array


In [None]:
import copy

# Initialize variables for tracking the best model and best validation metrics
state_dict = None  # This will hold the best model state
best_val_metric = 0  # This will store the best validation metric
epochs = 10  # Set the number of epochs (you can adjust as needed)
eval_epochs_interval = 1  # Evaluate every 'n' epochs (change this based on your needs)
tune_metric = "link_prediction_map"  # Define the metric you are tuning

# Training and evaluation loop
for epoch in range(1, epochs + 1):
    # Run the training function
    train_loss = train()
    
    # Perform evaluation every 'eval_epochs_interval' epochs
    if epoch % eval_epochs_interval == 0:
        # Run the validation on the validation dataset
        val_pred = test(loader_dict["val"])  # Get the predictions from the model
        val_metrics = task.evaluate(val_pred, task.get_table("val"))  # Evaluate predictions
        
        # Print the training loss and validation metrics
        print(
            f"Epoch: {epoch:02d}, Train loss: {train_loss}, "
            f"Val metrics: {val_metrics}"
        )

        # Check if the current validation metric is the best
        if val_metrics[tune_metric] > best_val_metric:
            best_val_metric = val_metrics[tune_metric]  # Update best metric
            state_dict = copy.deepcopy(model.state_dict())  # Save the best model state

# After training, load the best model weights
model.load_state_dict(state_dict)

# Evaluate the model on the validation set with the best weights
val_pred = test(loader_dict["val"])
val_metrics = task.evaluate(val_pred, task.get_table("val"))
print(f"Best Val metrics: {val_metrics}")

# Evaluate the model on the test set with the best weights
test_pred = test(loader_dict["test"])
test_metrics = task.evaluate(test_pred)
print(f"Best test metrics: {test_metrics}")

## Context GNN

In [38]:
import argparse
import json
import os
import warnings
from pathlib import Path
from typing import Dict, List, Tuple, Union

import numpy as np
import torch
import torch.nn.functional as F
from relbench.base import Dataset, RecommendationTask, TaskType
from relbench.datasets import get_dataset
from relbench.modeling.graph import (
    get_link_train_table_input,
    make_pkey_fkey_graph,
)
from relbench.modeling.loader import SparseTensor
from relbench.modeling.utils import get_stype_proposal
from relbench.tasks import get_task
from torch import Tensor
from torch_frame import stype
from torch_frame.config.text_embedder import TextEmbedderConfig
from torch_geometric.loader import NeighborLoader
from torch_geometric.seed import seed_everything
from torch_geometric.typing import NodeType
from torch_geometric.utils.cross_entropy import sparse_cross_entropy
from tqdm import tqdm

from contextgnn.nn.models import IDGNN, ContextGNN, ShallowRHSGNN
from contextgnn.utils import GloveTextEmbedding, RHSEmbeddingMode
# Static configuration parameters
learning_rate = 0.001
epochs = 20
eval_epochs_interval = 1
batch_size = 512
channels = 128
aggregation_method = "sum"
num_layers = 4
num_neighbors = 128
temporal_strategy = "last"
share_same_time = True
max_steps_per_epoch = 2000
num_workers = 0
seed = 42
model_name = "contextgnn"  # For example, can be 'idgnn', 'contextgnn', or 'shallowrhsgnn'
tune_metric = "link_prediction_map"  # Metric used to tune the model
cache_dir = os.path.expanduser("~/.cache/relbench_examples")

# Set random seed for reproducibility
torch.manual_seed(seed)


<torch._C.Generator at 0x2df166d5b70>

In [39]:
# Define static num_neighbors for NeighborLoader
num_neighbors = [num_neighbors // 2**i for i in range(num_layers)]

# Loader dictionary for train, validation, and test sets
loader_dict: Dict[str, NeighborLoader] = {}
dst_nodes_dict: Dict[str, Tuple[NodeType, Tensor]] = {}
num_dst_nodes_dict: Dict[str, int] = {}

# Assuming `task` is already defined and provides the dataset information
for split in ["train", "val", "test"]:
    table = task.get_table(split)
    table_input = get_link_train_table_input(table, task)
    dst_nodes_dict[split] = table_input.dst_nodes
    num_dst_nodes_dict[split] = table_input.num_dst_nodes
    loader_dict[split] = NeighborLoader(
        data,
        num_neighbors=num_neighbors,
        time_attr="time",
        input_nodes=table_input.src_nodes,
        input_time=table_input.src_time,
        subgraph_type="bidirectional",
        batch_size=batch_size,
        temporal_strategy=temporal_strategy,
        shuffle=split == "train",
        num_workers=num_workers,
        persistent_workers=num_workers > 0,
    )


  dst_node_indices = sparse_coo.to_sparse_csr()


In [40]:
if model_name == "idgnn":
    model = IDGNN(
        data=data,
        col_stats_dict=col_stats_dict,
        num_layers=num_layers,
        channels=channels,
        out_channels=1,
        aggr=aggregation_method,
        norm="layer_norm",
        torch_frame_model_kwargs={
            "channels": 128,
            "num_layers": 4,
        },
    ).to(device)
elif model_name == "contextgnn":
    model = ContextGNN(
        data=data,
        col_stats_dict=col_stats_dict,
        rhs_emb_mode=RHSEmbeddingMode.FUSION,
        dst_entity_table=task.dst_entity_table,
        num_nodes=num_dst_nodes_dict["train"],
        num_layers=num_layers,
        channels=channels,
        aggr="sum",
        norm="layer_norm",
        embedding_dim=64,
        torch_frame_model_kwargs={
            "channels": 128,
            "num_layers": 4,
        },
    ).to(device)
elif model_name == 'shallowrhsgnn':
    model = ShallowRHSGNN(
        data=data,
        col_stats_dict=col_stats_dict,
        rhs_emb_mode=RHSEmbeddingMode.FUSION,
        dst_entity_table=task.dst_entity_table,
        num_nodes=num_dst_nodes_dict["train"],
        num_layers=num_layers,
        channels=channels,
        aggr="sum",
        norm="layer_norm",
        embedding_dim=64,
        torch_frame_model_kwargs={
            "channels": 128,
            "num_layers": 4,
        },
    ).to(device)
else:
    raise ValueError(f"Unsupported model type {model_name}.")

In [41]:
# Initialize optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)


In [42]:
# Training function
def train() -> float:
    model.train()

    loss_accum = count_accum = 0
    steps = 0
    total_steps = min(len(loader_dict["train"]), max_steps_per_epoch)
    sparse_tensor = SparseTensor(dst_nodes_dict["train"][1], device=device)
    
    for batch in tqdm(loader_dict["train"], total=total_steps, desc="Train"):
        batch = batch.to(device)

        # Get ground-truth
        input_id = batch[task.src_entity_table].input_id
        src_batch, dst_index = sparse_tensor[input_id]

        # Optimization
        optimizer.zero_grad()

        if model_name == 'idgnn':
            out = model(batch, task.src_entity_table, task.dst_entity_table).flatten()
            batch_size = batch[task.src_entity_table].batch_size

            # Get target label
            target = torch.isin(
                batch[task.dst_entity_table].batch +
                batch_size * batch[task.dst_entity_table].n_id,
                src_batch + batch_size * dst_index,
            ).float()

            loss = F.binary_cross_entropy_with_logits(out, target)
            numel = out.numel()
        elif model_name in ['contextgnn', 'shallowrhsgnn']:
            logits = model(batch, task.src_entity_table, task.dst_entity_table)
            edge_label_index = torch.stack([src_batch, dst_index], dim=0)
            loss = sparse_cross_entropy(logits, edge_label_index)
            numel = len(batch[task.dst_entity_table].batch)

        loss.backward()
        optimizer.step()

        loss_accum += float(loss) * numel
        count_accum += numel

        steps += 1
        if steps > max_steps_per_epoch:
            break

    if count_accum == 0:
        warnings.warn(f"Did not sample a single '{task.dst_entity_table}' node in any mini-batch.")

    return loss_accum / count_accum if count_accum > 0 else float("nan")


In [43]:
# Test function
@torch.no_grad()
def test(loader: NeighborLoader, desc: str) -> np.ndarray:
    model.eval()

    pred_list: List[Tensor] = []
    for batch in tqdm(loader, desc=desc):
        batch = batch.to(device)
        batch_size = batch[task.src_entity_table].batch_size

        if model_name == "idgnn":
            out = (model.forward(batch, task.src_entity_table, task.dst_entity_table).detach().flatten())
            scores = torch.zeros(batch_size, task.num_dst_nodes, device=out.device)
            scores[batch[task.dst_entity_table].batch, batch[task.dst_entity_table].n_id] = torch.sigmoid(out)
        elif model_name in ['contextgnn', 'shallowrhsgnn']:
            out = model(batch, task.src_entity_table, task.dst_entity_table).detach()
            scores = torch.sigmoid(out)
        else:
            raise ValueError(f"Unsupported model type: {model_name}.")

        _, pred_mini = torch.topk(scores, k=task.eval_k, dim=1)
        pred_list.append(pred_mini)
    
    pred = torch.cat(pred_list, dim=0).cpu().numpy()
    return pred


In [44]:
# Training and evaluation loop
state_dict = None
best_val_metric = 0
for epoch in range(1, epochs + 1):
    train_loss = train()
    
    if epoch % eval_epochs_interval == 0:
        val_pred = test(loader_dict["val"], desc="Val")
        val_metrics = task.evaluate(val_pred, task.get_table("val"))
        print(f"Epoch: {epoch:02d}, Train loss: {train_loss}, Val metrics: {val_metrics}")

        if val_metrics[tune_metric] > best_val_metric:
            best_val_metric = val_metrics[tune_metric]
            state_dict = {k: v.cpu() for k, v in model.state_dict().items()}

# Load the best model and evaluate on validation and test sets
assert state_dict is not None
model.load_state_dict(state_dict)

val_pred = test(loader_dict["val"], desc="Best val")
val_metrics = task.evaluate(val_pred, task.get_table("val"))
print(f"Best val metrics: {val_metrics}")

test_pred = test(loader_dict["test"], desc="Test")
test_metrics = task.evaluate(test_pred)
print(f"Best test metrics: {test_metrics}")

Train:   0%|          | 1/1408 [00:45<17:56:41, 45.91s/it]


OutOfMemoryError: CUDA out of memory. Tried to allocate 192.00 MiB. GPU 

## Sample Softmax

In [76]:
import os
import json
import warnings
from pathlib import Path
from typing import Dict, List, Tuple

import numpy as np
import torch
from relbench.base import Dataset, RecommendationTask, TaskType
from relbench.datasets import get_dataset
from relbench.modeling.graph import (
    get_link_train_table_input,
    make_pkey_fkey_graph,
)
from relbench.modeling.loader import SparseTensor
from relbench.modeling.utils import get_stype_proposal
from relbench.tasks import get_task
from torch_frame import stype
from torch_frame.config.text_embedder import TextEmbedderConfig
from torch_geometric.loader import NeighborLoader
from torch_geometric.seed import seed_everything
from torch_geometric.typing import NodeType
from torch_geometric.utils.cross_entropy import sparse_cross_entropy
from tqdm import tqdm
from contextgnn.nn.models import ContextGNN
from contextgnn.utils import GloveTextEmbedding, RHSEmbeddingMode

# Static Configuration
dataset_name = "rel-trial"
task_name = "site-sponsor-run"
learning_rate = 0.001
epochs = 20
eval_epochs_interval = 1
batch_size = 128
channels = 128
aggregation_method = "sum"
num_layers = 6
num_neighbors = 64
rhs_sample_size = 1000  # Use -1 for sampling all RHS
temporal_strategy = "last"
max_steps_per_epoch = 200
num_workers = 0
seed = 42
cache_dir = os.path.expanduser("~/.cache/relbench_examples")

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_num_threads(1 if torch.cuda.is_available() else os.cpu_count())
torch.manual_seed(seed)
seed_everything(seed)


In [77]:
# Load dataset and task
dataset: Dataset = get_dataset(dataset_name, download=True)
task: RecommendationTask = get_task(dataset_name, task_name, download=True)

# Ensure task type is LINK_PREDICTION
assert task.task_type == TaskType.LINK_PREDICTION

# Tune metric
tune_metric = "link_prediction_map"

# Handle column type mappings
stypes_cache_path = Path(f"{cache_dir}/{dataset_name}/stypes.json")
try:
    with open(stypes_cache_path, "r") as f:
        col_to_stype_dict = json.load(f)
    for table, col_to_stype in col_to_stype_dict.items():
        for col, stype_str in col_to_stype.items():
            col_to_stype[col] = stype(stype_str)
except FileNotFoundError:
    col_to_stype_dict = get_stype_proposal(dataset.get_db())
    Path(stypes_cache_path).parent.mkdir(parents=True, exist_ok=True)
    with open(stypes_cache_path, "w") as f:
        json.dump(col_to_stype_dict, f, indent=2, default=str)

# Prepare graph data and column stats
data, col_stats_dict = make_pkey_fkey_graph(
    dataset.get_db(),
    col_to_stype_dict=col_to_stype_dict,
    text_embedder_cfg=TextEmbedderConfig(
        text_embedder=GloveTextEmbedding(device=device), batch_size=256
    ),
    cache_dir=f"{cache_dir}/{dataset_name}/materialized",
)


RuntimeError: PytorchStreamReader failed reading zip archive: failed finding central directory

In [75]:
torch.cuda.empty_cache()


In [68]:
# Define number of neighbors for NeighborLoader
num_neighbors_list = [int(num_neighbors // 2**i) for i in range(num_layers)]

# Loader dictionaries
loader_dict: Dict[str, NeighborLoader] = {}
dst_nodes_dict: Dict[str, Tuple[NodeType, Tensor]] = {}
num_dst_nodes_dict: Dict[str, int] = {}

# Initialize data loaders for train, val, and test splits
for split in ["train", "val", "test"]:
    table = task.get_table(split)
    table_input = get_link_train_table_input(table, task)
    dst_nodes_dict[split] = table_input.dst_nodes
    num_dst_nodes_dict[split] = table_input.num_dst_nodes
    loader_dict[split] = NeighborLoader(
        data,
        num_neighbors=num_neighbors_list,
        time_attr="time",
        input_nodes=table_input.src_nodes,
        input_time=table_input.src_time,
        subgraph_type="bidirectional",
        batch_size=batch_size,
        temporal_strategy=temporal_strategy,
        shuffle=split == "train",
        num_workers=num_workers,
        persistent_workers=num_workers > 0,
    )


In [69]:
# Initialize ContextGNN model
model: ContextGNN = ContextGNN(
    data=data,
    col_stats_dict=col_stats_dict,
    rhs_emb_mode=RHSEmbeddingMode.FUSION,
    dst_entity_table=task.dst_entity_table,
    num_nodes=num_dst_nodes_dict["train"],
    num_layers=num_layers,
    channels=channels,
    aggr=aggregation_method,
    norm="layer_norm",
    embedding_dim=64,
    torch_frame_model_kwargs={"channels": 128, "num_layers": 4},
    rhs_sample_size=rhs_sample_size,
).to(device)

# Set up the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)


In [70]:
def train() -> float:
    model.train()

    loss_accum = count_accum = 0.0
    steps = 0
    total_steps = min(len(loader_dict["train"]), max_steps_per_epoch)
    sparse_tensor = SparseTensor(dst_nodes_dict["train"][1], device=device)

    for batch in tqdm(loader_dict["train"], total=total_steps, desc="Train"):
        batch = batch.to(device)

        # Get ground-truth
        input_id = batch[task.src_entity_table].input_id
        src_batch, dst_index = sparse_tensor[input_id]

        # Optimization
        optimizer.zero_grad()

        logits, lhs_y_batch, rhs_y_index = model.forward_sample_softmax(
            batch, task.src_entity_table, task.dst_entity_table, src_batch, dst_index
        )
        edge_label_index = torch.stack([lhs_y_batch, rhs_y_index], dim=0)
        loss = sparse_cross_entropy(logits, edge_label_index)

        numel = len(batch[task.dst_entity_table].batch)
        loss.backward()
        optimizer.step()

        loss_accum += float(loss) * numel
        count_accum += numel

        steps += 1
        if steps >= max_steps_per_epoch:
            break

    if count_accum == 0:
        warnings.warn(f"Did not sample a single '{task.dst_entity_table}' node in any mini-batch.")

    return loss_accum / count_accum if count_accum > 0 else float("nan")


In [71]:
@torch.no_grad()
def test(loader: NeighborLoader, desc: str) -> np.ndarray:
    model.eval()

    pred_list: List[Tensor] = []
    for batch in tqdm(loader, desc=desc):
        batch = batch.to(device)
        out = model(batch, task.src_entity_table, task.dst_entity_table).detach()
        scores = torch.sigmoid(out)
        _, pred_mini = torch.topk(scores, k=task.eval_k, dim=1)
        pred_list.append(pred_mini)
    pred = torch.cat(pred_list, dim=0).cpu().numpy()
    return pred


In [72]:
# Initialize variables for tracking the best model and validation metrics
state_dict = None
best_val_metric = 0

# Training and evaluation loop
for epoch in range(1, epochs + 1):
    train_loss = train()

    if epoch % eval_epochs_interval == 0:
        # Run validation
        val_pred = test(loader_dict["val"], desc="Val")
        val_metrics = task.evaluate(val_pred, task.get_table("val"))
        print(f"Epoch: {epoch:02d}, Train loss: {train_loss}, Val metrics: {val_metrics}")

        # Save best model
        if val_metrics[tune_metric] > best_val_metric:
            best_val_metric = val_metrics[tune_metric]
            state_dict = {k: v.cpu() for k, v in model.state_dict().items()}

# Load the best model weights
assert state_dict is not None
model.load_state_dict(state_dict)

# Evaluate on validation and test sets
val_pred = test(loader_dict["val"], desc="Best val")
val_metrics = task.evaluate(val_pred, task.get_table("val"))
print(f"Best val metrics: {val_metrics}")

test_pred = test(loader_dict["test"], desc="Test")
test_metrics = task.evaluate(test_pred)
print(f"Best test metrics: {test_metrics}")


Train: 100%|█████████▉| 199/200 [52:32<00:15, 15.84s/it]
Val: 100%|██████████| 118/118 [17:12<00:00,  8.75s/it]


Epoch: 01, Train loss: 63.64355062029784, Val metrics: {'link_prediction_precision': np.float64(0.08001457822543237), 'link_prediction_recall': np.float64(0.06289739234376616), 'link_prediction_map': np.float64(0.06324225623808157), 'link_prediction_top': np.float64(0.25995626532370286)}


Train: 100%|█████████▉| 199/200 [49:45<00:15, 15.00s/it]
Val: 100%|██████████| 118/118 [27:39<00:00, 14.06s/it]


Epoch: 02, Train loss: 58.35857021517497, Val metrics: {'link_prediction_precision': np.float64(0.08079318799284342), 'link_prediction_recall': np.float64(0.06300215009859066), 'link_prediction_map': np.float64(0.06365502985591116), 'link_prediction_top': np.float64(0.2613478232058843)}


Train: 100%|█████████▉| 199/200 [47:15<00:14, 14.25s/it]
Val: 100%|██████████| 118/118 [28:21<00:00, 14.42s/it]


Epoch: 03, Train loss: 57.68884758799559, Val metrics: {'link_prediction_precision': np.float64(0.08203565038764826), 'link_prediction_recall': np.float64(0.06379913579833485), 'link_prediction_map': np.float64(0.06478797149147027), 'link_prediction_top': np.float64(0.2639321449870784)}


Train: 100%|█████████▉| 199/200 [47:34<00:14, 14.35s/it]
Val:   8%|▊         | 9/118 [02:25<29:24, 16.19s/it]


KeyboardInterrupt: 