In [1]:
from relbench.datasets import get_dataset, get_dataset_names, register_dataset

In [131]:
import os
import pandas as pd
import numpy as np
from relbench.base import Database, Dataset, Table

class TransactionalDataset(Dataset):
    # Set timestamps or other relevant information if needed
    val_timestamp = pd.Timestamp("2022-02-15")
    test_timestamp = pd.Timestamp("2022-02-22")

    def make_db(self) -> Database:
        # Path to your CSVs folder
        path = os.path.join("D:/Dani/relbench/relbench/", "hyper_data")
        customers = os.path.join(path, "Customers.csv")
        articles = os.path.join(path, "Articles.csv")
        branches = os.path.join(path, "Branches.csv")
        transactions = os.path.join(path, "Transactions.csv")

        # Ensure that CSV files exist in the specified path
        if not os.path.exists(customers):
            raise RuntimeError(f"Dataset not found at '{path}'. Please make sure the CSV files are in the correct folder.")

        # Read the CSV data into pandas DataFrames
        customers_df = pd.read_csv(customers)
        articles_df = pd.read_csv(articles)
        branches_df = pd.read_csv(branches)
        transactions_df = pd.read_csv(transactions)
        transactions_df['d_dat'] = pd.to_datetime(transactions_df['d_dat'])        
        split_date = pd.to_datetime('2021-08-01')
        transactions_df = transactions_df[transactions_df['d_dat'] >= split_date]
        transactions_df = transactions_df.reset_index(drop=True)
        ################################################################################
        # Check for and handle duplicate primary keys in articles, customers, and branches tables
        ################################################################################

        # Handle duplicates in the articles table
        if articles_df.duplicated(subset=['articles_id']).any():
            print("Duplicates found in the 'articles_id' column. Removing duplicates...")
            articles_df = articles_df.drop_duplicates(subset=['articles_id'], keep='first')

        # Handle duplicates in the customers table
        if customers_df.duplicated(subset=['customers_id']).any():
            print("Duplicates found in the 'customers_id' column. Removing duplicates...")
            customers_df = customers_df.drop_duplicates(subset=['customers_id'], keep='first')

        # Handle duplicates in the branches table
        if branches_df.duplicated(subset=['BranchCode']).any():
            print("Duplicates found in the 'BranchCode' column. Removing duplicates...")
            branches_df = branches_df.drop_duplicates(subset=['BranchCode'], keep='first')

        ################################################################################
        # Clean and process the data (drop unnecessary columns, handle missing data)
        ################################################################################
        # Drop unnecessary columns
        transactions_df.drop(columns=["Return Amount"], inplace=True)
        articles_df.drop(columns=["Item Barcode", "External Item Number"], inplace=True)

        # Replace any missing or invalid values
        transactions_df["salesTime"] = transactions_df["salesTime"].replace(r"^\\N$", "00:00:00", regex=True)
        transactions_df = transactions_df.replace(r"^\\N$", np.nan, regex=True)

        # Combine date and time into a single 'datetime' column
        # transactions_df['datetime'] = pd.to_datetime(transactions_df['d_dat'] + ' ' + transactions_df['salesTime'])
        # transactions_df.drop(columns=["d_dat"], inplace=True)        
        # Convert date column to pd.Timestamp
        # transactions_df["datetime"] = pd.to_datetime(transactions_df["datetime"])

        transactions_df["datetime"] = pd.to_datetime(
        transactions_df["d_dat"], format="%Y-%m-%d"
        )
        transactions_df.drop(columns=["d_dat"], inplace=True)          
        # Convert other fields if necessary
        transactions_df['price_purchase'] = pd.to_numeric(transactions_df['price_purchase'], errors='coerce')
        transactions_df['Discount_ratio'] = pd.to_numeric(transactions_df['Discount_ratio'], errors='coerce')
        transactions_df['Quantity'] = pd.to_numeric(transactions_df['Quantity'], errors='coerce')

        ################################################################################
        # Now we define the table structure and relationships.
        ################################################################################

        tables = {}

        # Articles table
        tables["article"] = Table(
            df=pd.DataFrame(articles_df),
            fkey_col_to_pkey_table={},
            pkey_col="articles_id",
            time_col=None,
        )

        # Customers table
        tables["customer"] = Table(
            df=pd.DataFrame(customers_df),
            fkey_col_to_pkey_table={},
            pkey_col="customers_id",
            time_col=None,
        )

        # Branches table (renamed from "branche" to "branches")
        tables["branches"] = Table(
            df=pd.DataFrame(branches_df),
            fkey_col_to_pkey_table={},
            pkey_col="BranchCode",
            time_col=None,
        )

        # Transactions table
        tables["transactions"] = Table(
            df=pd.DataFrame(transactions_df),
            fkey_col_to_pkey_table={
                "articles_id": "article",    # Foreign key to articles
                "customers_id": "customer",  # Foreign key to customers
                "BranchCode": "branches",    # Foreign key to branches
            },
            pkey_col=None,
            time_col="datetime",  # Use the combined datetime column for time-based operations
        )

        return Database(tables)


In [9]:
import os
import pandas as pd
import numpy as np
from relbench.base import Database, Dataset, Table
from datetime import datetime, timedelta

class TransactionalDataset(Dataset):
    # Set timestamps or other relevant information if needed
    val_timestamp = pd.Timestamp("2022-02-15")
    test_timestamp = pd.Timestamp("2022-02-22")

    def make_db(self) -> Database:
        # Path to your CSVs folder
        path = os.path.join("D:/Dani/relbench/relbench/", "hyper_data")
        customers = os.path.join(path, "Customers.csv")
        articles = os.path.join(path, "Articles.csv")
        branches = os.path.join(path, "Branches.csv")
        transactions = os.path.join(path, "Transactions.csv")

        # Ensure that CSV files exist in the specified path
        if not os.path.exists(customers):
            raise RuntimeError(f"Dataset not found at '{path}'. Please make sure the CSV files are in the correct folder.")

        # Read the CSV data into pandas DataFrames
        customers_df = pd.read_csv(customers)
        articles_df = pd.read_csv(articles)
        branches_df = pd.read_csv(branches)
        transactions_df = pd.read_csv(transactions)
        
        ################################################################################
        # Check for and handle duplicate primary keys in articles, customers, and branches tables
        ################################################################################

        # Handle duplicates in the articles table
        if articles_df.duplicated(subset=['articles_id']).any():
            print("Duplicates found in the 'articles_id' column. Removing duplicates...")
            articles_df = articles_df.drop_duplicates(subset=['articles_id'], keep='first')

        # Handle duplicates in the customers table
        if customers_df.duplicated(subset=['customers_id']).any():
            print("Duplicates found in the 'customers_id' column. Removing duplicates...")
            customers_df = customers_df.drop_duplicates(subset=['customers_id'], keep='first')

        # Handle duplicates in the branches table
        if branches_df.duplicated(subset=['BranchCode']).any():
            print("Duplicates found in the 'BranchCode' column. Removing duplicates...")
            branches_df = branches_df.drop_duplicates(subset=['BranchCode'], keep='first')

        ################################################################################
        # Clean and process the data (drop unnecessary columns, handle missing data)
        ################################################################################
        # Drop unnecessary columns
        transactions_df.drop(columns=["Return Amount"], inplace=True)
        articles_df.drop(columns=["Item Barcode", "External Item Number"], inplace=True)

        # Replace any missing or invalid values
        transactions_df["salesTime"] = transactions_df["salesTime"].replace(r"^\\N$", "00:00:00", regex=True)
        transactions_df = transactions_df.replace(r"^\\N$", np.nan, regex=True)

        # Combine date and time into a single 'datetime' column
        transactions_df["datetime"] = pd.to_datetime(
            transactions_df["d_dat"], format="%Y-%m-%d"
        )
        transactions_df.drop(columns=["d_dat"], inplace=True)          

        # Convert other fields if necessary
        transactions_df['price_purchase'] = pd.to_numeric(transactions_df['price_purchase'], errors='coerce')
        transactions_df['Discount_ratio'] = pd.to_numeric(transactions_df['Discount_ratio'], errors='coerce')
        transactions_df['Quantity'] = pd.to_numeric(transactions_df['Quantity'], errors='coerce')

        ################################################################################
        # Article Removal Logic: Remove articles not in transactions in the last 6 months
        ################################################################################
        # Calculate the date 6 months ago
        last_transaction_date = transactions_df["datetime"].max()

        # Calculate the date 6 months ago from the last transaction date
        six_months_ago = last_transaction_date - timedelta(days=6*30)  # Approximate 6 months

        print(f"Removing articles not in transactions since {six_months_ago.date()}")

        # Find articles that have been in transactions in the last 6 months
        recent_transactions = transactions_df[transactions_df['datetime'] >= six_months_ago]
        articles_in_recent_transactions = recent_transactions['articles_id'].unique()

        # Filter out articles that are in recent transactions
        articles_to_remove = articles_df[~articles_df['articles_id'].isin(articles_in_recent_transactions)]

        # Remove the articles from the articles DataFrame
        articles_df = articles_df[articles_df['articles_id'].isin(articles_in_recent_transactions)]

        # Remove the articles from the transactions DataFrame as well
        transactions_df = transactions_df[transactions_df['articles_id'].isin(articles_in_recent_transactions)]

        # Optionally, save the removed articles to a separate CSV
        articles_to_remove.to_csv(os.path.join(path, 'articles_to_remove.csv'), index=False)

        print(f"Removed {len(articles_to_remove)} articles that haven't been in a transaction in the last 6 months.")

        ################################################################################
        # Now we define the table structure and relationships.
        ################################################################################

        tables = {}

        # Articles table
        tables["article"] = Table(
            df=pd.DataFrame(articles_df),
            fkey_col_to_pkey_table={},
            pkey_col="articles_id",
            time_col=None,
        )

        # Customers table
        tables["customer"] = Table(
            df=pd.DataFrame(customers_df),
            fkey_col_to_pkey_table={},
            pkey_col="customers_id",
            time_col=None,
        )

        # Branches table (renamed from "branche" to "branches")
        tables["branches"] = Table(
            df=pd.DataFrame(branches_df),
            fkey_col_to_pkey_table={},
            pkey_col="BranchCode",
            time_col=None,
        )

        # Transactions table
        tables["transactions"] = Table(
            df=pd.DataFrame(transactions_df),
            fkey_col_to_pkey_table={
                "articles_id": "article",    # Foreign key to articles
                "customers_id": "customer",  # Foreign key to customers
                "BranchCode": "branches",    # Foreign key to branches
            },
            pkey_col=None,
            time_col="datetime",  # Use the combined datetime column for time-based operations
        )

        return Database(tables)


In [25]:
import os
import pandas as pd
import numpy as np
from relbench.base import Database, Dataset, Table


# Path to your CSVs folder
path = os.path.join("D:/Dani/relbench/relbench/burger_data/dorsa", "new")   

customers = os.path.join(path, "customer_data.csv")
articles = os.path.join(path, "article_data.csv")
branches = os.path.join(path, "branch_data.csv")
transactions = os.path.join(path, "transaction_data.csv")
calendar = os.path.join(path, "calendar.xlsx")
# Ensure that CSV files exist in the specified path
if not os.path.exists(customers):
    raise RuntimeError(f"Dataset not found at '{path}'. Please make sure the CSV files are in the correct folder.")

# Read the CSV data into pandas DataFrames
customers_df = pd.read_csv(customers)
articles_df = pd.read_csv(articles)
branches_df = pd.read_csv(branches)
transactions_df = pd.read_csv(transactions)
calendar_df = pd.read_excel(calendar)
transactions_df['d_dat'] = pd.to_datetime(transactions_df['Datetime'])
transactions_df.drop(columns=["Datetime"], inplace=True)  
transactions_df['Date'] = transactions_df['d_dat'].dt.date
transactions_df['Date'] = pd.to_datetime(transactions_df['Date'])
transactions_df = transactions_df.reset_index(drop=True)
transactions_df = pd.merge(transactions_df, calendar_df, how="inner", on='Date')

In [26]:
transactions_df

Unnamed: 0,Transaction Type,BranchCode,Transaction Serial,Document Barcode,Time,customers_id,articles_id,Style Code,Quantity,Base Total Price,...,Weekday,Weekend,Week_of_month,Week_of_year,wm_yr_wk,Holiday,event_name_1,event_type_1,event_name_2,event_type_2
0,فاکتور فروش,10284.0,27631.0,10284-27631,12:07:41,10121860.0,54821.0,3416.0,1,1.174312e+08,...,Tuesday,0,3,47,140247,0,Fajr,National,Sh Imam Mousa,Religious
1,فاکتور فروش,10124.0,75949.0,10124-75949,10:53:54,10447769.0,53612.0,3919.0,1,8.874312e+07,...,Saturday,0,2,47,140247,0,Fajr,National,,
2,فاکتور فروش,10706.0,21517.0,10706-21517,21:28:09,10469286.0,59459.0,2400.0,1,1.040909e+08,...,Sunday,0,2,16,140316,0,,,,
3,فاکتور فروش,11008.0,4992.0,11008-4992,21:36:02,10075749.0,54126.0,3117.0,1,4.500000e+07,...,Sunday,0,2,29,140329,0,,,,
4,فاکتور فروش,10002.0,58726.0,10002-58726,18:30:10,10432889.0,44301.0,3421.0,1,5.500000e+07,...,Friday,1,1,40,140140,1,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
495844,فاکتور فروش,10563.0,5418.0,10563-5418,22:52:44,10478973.0,62211.0,4542.0,1,1.936060e+09,...,Thursday,1,2,37,140337,0,,,,
495845,فاکتور فروش,10634.0,3188.0,10634-3188,18:13:09,10477956.0,62193.0,4542.0,1,1.955145e+09,...,Saturday,0,2,38,140338,0,,,,
495846,فاکتور فروش,10657.0,1308.0,10657-1308,10:43:27,10467100.0,54284.0,3305.0,26,2.027523e+09,...,Saturday,0,5,10,140310,0,,,,
495847,فاکتور فروش,10396.0,9557.0,10396-9557,10:55:21,10453553.0,62876.0,4648.0,1,2.190258e+09,...,Thursday,1,3,43,140343,0,,,,


In [32]:
import pandas as pd
import numpy as np
from sklearn.feature_selection import VarianceThreshold

def identify_invaluable_features(df, missing_threshold=0.8, variance_threshold=0.01, cardinality_threshold=0.5, correlation_threshold=0.95):
    """
    Analyzes a DataFrame and identifies features that are likely invaluable for ML modeling.
    
    Parameters:
    - df: pandas DataFrame
    - missing_threshold: float (features with more missing values than this ratio are considered invaluable)
    - variance_threshold: float (features with variance below this threshold are considered invaluable)
    - cardinality_threshold: float (categorical features with unique values exceeding this ratio of total rows are invaluable)
    - correlation_threshold: float (features with correlation above this with another feature are invaluable)

    Returns:
    - Pandas DataFrame with features and reasons why they are considered invaluable.
    """
    
    invaluable_features = []
    
    # High missing values
    missing_ratios = df.isnull().mean()
    high_missing_cols = missing_ratios[missing_ratios > missing_threshold].index
    for col in high_missing_cols:
        invaluable_features.append((col, "High missing values"))

    # Low variance features
    numeric_df = df.select_dtypes(include=[np.number])
    selector = VarianceThreshold(threshold=variance_threshold)
    if not numeric_df.empty:
        selector.fit(numeric_df.fillna(0))  # Replace NaN with 0 for variance check
        low_variance_cols = numeric_df.columns[~selector.get_support()]
        for col in low_variance_cols:
            invaluable_features.append((col, "Low variance"))

    # Single unique value
    unique_counts = df.nunique()
    single_value_cols = unique_counts[unique_counts == 1].index
    for col in single_value_cols:
        invaluable_features.append((col, "Single unique value"))

    # High cardinality categorical features
    categorical_df = df.select_dtypes(include=['object', 'category'])
    for col in categorical_df.columns:
        unique_ratio = categorical_df[col].nunique() / len(df)
        if unique_ratio > cardinality_threshold:
            invaluable_features.append((col, "High cardinality"))

    # High correlation features
    if not numeric_df.empty:
        corr_matrix = numeric_df.corr().abs()
        upper_tri = corr_matrix.where(np.triu(np.ones(corr_matrix.shape), k=1).astype(bool))
        high_corr_pairs = [(col, "High correlation") for col in upper_tri.columns if any(upper_tri[col] > correlation_threshold)]
        invaluable_features.extend(high_corr_pairs)

    return pd.DataFrame(invaluable_features, columns=["Feature", "Reason"])

# Example Usage:
# df = pd.read_csv("your_dataset.csv")  # Load your dataset
invaluable_features_df = identify_invaluable_features(articles_df)
# import ace_tools as tools; tools.display_dataframe_to_user(name="Invaluable Features", dataframe=invaluable_features_df)


In [34]:
articles_df

Unnamed: 0,articles_id,Article Description,Base Unit Price,Style Code,Style Description,Collection Code,Collection Name,Idea Code,Idea Description,Initial Pattern Code,...,User,Usage Type,Usage Space,Brand,Size,Season,Year,Model,Combined Feature,Dimensions
0,54821.0,چلسي بوت نقش‌دال,1.174312e+08,1152,چلسی بوت نقش‌دال,0,عمومي,0,عمومي,218,...,MEN,SMART CASUAL,Unknown Usage Space,DORSA,42,FALL-WINTER,1402,AKM 6408-BOOT,پنجه متوسط گرد - فرم رویه صاف - فرم زیره پاشنه...,Unknown Dimensions
1,53612.0,کيف پوچ زنانه کلکسيون آهنگ,8.874312e+07,1495,پوچ آهنگ,0,عمومي,0,عمومي,438,...,WOMEN,Unknown Usage Type,Unknown Usage Space,DORSA,Unknown Size,SPRING-SUMMER,1402,Unknown Model,چرم گاوی ناپا - تریم زرد - آستر نقش دال مشکی -...,L 38 * H 28 * D 3 CM
2,59459.0,کيف دستي زنانه کوچک,1.040909e+08,510,کیف زنانه مربعی کوچک با لوگو 65میل درسا,0,عمومي,0,عمومي,262,...,WOMEN,SMART CASUAL,Unknown Usage Space,DORSA,Unknown Size,S2,1403,Unknown Model,لوگو فلزی پنجره ای 9-65میل - رنگ فلز طلایی - چ...,L 24 * H 20.5 * D 9 CM
3,54126.0,کيف پول عمودي آکو,4.500000e+07,958,کیف پول عمودی اکو,44,AKO,0,AKO,328,...,MEN,SMART CASUAL,Unknown Usage Space,DORSA,Unknown Size,SPRING-SUMMER,1402,Unknown Model,آستر نقش دال مشکی - فینیش مشکی - چرم گاوی ناپا...,L 10 * H 12 * D 1.5 CM
4,44301.0,اسنيکر بندي زنانه ساقدار با زيره زرد,5.500000e+07,1155,اسنیکر بندی زنانه ساقدار,0,عمومي,0,عمومي,220,...,WOMEN,CASUAL,Unknown Usage Space,DORSA,37,FALL-WINTER,1399,AKM 4939-WOMEN BOOT,پنجه متوسط گرد - فرم رویه بندی - فرم زیره پاشن...,Unknown Dimensions
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
45534,50733.0,دستبند خرد طلا 29/18 سايز L,2.004857e+09,1906,دستبند خرد طلا سایز L,68,WISDOM,0,WISDOM,409,...,UNISEX,Unknown Usage Type,-,DORSA JEWELRY,29,SPRING-SUMMER,1402,4543-1,محاسبه وزن طلا - 18,Unknown Dimensions
45535,50731.0,دستبند خرد طلا 28/96 سايز L,2.003945e+09,1906,دستبند خرد طلا سایز L,68,WISDOM,0,WISDOM,409,...,UNISEX,Unknown Usage Type,-,DORSA JEWELRY,28,SPRING-SUMMER,1402,4543-1,محاسبه وزن طلا - 96,Unknown Dimensions
45536,62211.0,دستبند خرد طلا 28/76 سايز M,1.936060e+09,1905,دستبند خرد طلا سایز M,68,WISDOM,0,WISDOM,409,...,UNISEX,Unknown Usage Type,-,DORSA JEWELRY,28,SPRING-SUMMER,1402,4542-1,محاسبه وزن طلا - 76,Unknown Dimensions
45537,62193.0,دستبند خرد طلا 28/57 سايز M,1.955145e+09,1905,دستبند خرد طلا سایز M,68,WISDOM,0,WISDOM,409,...,UNISEX,Unknown Usage Type,-,DORSA JEWELRY,28,SPRING-SUMMER,1402,4542-1,محاسبه وزن طلا - 57,Unknown Dimensions


In [5]:
invaluable_features_df

Unnamed: 0,Feature,Reason
0,Transaction Type,Single unique value
1,Document Barcode,High cardinality
2,Sale Total Price,High correlation
3,Net Sales,High correlation
4,Tax,High correlation
5,Net Sales Minus Sales Voucher,High correlation


In [31]:
transactions_df.columns

Index(['Transaction Type', 'BranchCode', 'Transaction Serial',
       'Document Barcode', 'Time', 'customers_id', 'articles_id', 'Style Code',
       'Quantity', 'Base Total Price', 'Sale Total Price', 'Net Sales',
       'Year_x', 'Month_x', 'Day', 'Season', 'Allocated Credit',
       'Sales Voucher', 'Gift Voucher', 'Cash Card', 'Special Discount',
       'Promotion Discount', 'Small Change Discount', 'Tax', 'Total Discounts',
       'Net Sales Minus Sales Voucher', 'Discount Percentage', 'd_dat', 'Date',
       'Date_S', 'd', 'Year_y', 'Day_of_year', 'Month_y', 'Day_of_month',
       'Day_of_week', 'even_odd', 'Weekday', 'Weekend', 'Week_of_month',
       'Week_of_year', 'wm_yr_wk', 'Holiday', 'event_name_1', 'event_type_1',
       'event_name_2', 'event_type_2'],
      dtype='object')

In [35]:
import os
import pandas as pd
import numpy as np
from relbench.base import Database, Dataset, Table

class TransactionalDataset(Dataset):
    # Set timestamps or other relevant information if needed
    # val_timestamp = pd.Timestamp("2024-11-17")
    # test_timestamp = pd.Timestamp("2024-12-18")
    val_timestamp = pd.Timestamp("2025-01-3")
    test_timestamp = pd.Timestamp("2025-01-11")
    def make_db(self) -> Database:
        # Path to your CSVs folder
        path = os.path.join("D:/Dani/relbench/relbench/burger_data/dorsa", "new")   

        customers = os.path.join(path, "customer_data.csv")
        articles = os.path.join(path, "article_data.csv")
        branches = os.path.join(path, "branch_data.csv")
        transactions = os.path.join(path, "transaction_data.csv")
        calendar = os.path.join(path, "calendar.xlsx")
        # Ensure that CSV files exist in the specified path
        if not os.path.exists(customers):
            raise RuntimeError(f"Dataset not found at '{path}'. Please make sure the CSV files are in the correct folder.")

        # Read the CSV data into pandas DataFrames
        customers_df = pd.read_csv(customers)
        articles_df = pd.read_csv(articles)
        branches_df = pd.read_csv(branches)
        transactions_df = pd.read_csv(transactions)
        calendar_df = pd.read_excel(calendar)
        transactions_df['d_dat'] = pd.to_datetime(transactions_df['Datetime'])
        transactions_df.drop(columns=["Datetime"], inplace=True)  
        transactions_df['Date'] = transactions_df['d_dat'].dt.date
        transactions_df['Date'] = pd.to_datetime(transactions_df['Date'])
        transactions_df = transactions_df.reset_index(drop=True)
        transactions_df = pd.merge(transactions_df, calendar_df, how="inner", on='Date')
        
        ################################################################################
        # Check for and handle duplicate primary keys in articles, customers, and branches tables
        ################################################################################
        # customers_df['customers_id'] = customers_df['Customer ID']
        # customers_df.drop(columns=["Customer ID"], inplace=True)
        articles_df.drop(columns=['articles_id'], inplace=True)
        articles_df['articles_id'] = articles_df["Style Code"]


        transactions_df.drop(columns=['articles_id'], inplace=True)
        transactions_df['articles_id'] = transactions_df["Style Code"]


        # transactions_df['customers_id'] = transactions_df['Customer ID']
        # transactions_df.drop(columns=["Customer ID"], inplace=True)

        # transactions_df['BranchCode'] = transactions_df['Branch ID']
        # transactions_df.drop(columns=["Branch ID"], inplace=True)

        # branches_df['BranchCode'] = branches_df['Branch ID']
        # branches_df.drop(columns=["Branch ID"], inplace=True)        

        # Handle duplicates in the articles table
        articles_df = articles_df.drop_duplicates(subset=['articles_id'], keep='first')
        
        # Handle duplicates in the customers table
        customers_df = customers_df.drop_duplicates(subset=['customers_id'], keep='first')
        
        # Handle duplicates in the branches table
        branches_df = branches_df.drop_duplicates(subset=['BranchCode'], keep='first')
        
        ################################################################################
        # Filter customers with 3 or more transactions
        ################################################################################
        unique_transactions = transactions_df.drop_duplicates(subset=['customers_id', 'Transaction Serial'])

        # Count the unique customers
        transaction_count = unique_transactions['customers_id'].value_counts()    
        valid_customers = transaction_count[transaction_count >= 3].index
        transactions_df = transactions_df[transactions_df['customers_id'].isin(valid_customers)]
        customers_df = customers_df[customers_df['customers_id'].isin(valid_customers)]
        
        ################################################################################
        # Clean and process the data (drop unnecessary columns, handle missing data)
        ################################################################################
        transactions_df.drop(columns=["event_name_2", "event_type_2", "Transaction Type", "Document Barcode", "Sale Total Price", "Net Sales", "Tax", 
"Net Sales Minus Sales Voucher", "d", "Year_y", "Day_of_year", "Month_y", "Day_of_month", "Week_of_month", 
"Week_of_year", "wm_yr_wk", "Date_S", "Date"], inplace=True)
        transactions_df = transactions_df.replace(r"^\\N$", np.nan, regex=True)

        # Convert date column to pd.Timestamp
        transactions_df["datetime"] = pd.to_datetime(transactions_df["d_dat"], format="%Y-%m-%d")
        transactions_df.drop(columns=["d_dat"], inplace=True)          
        
        ################################################################################
        # Now we define the table structure and relationships.
        ################################################################################
        tables = {}

        # Articles table
        tables["article"] = Table(
            df=pd.DataFrame(articles_df),
            fkey_col_to_pkey_table={},
            pkey_col="articles_id",
            time_col=None,
        )

        # Customers table
        tables["customer"] = Table(
            df=pd.DataFrame(customers_df),
            fkey_col_to_pkey_table={},
            pkey_col="customers_id",
            time_col=None,
        )

        # Branches table
        tables["branches"] = Table(
            df=pd.DataFrame(branches_df),
            fkey_col_to_pkey_table={},
            pkey_col="BranchCode",
            time_col=None,
        )

        # Transactions table
        tables["transactions"] = Table(
            df=pd.DataFrame(transactions_df),
            fkey_col_to_pkey_table={
                "articles_id": "article",    # Foreign key to articles
                "customers_id": "customer",  # Foreign key to customers
                "BranchCode": "branches",    # Foreign key to branches
            },
            pkey_col=None,
            time_col="datetime",  # Use the combined datetime column for time-based operations
        )

        return Database(tables)


In [36]:
transactional_dataset = TransactionalDataset()
db = transactional_dataset.make_db()

In [37]:
table = db.table_dict["transactions"]

In [38]:
db.table_dict["transactions"]

Table(df=
        BranchCode  Transaction Serial      Time  customers_id  Style Code  \
0          10284.0             27631.0  12:07:41    10121860.0      3416.0   
1          10124.0             75949.0  10:53:54    10447769.0      3919.0   
2          10706.0             21517.0  21:28:09    10469286.0      2400.0   
3          11008.0              4992.0  21:36:02    10075749.0      3117.0   
4          10002.0             58726.0  18:30:10    10432889.0      3421.0   
...            ...                 ...       ...           ...         ...   
495838     10486.0              4804.0  18:42:41    10443557.0      4541.0   
495841     10317.0             11390.0  19:02:00    10183163.0      4543.0   
495843     10563.0              5333.0  17:28:55    10026437.0      4543.0   
495844     10563.0              5418.0  22:52:44    10478973.0      4542.0   
495848     11019.0               500.0  21:07:05    10462501.0      3471.0   

        Quantity  Base Total Price  Year_x  Month_x  

In [39]:
db.table_dict["branches"]

Table(df=
    BranchCode                  Branch Name   City
0      10284.0              درسا اطلس تبريز  تبريز
1      10124.0                     درسا ارگ  تهران
2      10706.0            فروشگاه درسا قلهک  تهران
3      11008.0                درسا سام سنتر  تهران
4      10002.0                   درسا گاندي  تهران
..         ...                          ...    ...
70     10435.0     درسا احمد آباد - جواهرات   مشهد
71     10311.0           درسا ارگ - جواهرات  تهران
72     10431.0         درسا کرمان - جواهرات  کرمان
73     10743.0  درسا فرودگاه تبريز -جواهرات  تبريز
74     13827.0      درسا حيات سبز - جواهرات  تهران

[75 rows x 3 columns],
  fkey_col_to_pkey_table={},
  pkey_col=BranchCode,
  time_col=None)

In [242]:
r = db.table_dict["article"].df

In [243]:
r

Unnamed: 0,Article Description,Base Unit Price,Style Code,Style Description,Collection Code,Collection Name,Idea Code,Idea Description,Initial Pattern Code,Initial Pattern Description,...,Usage Type,Usage Space,Brand,Size,Season,Year,Model,Combined Feature,Dimensions,articles_id
0,چلسي بوت نقش‌دال,1.174312e+08,1152,چلسی بوت نقش‌دال,0,عمومي,0,عمومي,218,قالب 9717 مردانه,...,SMART CASUAL,Unknown Usage Space,DORSA,42,FALL-WINTER,1402,AKM 6408-BOOT,پنجه متوسط گرد - فرم رویه صاف - فرم زیره پاشنه...,Unknown Dimensions,1152
1,کيف پوچ زنانه کلکسيون آهنگ,8.874312e+07,1495,پوچ آهنگ,0,عمومي,0,عمومي,438,پوچ آهنگ_3206,...,Unknown Usage Type,Unknown Usage Space,DORSA,Unknown Size,SPRING-SUMMER,1402,Unknown Model,چرم گاوی ناپا - تریم زرد - آستر نقش دال مشکی -...,L 38 * H 28 * D 3 CM,1495
2,کيف دستي زنانه کوچک,1.040909e+08,510,کیف زنانه مربعی کوچک با لوگو 65میل درسا,0,عمومي,0,عمومي,262,کیف دسته دار-2400,...,SMART CASUAL,Unknown Usage Space,DORSA,Unknown Size,S2,1403,Unknown Model,لوگو فلزی پنجره ای 9-65میل - رنگ فلز طلایی - چ...,L 24 * H 20.5 * D 9 CM,510
3,کيف پول عمودي آکو,4.500000e+07,958,کیف پول عمودی اکو,44,AKO,0,AKO,328,کیف پول کوچک بای فولد-2286,...,SMART CASUAL,Unknown Usage Space,DORSA,Unknown Size,SPRING-SUMMER,1402,Unknown Model,آستر نقش دال مشکی - فینیش مشکی - چرم گاوی ناپا...,L 10 * H 12 * D 1.5 CM,958
4,اسنيکر بندي زنانه ساقدار با زيره زرد,5.500000e+07,1155,اسنیکر بندی زنانه ساقدار,0,عمومي,0,عمومي,220,قالب 3-049 اسنیکر زنانه,...,CASUAL,Unknown Usage Space,DORSA,37,FALL-WINTER,1399,AKM 4939-WOMEN BOOT,پنجه متوسط گرد - فرم رویه بندی - فرم زیره پاشن...,Unknown Dimensions,1155
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
45487,دستبندنقش دال پرکريستال سواروسکي طلا 14/66سايز M,8.541356e+08,1275,دستبند طلا یک ردیف توری نقش دال پر کریستال سا...,63,DALL,3,DALL,234,النگوی یک ردیف توری نقش دال,...,Unknown Usage Type,Unknown Usage Space,DORSA JEWELRY,14,S1,1403,3586-1,محاسبه وزن طلا - 66 - کریستال 6862,Unknown Dimensions,1275
45522,انگشتر شأن طلاباکريستال سواروسکي دودي 18.26 سا...,1.240587e+09,1979,انگشتر شأن طلا سایز S,93,شأن,3,شأن,494,انگشتر سه وجهی دو دور,...,Unknown Usage Type,Unknown Usage Space,DORSA JEWELRY,18,S1,1403,4664-1,محاسبه وزن طلا - 26 - کریستال 6719 - کریستال 6...,Unknown Dimensions,1979
45523,دستبند خيال بغل دال با کريستال طلا 24/21 سايز S,1.261884e+09,1306,دستبند خیال پر کریستال طلا S,83,KHIYAL,0,KHIYAL,372,النگوی بغل دال پر کریستال,...,Unknown Usage Type,Unknown Usage Space,DORSA JEWELRY,24,SPRING-SUMMER,1402,3632-1,محاسبه وزن طلا - 21 - کریستال 6862 - کریستال 7223,Unknown Dimensions,1306
45524,انگشتر دژ با راکس و کريستال طلا 16.36 سايز L,1.276972e+09,1783,انگشتر دژ با راکس و کریستال طلا سایز L,72,AUDACIOUS,3,AUDACIOUS,498,Unknown Pattern,...,Unknown Usage Type,Unknown Usage Space,DORSA JEWELRY,16,SPRING-SUMMER,1400,4359-1,محاسبه وزن طلا - 36 - کریستال 6719 - کریستال 6669,Unknown Dimensions,1783


In [244]:
db.table_dict["customer"]

Table(df=
       Customer Category  customers_id Customer Process    City
0                  عمومي    10121860.0              طلا   تبريز
1                  عمومي    10447769.0             برنز   تهران
2                  عمومي    10469286.0             برنز   تهران
3                  عمومي    10075749.0        پلاتینیوم   تهران
4                  عمومي    10432889.0             برنز   تهران
...                  ...           ...              ...     ...
176963             عمومي    10467100.0             نقره   اهواز
176964             عمومي    10465852.0             نقره   تهران
176966             عمومي    10204307.0             نقره  اصفهان
176967             عمومي    10481987.0             نقره   تهران
176968             عمومي    10477956.0             نقره   شيراز

[156796 rows x 4 columns],
  fkey_col_to_pkey_table={},
  pkey_col=customers_id,
  time_col=None)

In [40]:
table.df.iloc[table.df["datetime"].idxmax()]


IndexError: single positional indexer is out-of-bounds

In [41]:
table.df.iloc[table.df["datetime"].idxmin()]

BranchCode                           10672.0
Transaction Serial                    9800.0
Time                                19:34:50
customers_id                      10359133.0
Style Code                            5034.0
Quantity                                   1
Base Total Price                 110000000.0
Year_x                                1403.0
Month_x                                  9.0
Day                                     27.0
Season                                 پاییز
Allocated Credit                  17600000.0
Sales Voucher                            0.0
Gift Voucher                             0.0
Cash Card                                0.0
Special Discount                         0.0
Promotion Discount                       0.0
Small Change Discount                    0.0
Total Discounts                   54727273.0
Discount Percentage                     20.0
Day_of_week                                3
even_odd                                   1
Weekday   

In [42]:
register_dataset("Dorsa-aras669", TransactionalDataset)
get_dataset_names()

['rel-amazon',
 'rel-avito',
 'rel-event',
 'rel-f1',
 'rel-hm',
 'rel-stack',
 'rel-trial',
 'Dorsa-aras669']

In [43]:
hyper_dataset = get_dataset("Dorsa-aras669")
hyper_dataset

TransactionalDataset()

In [44]:
hyper_dataset.val_timestamp, hyper_dataset.test_timestamp

(Timestamp('2025-01-03 00:00:00'), Timestamp('2025-01-11 00:00:00'))

In [45]:
import relbench

relbench.__version__

'1.1.0'

In [46]:
import duckdb
import pandas as pd
from relbench.tasks import get_task, get_task_names, register_task
from relbench.base import Database, EntityTask, RecommendationTask, Table, TaskType
from relbench.metrics import (
    accuracy,
    average_precision,
    f1,
    link_prediction_map,
    link_prediction_precision,
    link_prediction_recall,
    mae,
    r2,
    rmse,
    roc_auc,
)
from metrics import link_prediction_top
class UserItemPurchaseTask(RecommendationTask):
    r"""Predict the list of articles each customer will purchase in the next seven
    days."""

    task_type = TaskType.LINK_PREDICTION
    src_entity_col = "customer_id"
    src_entity_table = "customer"
    dst_entity_col = "article_id"
    dst_entity_table = "article"
    time_col = "timestamp"
    timedelta = pd.Timedelta(days=7)
    metrics = [link_prediction_precision, link_prediction_recall, link_prediction_map, link_prediction_top]
    eval_k = 12

    def make_table(self, db: Database, timestamps: "pd.Series[pd.Timestamp]") -> Table:
        customer = db.table_dict["customer"].df
        transactions = db.table_dict["transactions"].df
        timestamp_df = pd.DataFrame({"timestamp": timestamps})

        df = duckdb.sql(
            f"""
            SELECT
                t.timestamp,
                transactions.customer_id,
                LIST(DISTINCT transactions.article_id) AS article_id
            FROM
                timestamp_df t
            LEFT JOIN
                transactions
            ON
                transactions.t_dat > t.timestamp AND
                transactions.t_dat <= t.timestamp + INTERVAL '{self.timedelta} days'
            GROUP BY
                t.timestamp,
                transactions.customer_id
            """
        ).df()

        return Table(
            df=df,
            fkey_col_to_pkey_table={
                self.src_entity_col: self.src_entity_table,
                self.dst_entity_col: self.dst_entity_table,
            },
            pkey_col=None,
            time_col=self.time_col,
        )

# Task 1: Predict articles each customer will purchase in the next 7 days
class CustomerArticlePurchaseTask(RecommendationTask):
    r"""Predict the list of articles each customer will purchase in the next seven days."""
    
    task_type = TaskType.LINK_PREDICTION
    src_entity_col = "customers_id"
    src_entity_table = "customer"
    dst_entity_col = "articles_id"
    dst_entity_table = "article"
    time_col = "timestamp"
    timedelta = pd.Timedelta(days=7)
    metrics = [link_prediction_precision, link_prediction_recall, link_prediction_map, link_prediction_top]
    eval_k = 4

    def make_table(self, db: Database, timestamps: "pd.Series[pd.Timestamp]") -> Table:
        transactions = db.table_dict["transactions"].df
        timestamp_df = pd.DataFrame({"timestamp": timestamps})

        df = duckdb.sql(
            f"""
            SELECT
                t.timestamp,
                transactions.customers_id,
                LIST(DISTINCT transactions.articles_id) AS articles_id
            FROM
                timestamp_df t
            LEFT JOIN
                transactions
            ON
                transactions.datetime > t.timestamp AND
                transactions.datetime <= t.timestamp + INTERVAL '{self.timedelta.days} days'
            GROUP BY
                t.timestamp,
                transactions.customers_id
            """
        ).df()

        return Table(
            df=df,
            fkey_col_to_pkey_table={
                self.src_entity_col: self.src_entity_table,
                self.dst_entity_col: self.dst_entity_table,
            },
            pkey_col=None,
            time_col=self.time_col,
        )


# Task 2: Predict customer churn (no purchases in the next week)
class CustomerChurnTask(EntityTask):
    r"""Predict the churn for a customer (no transactions) in the next 6 days."""

    task_type = TaskType.BINARY_CLASSIFICATION
    entity_col = "customers_id"
    entity_table = "customer"
    time_col = "timestamp"
    target_col = "churn"
    timedelta = pd.Timedelta(days=7)
    metrics = [average_precision, accuracy, f1, roc_auc]

    def make_table(self, db: Database, timestamps: "pd.Series[pd.Timestamp]") -> Table:
        customer = db.table_dict["customer"].df
        transactions = db.table_dict["transactions"].df
        timestamp_df = pd.DataFrame({"timestamp": timestamps})

        df = duckdb.sql(
            f"""
            SELECT
                timestamp,
                customers_id,
                CAST(
                    NOT EXISTS (
                        SELECT 1
                        FROM transactions
                        WHERE
                            transactions.customers_id = customer.customers_id AND
                            transactions.datetime > timestamp AND
                            transactions.datetime <= timestamp + INTERVAL '{self.timedelta}'
                    ) AS INTEGER
                ) AS churn
            FROM
                timestamp_df,
                customer
            WHERE
                EXISTS (
                    SELECT 1
                    FROM transactions
                    WHERE
                        transactions.customers_id = customer.customers_id AND
                        transactions.datetime > timestamp - INTERVAL '{self.timedelta}' AND
                        transactions.datetime <= timestamp
                )
            """
        ).df()

        return Table(
            df=df,
            fkey_col_to_pkey_table={self.entity_col: self.entity_table},
            pkey_col=None,
            time_col=self.time_col,
        ) 
    # def make_table(self, db: Database, timestamps: "pd.Series[pd.Timestamp]") -> Table:
    #     transactions = db.table_dict["transactions"].df
    #     customer = db.table_dict["customer"].df
    #     timestamp_df = pd.DataFrame({"timestamp": timestamps})

    #     df = duckdb.sql(
    #         f"""
    #         SELECT
    #             t.timestamp,
    #             c.customers_id,
    #             CAST(
    #                 NOT EXISTS (
    #                     SELECT 1
    #                     FROM transactions
    #                     WHERE
    #                         transactions.customers_id = c.customers_id AND
    #                         transactions.datetime > t.timestamp AND
    #                         transactions.datetime <= t.timestamp + INTERVAL '{self.timedelta.days} days'
    #                 ) AS INTEGER
    #             ) AS churn
    #         FROM
    #             timestamp_df t,
    #             customer c
    #         WHERE
    #             EXISTS (
    #                 SELECT 1
    #                 FROM transactions
    #                 WHERE
    #                     transactions.customers_id = c.customers_id AND
    #                     transactions.datetime > t.timestamp - INTERVAL '{self.timedelta.days} days' AND
    #                     transactions.datetime <= t.timestamp
    #             )
    #         """
    #     ).df()

    #     return Table(
    #         df=df,
    #         fkey_col_to_pkey_table={self.entity_col: self.entity_table},
    #         pkey_col=None,
    #         time_col=self.time_col,
    #     )


# Task 3: Predict article sales in the next 7 days
class ArticleSalesTask(EntityTask):
    r"""Predict the total sales for an article (sum of `price_purchase`) in the next 7 days."""
    
    task_type = TaskType.REGRESSION
    entity_col = "articles_id"
    entity_table = "article"
    time_col = "datetime"
    target_col = "sales"
    timedelta = pd.Timedelta(days=7)
    metrics = [r2, mae, rmse]

    def make_table(self, db: Database, timestamps: "pd.Series[pd.Timestamp]") -> Table:
        transactions = db.table_dict["transactions"].df
        articles = db.table_dict["article"].df
        timestamp_df = pd.DataFrame({"timestamp": timestamps})

        df = duckdb.sql(
            f"""
            SELECT
                t.timestamp,
                a.articles_id,
                COALESCE(SUM(transactions.price_purchase), 0) AS sales
            FROM
                timestamp_df t,
                article a
            LEFT JOIN
                transactions
            ON
                transactions.articles_id = a.articles_id AND
                transactions.datetime > t.timestamp AND
                transactions.datetime <= t.timestamp + INTERVAL '{self.timedelta.days} days'
            GROUP BY
                t.timestamp,
                a.articles_id
            """
        ).df()

        return Table(
            df=df,
            fkey_col_to_pkey_table={self.entity_col: self.entity_table},
            pkey_col=None,
            time_col=self.time_col,
        )



In [47]:
aras_recom_task = CustomerArticlePurchaseTask(hyper_dataset, cache_dir="D:/Dani/relbench/relbench/cache/hyper_ar9564a656678ks4484386789207487")
aras_recom_task

CustomerArticlePurchaseTask(dataset=TransactionalDataset())

In [48]:
register_task("Dorsa-aras669", "aras_recom_task6669", CustomerArticlePurchaseTask)
get_task_names("Dorsa-aras669")

['aras_recom_task6669']

In [49]:
import numpy as np

from torch.nn import BCEWithLogitsLoss, L1Loss
from relbench.datasets import get_dataset
from relbench.tasks import get_task

dataset = get_dataset("Dorsa-aras669")
task = get_task("Dorsa-aras669", "aras_recom_task6669")


train_table = task.get_table("train")
val_table = task.get_table("val")
test_table = task.get_table("test")

out_channels = 1
loss_fn = BCEWithLogitsLoss()
tune_metric = "link_prediction_map"
higher_is_better = True

Making task table for train split from scratch...
(You can also use `get_task(..., download=True)` for tasks prepared by the RelBench team.)
Making Database object from scratch...
(You can also use `get_dataset(..., download=True)` for datasets prepared by the RelBench team.)
Done in 3.93 seconds.
Caching Database object to C:\Users\KN2C\AppData\Local\relbench\relbench\Cache/Dorsa-aras669/db...
Done in 0.52 seconds.
Loading Database object from C:\Users\KN2C\AppData\Local\relbench\relbench\Cache/Dorsa-aras669/db...
Done in 0.36 seconds.
Done in 9.75 seconds.
Making task table for val split from scratch...
(You can also use `get_task(..., download=True)` for tasks prepared by the RelBench team.)
Done in 0.03 seconds.
Making task table for test split from scratch...
(You can also use `get_task(..., download=True)` for tasks prepared by the RelBench team.)
Loading Database object from C:\Users\KN2C\AppData\Local\relbench\relbench\Cache/Dorsa-aras669/db...
Done in 0.12 seconds.
Done in 0.1

In [50]:
train_table

Table(df=
       timestamp  customers_id articles_id
0     2022-01-14         27706      [1933]
1     2021-12-24          4488        [80]
2     2021-12-17         12132       [556]
3     2021-10-22         20384       [556]
4     2021-07-23          2170      [1104]
...          ...           ...         ...
25086 2021-09-17         23967      [1612]
25087 2021-09-03          7373       [679]
25088 2023-08-18         42718      [1180]
25089 2024-01-26         43030      [1915]
25090 2022-09-02          1469       [556]

[25091 rows x 3 columns],
  fkey_col_to_pkey_table={'customers_id': 'customer', 'articles_id': 'article'},
  pkey_col=None,
  time_col=timestamp)

In [51]:
val_table

Table(df=
    timestamp  customers_id        articles_id
0  2025-01-03         32179             [2148]
1  2025-01-03         21862             [2127]
2  2025-01-03         19916              [361]
3  2025-01-03         28237             [1559]
4  2025-01-03         11875        [679, 1253]
5  2025-01-03         15920             [1481]
6  2025-01-03         27469       [1195, 1267]
7  2025-01-03         42781             [2127]
8  2025-01-03         18498              [361]
9  2025-01-03          7845             [2148]
10 2025-01-03         25184  [1517, 1616, 111]
11 2025-01-03         42928       [2127, 1481]
12 2025-01-03         20404              [176]
13 2025-01-03          1803       [2055, 1628]
14 2025-01-03         32628              [361]
15 2025-01-03         37515             [1788]
16 2025-01-03         43112              [366]
17 2025-01-03         18705  [111, 1517, 1650]
18 2025-01-03         36303              [177]
19 2025-01-03         10433       [1504, 2148]
20 

In [52]:
test_table

Table(df=
    timestamp  customers_id   articles_id
0  2025-01-11         30891  [1616, 1650]
1  2025-01-11         33440        [2127]
2  2025-01-11         26934         [554]
3  2025-01-11          3241        [2127]
4  2025-01-11         24667         [111]
5  2025-01-11         36365        [2127]
6  2025-01-11         43114        [1315]
7  2025-01-11         42928  [1383, 1808]
8  2025-01-11         39877        [2127]
9  2025-01-11         30074        [2148]
10 2025-01-11         20843        [1631]
11 2025-01-11           407        [2127]
12 2025-01-11          2356         [361]
13 2025-01-11         38833   [2127, 199]
14 2025-01-11         16398        [1517]
15 2025-01-11         21401         [554]
16 2025-01-11           378         [361]
17 2025-01-11         30227        [2148]
18 2025-01-11         10445         [199]
19 2025-01-11         41792  [1520, 1074]
20 2025-01-11         42129        [2127],
  fkey_col_to_pkey_table={'customers_id': 'customer', 'articles_i

In [53]:
import os
import math
import numpy as np
from tqdm import tqdm

import torch
import torch_geometric
import torch_frame

# Some book keeping
from torch_geometric.seed import seed_everything

seed_everything(42)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)  # check that it's cuda if you want it to run in reasonable time!
root_dir = "D:/Dani/relbench/relbench/data_ARAS"

  from .autonotebook import tqdm as notebook_tqdm


cuda


In [67]:
from relbench.modeling.utils import get_stype_proposal

db = dataset.get_db()
col_to_stype_dict = get_stype_proposal(db)
col_to_stype_dict

{'article': {'Article Description': <stype.text_embedded: 'text_embedded'>,
  'Base Unit Price': <stype.numerical: 'numerical'>,
  'Style Code': <stype.numerical: 'numerical'>,
  'Style Description': <stype.text_embedded: 'text_embedded'>,
  'Collection Code': <stype.numerical: 'numerical'>,
  'Collection Name': <stype.text_embedded: 'text_embedded'>,
  'Idea Code': <stype.categorical: 'categorical'>,
  'Idea Description': <stype.text_embedded: 'text_embedded'>,
  'Initial Pattern Code': <stype.numerical: 'numerical'>,
  'Initial Pattern Description': <stype.text_embedded: 'text_embedded'>,
  'Product Category': <stype.text_embedded: 'text_embedded'>,
  'Product Group': <stype.text_embedded: 'text_embedded'>,
  'Product Subgroup': <stype.text_embedded: 'text_embedded'>,
  'Product Type': <stype.text_embedded: 'text_embedded'>,
  'Material Category': <stype.text_embedded: 'text_embedded'>,
  'Material Type': <stype.text_embedded: 'text_embedded'>,
  'Processing': <stype.text_embedded: '

In [68]:
# from relbench.modeling.schema import Stype
from torch_frame import stype
# Modify specific column stype manually
col_to_stype_dict['article']['Style Code'] = stype.categorical  
col_to_stype_dict['article']['Collection Code'] = stype.categorical  
col_to_stype_dict['article']['Idea Code'] = stype.categorical  
col_to_stype_dict['article']['Initial Pattern Code'] = stype.categorical  
 
col_to_stype_dict['transactions']['Year_x'] = stype.categorical 
col_to_stype_dict['transactions']['Month_x'] = stype.categorical 
# col_to_stype_dict['transactions']['Day'] = stype.sequence_numerical
# Now you can use col_to_stype_dict in your modeling pipeline
print(col_to_stype_dict)

{'article': {'Article Description': <stype.text_embedded: 'text_embedded'>, 'Base Unit Price': <stype.numerical: 'numerical'>, 'Style Code': <stype.categorical: 'categorical'>, 'Style Description': <stype.text_embedded: 'text_embedded'>, 'Collection Code': <stype.categorical: 'categorical'>, 'Collection Name': <stype.text_embedded: 'text_embedded'>, 'Idea Code': <stype.categorical: 'categorical'>, 'Idea Description': <stype.text_embedded: 'text_embedded'>, 'Initial Pattern Code': <stype.categorical: 'categorical'>, 'Initial Pattern Description': <stype.text_embedded: 'text_embedded'>, 'Product Category': <stype.text_embedded: 'text_embedded'>, 'Product Group': <stype.text_embedded: 'text_embedded'>, 'Product Subgroup': <stype.text_embedded: 'text_embedded'>, 'Product Type': <stype.text_embedded: 'text_embedded'>, 'Material Category': <stype.text_embedded: 'text_embedded'>, 'Material Type': <stype.text_embedded: 'text_embedded'>, 'Processing': <stype.text_embedded: 'text_embedded'>, 'Co

In [69]:
from typing import List, Optional
from sentence_transformers import SentenceTransformer
from torch import Tensor
import torch

class BertPersianTextEmbedding:
    def __init__(self, device: Optional[torch.device] = None):
        # Replace the model with a Persian BERT model
        self.model = SentenceTransformer("HooshvareLab/bert-fa-zwnj-base",  # Example Persian BERT model
            device=device,
        )

    def __call__(self, sentences: List[str]) -> Tensor:
        # Encode the sentences using the Persian BERT model and return as a tensor
        return torch.from_numpy(self.model.encode(sentences))


In [70]:
from torch_frame.config.text_embedder import TextEmbedderConfig
from relbench.modeling.graph import make_pkey_fkey_graph

text_embedder_cfg = TextEmbedderConfig(
    text_embedder=BertPersianTextEmbedding(device=device), batch_size=64
)

data, col_stats_dict = make_pkey_fkey_graph(
    db,
    col_to_stype_dict=col_to_stype_dict,  # speficied column types
    text_embedder_cfg=text_embedder_cfg,  # our chosen text encoder
    cache_dir=os.path.join(
        root_dir, f"rel-aras_recom_materialized_cache7822322232"
    ),  # store materialized graph for convenience
)

No sentence-transformers model found with name HooshvareLab/bert-fa-zwnj-base. Creating a new one with mean pooling.
Some weights of BertModel were not initialized from the model checkpoint at HooshvareLab/bert-fa-zwnj-base and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  ser = pd.to_datetime(ser, format=time_format)
  ser = pd.to_datetime(ser, format=self.format, errors='coerce')
Embedding raw data in mini-batch: 100%|██████████| 4946/4946 [03:36<00:00, 22.85it/s]


In [71]:
data

HeteroData(
  article={ tf=TensorFrame([2168, 28]) },
  branches={ tf=TensorFrame([75, 2]) },
  customer={ tf=TensorFrame([43453, 3]) },
  transactions={
    tf=TensorFrame([316487, 26]),
    time=[316487],
  },
  (transactions, f2p_articles_id, article)={ edge_index=[2, 36410] },
  (article, rev_f2p_articles_id, transactions)={ edge_index=[2, 36410] },
  (transactions, f2p_customers_id, customer)={ edge_index=[2, 316487] },
  (customer, rev_f2p_customers_id, transactions)={ edge_index=[2, 316487] },
  (transactions, f2p_BranchCode, branches)={ edge_index=[2, 316487] },
  (branches, rev_f2p_BranchCode, transactions)={ edge_index=[2, 316487] }
)

## RelBench Recommendation

In [72]:
from typing import Any, Dict, List

import torch
from torch import Tensor
from torch.nn import Embedding, ModuleDict
from torch_frame.data.stats import StatType
from torch_geometric.data import HeteroData
from torch_geometric.nn import MLP
from torch_geometric.typing import NodeType

from relbench.modeling.nn import HeteroEncoder, HeteroGraphSAGE, HeteroTemporalEncoder


class Model(torch.nn.Module):

    def __init__(
        self,
        data: HeteroData,
        col_stats_dict: Dict[str, Dict[str, Dict[StatType, Any]]],
        num_layers: int,
        channels: int,
        out_channels: int,
        aggr: str,
        norm: str,
        # List of node types to add shallow embeddings to input
        shallow_list: List[NodeType] = [],
        # ID awareness
        id_awareness: bool = False,
    ):
        super().__init__()

        self.encoder = HeteroEncoder(
            channels=channels,
            node_to_col_names_dict={
                node_type: data[node_type].tf.col_names_dict
                for node_type in data.node_types
            },
            node_to_col_stats=col_stats_dict,
        )
        self.temporal_encoder = HeteroTemporalEncoder(
            node_types=[
                node_type for node_type in data.node_types if "time" in data[node_type]
            ],
            channels=channels,
        )
        self.gnn = HeteroGraphSAGE(
            node_types=data.node_types,
            edge_types=data.edge_types,
            channels=channels,
            aggr=aggr,
            num_layers=num_layers,
        )
        self.head = MLP(
            channels,
            out_channels=out_channels,
            norm=norm,
            num_layers=1,
        )
        self.embedding_dict = ModuleDict(
            {
                node: Embedding(data.num_nodes_dict[node], channels)
                for node in shallow_list
            }
        )

        self.id_awareness_emb = None
        if id_awareness:
            self.id_awareness_emb = torch.nn.Embedding(1, channels)
        self.reset_parameters()

    def reset_parameters(self):
        self.encoder.reset_parameters()
        self.temporal_encoder.reset_parameters()
        self.gnn.reset_parameters()
        self.head.reset_parameters()
        for embedding in self.embedding_dict.values():
            torch.nn.init.normal_(embedding.weight, std=0.1)
        if self.id_awareness_emb is not None:
            self.id_awareness_emb.reset_parameters()

    def forward(
        self,
        batch: HeteroData,
        entity_table: NodeType,
    ) -> Tensor:
        seed_time = batch[entity_table].seed_time
        x_dict = self.encoder(batch.tf_dict)

        rel_time_dict = self.temporal_encoder(
            seed_time, batch.time_dict, batch.batch_dict
        )

        for node_type, rel_time in rel_time_dict.items():
            x_dict[node_type] = x_dict[node_type] + rel_time

        for node_type, embedding in self.embedding_dict.items():
            x_dict[node_type] = x_dict[node_type] + embedding(batch[node_type].n_id)

        x_dict = self.gnn(
            x_dict,
            batch.edge_index_dict,
            batch.num_sampled_nodes_dict,
            batch.num_sampled_edges_dict,
        )

        return self.head(x_dict[entity_table][: seed_time.size(0)])

    def forward_dst_readout(
        self,
        batch: HeteroData,
        entity_table: NodeType,
        dst_table: NodeType,
    ) -> Tensor:
        if self.id_awareness_emb is None:
            raise RuntimeError(
                "id_awareness must be set True to use forward_dst_readout"
            )
        seed_time = batch[entity_table].seed_time
        x_dict = self.encoder(batch.tf_dict)
        # Add ID-awareness to the root node
        x_dict[entity_table][: seed_time.size(0)] += self.id_awareness_emb.weight

        rel_time_dict = self.temporal_encoder(
            seed_time, batch.time_dict, batch.batch_dict
        )

        for node_type, rel_time in rel_time_dict.items():
            x_dict[node_type] = x_dict[node_type] + rel_time

        for node_type, embedding in self.embedding_dict.items():
            x_dict[node_type] = x_dict[node_type] + embedding(batch[node_type].n_id)

        x_dict = self.gnn(
            x_dict,
            batch.edge_index_dict,
        )

        return self.head(x_dict[dst_table])

In [73]:
import argparse
import copy
import json
import os
import warnings
from pathlib import Path
from typing import Dict, Tuple

import numpy as np
import torch
import torch.nn.functional as F
# from model import Model
# from text_embedder import GloveTextEmbedding
from torch import Tensor
from torch_frame import stype
from torch_frame.config.text_embedder import TextEmbedderConfig
from torch_geometric.loader import NeighborLoader
from torch_geometric.seed import seed_everything
from tqdm import tqdm

from relbench.base import Dataset, RecommendationTask, TaskType
from relbench.datasets import get_dataset
from relbench.modeling.graph import get_link_train_table_input, make_pkey_fkey_graph
from relbench.modeling.loader import SparseTensor
from relbench.modeling.utils import get_stype_proposal
from relbench.tasks import get_task

In [74]:
# Initialize the loader dictionary
loader_dict: Dict[str, NeighborLoader] = {}
dst_nodes_dict: Dict[str, Tuple[NodeType, Tensor]] = {}

# Loop over the train, val, and test splits
for split, table in [
    ("train", train_table),
    ("val", val_table),
    ("test", test_table),
]:
    # Get link train table input for link prediction task
    table_input = get_link_train_table_input(
        table=table,
        task=task,
    )
    
    # Save destination nodes for later use
    dst_nodes_dict[split] = table_input.dst_nodes

    # Create NeighborLoader for link prediction
    loader_dict[split] = NeighborLoader(
        data,
        num_neighbors=[128 for _ in range(2)],  # Sample subgraphs of depth 2, 128 neighbors per node
        time_attr="time",  # Use time attribute if available
        input_nodes=table_input.src_nodes,  # Source nodes for link prediction
        input_time=table_input.src_time,  # Use src_time if time data is available
        subgraph_type="bidirectional",
        batch_size=512,
        temporal_strategy="last",  # Uniform sampling strategy for time
        shuffle=split == "train",  # Shuffle only during training
        num_workers=0,
        persistent_workers=False,
    )


  dst_node_indices = sparse_coo.to_sparse_csr()


In [75]:
# Initialize the model for link prediction task
model = Model(
    data=data,  # Heterogeneous data object
    col_stats_dict=col_stats_dict,  # Column statistics dictionary
    num_layers=2,  # Adjust this to match your desired architecture (depth of GNN)
    channels=128,  # Number of hidden channels in GNN layers
    out_channels=1,  # Output size (for link prediction, usually a scalar per edge)
    aggr="sum",  # Aggregation method (can be "sum", "mean", etc.)
    norm="layer_norm",  # Normalization method
    id_awareness=True,  # Whether the model is aware of node IDs
).to(device)  # Move model to the appropriate device (e.g., GPU)

# Set up the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)  # Use the desired learning rate

# Handling sparse destination nodes for training
# dst_nodes_dict stores the destination nodes for the "train" split (in sparse format)
train_sparse_tensor = SparseTensor(dst_nodes_dict["train"][1], device=device)


In [76]:
def train() -> float:
    model.train()  # Set model to training mode

    loss_accum = count_accum = 0
    steps = 0
    total_steps = min(len(loader_dict["train"]), 2000)  # Change the max_steps_per_epoch to 2000 or your preferred value

    for batch in tqdm(loader_dict["train"], total=total_steps):
        batch = batch.to(device)  # Move batch data to device (GPU or CPU)

        # Forward pass through the model for link prediction (source and destination tables)
        out = model.forward_dst_readout(
            batch, task.src_entity_table, task.dst_entity_table
        ).flatten()  # Flatten the output

        batch_size = batch[task.src_entity_table].batch_size  # Get batch size for the source entity table

        # Get ground-truth labels
        input_id = batch[task.src_entity_table].input_id  # Input IDs for the batch
        src_batch, dst_index = train_sparse_tensor[input_id]  # Get the source and destination indices

        # Get the target labels by checking if source-destination pairs exist
        target = torch.isin(
            batch[task.dst_entity_table].batch
            + batch_size * batch[task.dst_entity_table].n_id,
            src_batch + batch_size * dst_index,
        ).float()  # Convert the result to float for loss computation

        # Optimization
        optimizer.zero_grad()  # Clear previous gradients
        loss = F.binary_cross_entropy_with_logits(out, target)  # Compute binary cross-entropy loss
        loss.backward()  # Backpropagation to compute gradients

        optimizer.step()  # Update model parameters

        # Accumulate the total loss and count for averaging later
        loss_accum += float(loss) * out.numel()
        count_accum += out.numel()

        steps += 1
        if steps >= total_steps:
            break  # Break the loop if max steps per epoch is reached

    # Handle the case where no data was sampled
    if count_accum == 0:
        warnings.warn(
            f"Did not sample a single '{task.dst_entity_table}' node in any mini-batch. "
            "Try increasing the number of layers/hops or reducing the batch size."
        )

    # Return average loss for the epoch
    return loss_accum / count_accum if count_accum > 0 else float("nan")


In [77]:
@torch.no_grad()  # No gradient computation for evaluation
def test(loader: NeighborLoader) -> np.ndarray:
    model.eval()  # Set model to evaluation mode

    pred_list: list[Tensor] = []  # Store predictions
    for batch in tqdm(loader):  # Iterate over batches in the test loader
        batch = batch.to(device)  # Move the batch data to the device (GPU or CPU)

        # Forward pass through the model for link prediction
        out = (
            model.forward_dst_readout(
                batch, task.src_entity_table, task.dst_entity_table
            )
            .detach()
            .flatten()  # Detach the output from the computational graph
        )

        batch_size = batch[task.src_entity_table].batch_size  # Get the batch size for source nodes

        # Prepare a tensor to hold the scores for the source-destination pairs
        scores = torch.zeros(batch_size, task.num_dst_nodes, device=out.device)

        # Fill the scores with sigmoid activations for the destination nodes in the current batch
        scores[
            batch[task.dst_entity_table].batch, batch[task.dst_entity_table].n_id
        ] = torch.sigmoid(out)  # Apply sigmoid activation to get probabilities

        # Use top-k (e.g., top recommended items) based on the scores
        _, pred_mini = torch.topk(scores, k=task.eval_k, dim=1)  # Get top-k predictions
        pred_list.append(pred_mini)  # Append predictions to the list

    # Concatenate all predictions and move to CPU for further processing
    pred = torch.cat(pred_list, dim=0).cpu().numpy()

    return pred  # Return the final predictions as a NumPy array


In [78]:
import copy

# Initialize variables for tracking the best model and best validation metrics
state_dict = None  # This will hold the best model state
best_val_metric = 0  # This will store the best validation metric
epochs = 10  # Set the number of epochs (you can adjust as needed)
eval_epochs_interval = 1  # Evaluate every 'n' epochs (change this based on your needs)
# tune_metric = "link_prediction_map"  # Define the metric you are tuning

# Training and evaluation loop
for epoch in range(1, epochs + 1):
    # Run the training function
    train_loss = train()
    
    # Perform evaluation every 'eval_epochs_interval' epochs
    if epoch % eval_epochs_interval == 0:
        # Run the validation on the validation dataset
        val_pred = test(loader_dict["val"])  # Get the predictions from the model
        val_metrics = task.evaluate(val_pred, task.get_table("val"))  # Evaluate predictions
        
        # Print the training loss and validation metrics
        print(
            f"Epoch: {epoch:02d}, Train loss: {train_loss}, "
            f"Val metrics: {val_metrics}"
        )

        # Check if the current validation metric is the best
        if val_metrics[tune_metric] > best_val_metric:
            best_val_metric = val_metrics[tune_metric]  # Update best metric
            state_dict = copy.deepcopy(model.state_dict())  # Save the best model state

# After training, load the best model weights
model.load_state_dict(state_dict)

# Evaluate the model on the validation set with the best weights
val_pred = test(loader_dict["val"])
val_metrics = task.evaluate(val_pred, task.get_table("val"))
print(f"Best Val metrics: {val_metrics}")

# Evaluate the model on the test set with the best weights
test_pred = test(loader_dict["test"])
test_metrics = task.evaluate(test_pred)
print(f"Best test metrics: {test_metrics}")

 98%|█████████▊| 49/50 [00:04<00:00, 12.18it/s]
100%|██████████| 1/1 [00:00<00:00, 17.64it/s]


Epoch: 01, Train loss: 0.3260339167547758, Val metrics: {'link_prediction_precision': np.float64(0.09523809523809523), 'link_prediction_recall': np.float64(0.30952380952380953), 'link_prediction_map': np.float64(0.2420634920634921), 'link_prediction_top': np.float64(0.38095238095238093)}


 98%|█████████▊| 49/50 [00:02<00:00, 18.15it/s]
100%|██████████| 1/1 [00:00<00:00, 52.77it/s]


Epoch: 02, Train loss: 0.29487549611980124, Val metrics: {'link_prediction_precision': np.float64(0.10714285714285714), 'link_prediction_recall': np.float64(0.3333333333333333), 'link_prediction_map': np.float64(0.28174603174603174), 'link_prediction_top': np.float64(0.42857142857142855)}


 98%|█████████▊| 49/50 [00:02<00:00, 18.12it/s]
100%|██████████| 1/1 [00:00<00:00, 41.83it/s]


Epoch: 03, Train loss: 0.2902663164754045, Val metrics: {'link_prediction_precision': np.float64(0.10714285714285714), 'link_prediction_recall': np.float64(0.3333333333333333), 'link_prediction_map': np.float64(0.28174603174603174), 'link_prediction_top': np.float64(0.42857142857142855)}


 98%|█████████▊| 49/50 [00:02<00:00, 18.40it/s]
100%|██████████| 1/1 [00:00<00:00, 43.54it/s]


Epoch: 04, Train loss: 0.3096420116579359, Val metrics: {'link_prediction_precision': np.float64(0.10714285714285714), 'link_prediction_recall': np.float64(0.3333333333333333), 'link_prediction_map': np.float64(0.23809523809523808), 'link_prediction_top': np.float64(0.42857142857142855)}


 98%|█████████▊| 49/50 [00:02<00:00, 18.41it/s]
100%|██████████| 1/1 [00:00<00:00, 52.82it/s]


Epoch: 05, Train loss: 0.2868384764465908, Val metrics: {'link_prediction_precision': np.float64(0.10714285714285714), 'link_prediction_recall': np.float64(0.3333333333333333), 'link_prediction_map': np.float64(0.246031746031746), 'link_prediction_top': np.float64(0.42857142857142855)}


 98%|█████████▊| 49/50 [00:02<00:00, 18.40it/s]
100%|██████████| 1/1 [00:00<00:00, 47.75it/s]


Epoch: 06, Train loss: 0.2846905599187409, Val metrics: {'link_prediction_precision': np.float64(0.10714285714285714), 'link_prediction_recall': np.float64(0.3333333333333333), 'link_prediction_map': np.float64(0.246031746031746), 'link_prediction_top': np.float64(0.42857142857142855)}


 98%|█████████▊| 49/50 [00:02<00:00, 18.14it/s]
100%|██████████| 1/1 [00:00<00:00, 45.57it/s]


Epoch: 07, Train loss: nan, Val metrics: {'link_prediction_precision': np.float64(0.10714285714285714), 'link_prediction_recall': np.float64(0.3333333333333333), 'link_prediction_map': np.float64(0.2341269841269841), 'link_prediction_top': np.float64(0.42857142857142855)}


 98%|█████████▊| 49/50 [00:02<00:00, 18.29it/s]
100%|██████████| 1/1 [00:00<00:00, 43.54it/s]


Epoch: 08, Train loss: 0.2795940018904442, Val metrics: {'link_prediction_precision': np.float64(0.09523809523809523), 'link_prediction_recall': np.float64(0.30952380952380953), 'link_prediction_map': np.float64(0.21825396825396823), 'link_prediction_top': np.float64(0.38095238095238093)}


 98%|█████████▊| 49/50 [00:02<00:00, 18.18it/s]
100%|██████████| 1/1 [00:00<00:00, 53.93it/s]


Epoch: 09, Train loss: 0.2782855828034849, Val metrics: {'link_prediction_precision': np.float64(0.10714285714285714), 'link_prediction_recall': np.float64(0.3333333333333333), 'link_prediction_map': np.float64(0.2341269841269841), 'link_prediction_top': np.float64(0.42857142857142855)}


 98%|█████████▊| 49/50 [00:02<00:00, 18.25it/s]
100%|██████████| 1/1 [00:00<00:00, 52.77it/s]


Epoch: 10, Train loss: 0.2808184058676759, Val metrics: {'link_prediction_precision': np.float64(0.09523809523809523), 'link_prediction_recall': np.float64(0.30952380952380953), 'link_prediction_map': np.float64(0.21825396825396823), 'link_prediction_top': np.float64(0.38095238095238093)}


100%|██████████| 1/1 [00:00<00:00, 51.30it/s]


Best Val metrics: {'link_prediction_precision': np.float64(0.10714285714285714), 'link_prediction_recall': np.float64(0.3333333333333333), 'link_prediction_map': np.float64(0.28174603174603174), 'link_prediction_top': np.float64(0.42857142857142855)}


100%|██████████| 1/1 [00:00<00:00, 52.79it/s]

Best test metrics: {'link_prediction_precision': np.float64(0.023809523809523808), 'link_prediction_recall': np.float64(0.09523809523809523), 'link_prediction_map': np.float64(0.09523809523809523), 'link_prediction_top': np.float64(0.09523809523809523)}





## Context GNN

In [80]:
import argparse
import json
import os
import warnings
from pathlib import Path
from typing import Dict, List, Tuple, Union

import numpy as np
import torch
import torch.nn.functional as F
from relbench.base import Dataset, RecommendationTask, TaskType
from relbench.datasets import get_dataset
from relbench.modeling.graph import (
    get_link_train_table_input,
    make_pkey_fkey_graph,
)
from relbench.modeling.loader import SparseTensor
from relbench.modeling.utils import get_stype_proposal
from relbench.tasks import get_task
from torch import Tensor
from torch_frame import stype
from torch_frame.config.text_embedder import TextEmbedderConfig
from torch_geometric.loader import NeighborLoader
from torch_geometric.seed import seed_everything
from torch_geometric.typing import NodeType
from torch_geometric.utils.cross_entropy import sparse_cross_entropy
from tqdm import tqdm

from contextgnn.nn.models import IDGNN, ContextGNN, ShallowRHSGNN
from contextgnn.utils import GloveTextEmbedding, RHSEmbeddingMode

# Static configuration parameters
learning_rate = 0.001
epochs = 10
eval_epochs_interval = 1
batch_size = 256
channels = 128
aggregation_method = "sum"
num_layers = 3
num_neighbors = 256
temporal_strategy = "last"
share_same_time = True
max_steps_per_epoch = 2000
num_workers = 0
seed = 42
model_name = "idgnn"  # For example, can be 'idgnn', 'contextgnn', or 'shallowrhsgnn'
tune_metric = "link_prediction_map"  # Metric used to tune the model
cache_dir = os.path.expanduser("~/.cache/relbench_examples")

# Set random seed for reproducibility
torch.manual_seed(seed)


<torch._C.Generator at 0x2163eeb32f0>

In [81]:
# Define static num_neighbors for NeighborLoader
num_neighbors = [num_neighbors // 2**i for i in range(num_layers)]

# Loader dictionary for train, validation, and test sets
loader_dict: Dict[str, NeighborLoader] = {}
dst_nodes_dict: Dict[str, Tuple[NodeType, Tensor]] = {}
num_dst_nodes_dict: Dict[str, int] = {}

# Assuming `task` is already defined and provides the dataset information
for split in ["train", "val", "test"]:
    table = task.get_table(split)
    table_input = get_link_train_table_input(table, task)
    dst_nodes_dict[split] = table_input.dst_nodes
    num_dst_nodes_dict[split] = table_input.num_dst_nodes
    loader_dict[split] = NeighborLoader(
        data,
        num_neighbors=num_neighbors,
        time_attr="time",
        input_nodes=table_input.src_nodes,
        input_time=table_input.src_time,
        subgraph_type="bidirectional",
        batch_size=batch_size,
        temporal_strategy=temporal_strategy,
        shuffle=split == "train",
        num_workers=num_workers,
        persistent_workers=num_workers > 0,
    )


In [82]:
if model_name == "idgnn":
    model = IDGNN(
        data=data,
        col_stats_dict=col_stats_dict,
        num_layers=num_layers,
        channels=channels,
        out_channels=1,
        aggr=aggregation_method,
        norm="layer_norm",
        torch_frame_model_kwargs={
            "channels": 128,
            "num_layers": 4,
        },
    ).to(device)
elif model_name == "contextgnn":
    model = ContextGNN(
        data=data,
        col_stats_dict=col_stats_dict,
        rhs_emb_mode=RHSEmbeddingMode.FUSION,
        dst_entity_table=task.dst_entity_table,
        num_nodes=num_dst_nodes_dict["train"],
        num_layers=num_layers,
        channels=channels,
        aggr="sum",
        norm="layer_norm",
        embedding_dim=128,
        torch_frame_model_kwargs={
            "channels": 128,
            "num_layers": 4,
        },
    ).to(device)
elif model_name == 'shallowrhsgnn':
    model = ShallowRHSGNN(
        data=data,
        col_stats_dict=col_stats_dict,
        rhs_emb_mode=RHSEmbeddingMode.FUSION,
        dst_entity_table=task.dst_entity_table,
        num_nodes=num_dst_nodes_dict["train"],
        num_layers=num_layers,
        channels=channels,
        aggr="sum",
        norm="layer_norm",
        embedding_dim=64,
        torch_frame_model_kwargs={
            "channels": 128,
            "num_layers": 4,
        },
    ).to(device)

elif model_name == "contexttransgnn":
    model = ContextTransGNN(
        data=data,
        col_stats_dict=col_stats_dict,
        rhs_emb_mode=RHSEmbeddingMode.FUSION,
        dst_entity_table=task.dst_entity_table,
        num_nodes=num_dst_nodes_dict["train"],
        num_layers=num_layers,
        channels=channels,
        embedding_dim=64,
        transformer_heads=4,
        aggr="sum",
        norm="layer_norm",
    ).to(device)

else:
    raise ValueError(f"Unsupported model type {model_name}.")
   

In [83]:
# Initialize optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)


In [279]:
def train() -> float:
    model.train()

    loss_accum = count_accum = 0
    steps = 0
    total_steps = min(len(loader_dict["train"]), max_steps_per_epoch)
    sparse_tensor = SparseTensor(dst_nodes_dict["train"][1], device=device)
    
    for batch in tqdm(loader_dict["train"], total=total_steps, desc="Train"):
        batch = batch.to(device)

        # Get ground-truth
        input_id = batch[task.src_entity_table].input_id
        src_batch, dst_index = sparse_tensor[input_id]

        # Optimization
        optimizer.zero_grad()

        logits = model(batch, task.src_entity_table, task.dst_entity_table)
        edge_label_index = torch.stack([src_batch, dst_index], dim=0)
        loss = sparse_cross_entropy(logits, edge_label_index)
        numel = len(batch[task.dst_entity_table].batch)

        loss.backward()
        optimizer.step()

        loss_accum += float(loss) * numel
        count_accum += numel

        steps += 1
        if steps > max_steps_per_epoch:
            break

    if count_accum == 0:
        warnings.warn(f"Did not sample a single '{task.dst_entity_table}' node in any mini-batch.")

    return loss_accum / count_accum if count_accum > 0 else float("nan")

# Test function
@torch.no_grad()
def test(loader: NeighborLoader, desc: str) -> np.ndarray:
    model.eval()

    pred_list: List[Tensor] = []
    for batch in tqdm(loader, desc=desc):
        batch = batch.to(device)

        out = model(batch, task.src_entity_table, task.dst_entity_table).detach()
        scores = torch.sigmoid(out)

        _, pred_mini = torch.topk(scores, k=task.eval_k, dim=1)
        pred_list.append(pred_mini)
    
    pred = torch.cat(pred_list, dim=0).cpu().numpy()
    return pred

# Training and evaluation loop
state_dict = None
best_val_metric = 0
for epoch in range(1, epochs + 1):
    train_loss = train()
    
    if epoch % eval_epochs_interval == 0:
        val_pred = test(loader_dict["val"], desc="Val")
        val_metrics = task.evaluate(val_pred, task.get_table("val"))
        print(f"Epoch: {epoch:02d}, Train loss: {train_loss}, Val metrics: {val_metrics}")

        if val_metrics[tune_metric] > best_val_metric:
            best_val_metric = val_metrics[tune_metric]
            state_dict = {k: v.cpu() for k, v in model.state_dict().items()}

# Load the best model and evaluate on validation and test sets
assert state_dict is not None
model.load_state_dict(state_dict)

val_pred = test(loader_dict["val"], desc="Best val")
val_metrics = task.evaluate(val_pred, task.get_table("val"))
print(f"Best val metrics: {val_metrics}")

test_pred = test(loader_dict["test"], desc="Test")
test_metrics = task.evaluate(test_pred)
print(f"Best test metrics: {test_metrics}")

Train: 100%|██████████| 155/155 [02:30<00:00,  1.03it/s]
Val: 100%|██████████| 1/1 [00:00<00:00, 15.75it/s]


Epoch: 01, Train loss: 8.351733575017283, Val metrics: {'link_prediction_precision': np.float64(0.07758620689655173), 'link_prediction_recall': np.float64(0.25862068965517243), 'link_prediction_map': np.float64(0.1810344827586207), 'link_prediction_top': np.float64(0.3103448275862069)}


Train: 100%|██████████| 155/155 [02:27<00:00,  1.05it/s]
Val: 100%|██████████| 1/1 [00:00<00:00, 17.31it/s]


Epoch: 02, Train loss: 5.607086999546708, Val metrics: {'link_prediction_precision': np.float64(0.05172413793103448), 'link_prediction_recall': np.float64(0.14942528735632182), 'link_prediction_map': np.float64(0.1221264367816092), 'link_prediction_top': np.float64(0.20689655172413793)}


Train: 100%|██████████| 155/155 [02:30<00:00,  1.03it/s]
Val: 100%|██████████| 1/1 [00:00<00:00, 18.64it/s]


Epoch: 03, Train loss: 5.481380947872887, Val metrics: {'link_prediction_precision': np.float64(0.06896551724137931), 'link_prediction_recall': np.float64(0.17816091954022986), 'link_prediction_map': np.float64(0.09482758620689655), 'link_prediction_top': np.float64(0.27586206896551724)}


Train: 100%|██████████| 155/155 [02:35<00:00,  1.01s/it]
Val: 100%|██████████| 1/1 [00:00<00:00, 19.27it/s]


Epoch: 04, Train loss: 5.388827334616773, Val metrics: {'link_prediction_precision': np.float64(0.02586206896551724), 'link_prediction_recall': np.float64(0.08045977011494251), 'link_prediction_map': np.float64(0.023946360153256706), 'link_prediction_top': np.float64(0.10344827586206896)}


Train: 100%|██████████| 155/155 [02:30<00:00,  1.03it/s]
Val: 100%|██████████| 1/1 [00:00<00:00, 19.28it/s]


Epoch: 05, Train loss: 5.292511561220247, Val metrics: {'link_prediction_precision': np.float64(0.0), 'link_prediction_recall': np.float64(0.0), 'link_prediction_map': np.float64(0.0), 'link_prediction_top': np.float64(0.0)}


Train: 100%|██████████| 155/155 [02:30<00:00,  1.03it/s]
Val: 100%|██████████| 1/1 [00:00<00:00, 19.39it/s]


Epoch: 06, Train loss: 5.234641221247929, Val metrics: {'link_prediction_precision': np.float64(0.0), 'link_prediction_recall': np.float64(0.0), 'link_prediction_map': np.float64(0.0), 'link_prediction_top': np.float64(0.0)}


Train: 100%|██████████| 155/155 [02:28<00:00,  1.04it/s]
Val: 100%|██████████| 1/1 [00:00<00:00, 18.89it/s]


Epoch: 07, Train loss: 5.1853479167041865, Val metrics: {'link_prediction_precision': np.float64(0.017241379310344827), 'link_prediction_recall': np.float64(0.04597701149425287), 'link_prediction_map': np.float64(0.020114942528735632), 'link_prediction_top': np.float64(0.06896551724137931)}


Train: 100%|██████████| 155/155 [02:32<00:00,  1.01it/s]
Val: 100%|██████████| 1/1 [00:00<00:00, 18.93it/s]


Epoch: 08, Train loss: 5.142359530692381, Val metrics: {'link_prediction_precision': np.float64(0.017241379310344827), 'link_prediction_recall': np.float64(0.04597701149425287), 'link_prediction_map': np.float64(0.0210727969348659), 'link_prediction_top': np.float64(0.06896551724137931)}


Train: 100%|██████████| 155/155 [02:32<00:00,  1.02it/s]
Val: 100%|██████████| 1/1 [00:00<00:00, 20.05it/s]


Epoch: 09, Train loss: 5.13735747882375, Val metrics: {'link_prediction_precision': np.float64(0.017241379310344827), 'link_prediction_recall': np.float64(0.022988505747126436), 'link_prediction_map': np.float64(0.006704980842911877), 'link_prediction_top': np.float64(0.06896551724137931)}


Train: 100%|██████████| 155/155 [02:29<00:00,  1.04it/s]
Val: 100%|██████████| 1/1 [00:00<00:00, 19.28it/s]


Epoch: 10, Train loss: 5.071819476780299, Val metrics: {'link_prediction_precision': np.float64(0.008620689655172414), 'link_prediction_recall': np.float64(0.011494252873563218), 'link_prediction_map': np.float64(0.0038314176245210726), 'link_prediction_top': np.float64(0.034482758620689655)}


Best val: 100%|██████████| 1/1 [00:00<00:00, 22.79it/s]


Best val metrics: {'link_prediction_precision': np.float64(0.07758620689655173), 'link_prediction_recall': np.float64(0.25862068965517243), 'link_prediction_map': np.float64(0.1810344827586207), 'link_prediction_top': np.float64(0.3103448275862069)}


Test: 100%|██████████| 1/1 [00:00<00:00, 24.46it/s]

Best test metrics: {'link_prediction_precision': np.float64(0.05303030303030303), 'link_prediction_recall': np.float64(0.19696969696969696), 'link_prediction_map': np.float64(0.11742424242424243), 'link_prediction_top': np.float64(0.21212121212121213)}





In [84]:
# Training function
def train() -> float:
    model.train()

    loss_accum = count_accum = 0
    steps = 0
    total_steps = min(len(loader_dict["train"]), max_steps_per_epoch)
    sparse_tensor = SparseTensor(dst_nodes_dict["train"][1], device=device)
    
    for batch in tqdm(loader_dict["train"], total=total_steps, desc="Train"):
        batch = batch.to(device)

        # Get ground-truth
        input_id = batch[task.src_entity_table].input_id
        src_batch, dst_index = sparse_tensor[input_id]

        # Optimization
        optimizer.zero_grad()

        if model_name == 'idgnn':
            out = model(batch, task.src_entity_table, task.dst_entity_table).flatten()
            batch_size = batch[task.src_entity_table].batch_size

            # Get target label
            target = torch.isin(
                batch[task.dst_entity_table].batch +
                batch_size * batch[task.dst_entity_table].n_id,
                src_batch + batch_size * dst_index,
            ).float()

            loss = F.binary_cross_entropy_with_logits(out, target)
            numel = out.numel()
        elif model_name in ['contextgnn', 'shallowrhsgnn']:
            logits = model(batch, task.src_entity_table, task.dst_entity_table)
            edge_label_index = torch.stack([src_batch, dst_index], dim=0)
            loss = sparse_cross_entropy(logits, edge_label_index)
            numel = len(batch[task.dst_entity_table].batch)

        loss.backward()
        optimizer.step()

        loss_accum += float(loss) * numel
        count_accum += numel

        steps += 1
        if steps > max_steps_per_epoch:
            break

    if count_accum == 0:
        warnings.warn(f"Did not sample a single '{task.dst_entity_table}' node in any mini-batch.")

    return loss_accum / count_accum if count_accum > 0 else float("nan")


In [85]:
# Test function
@torch.no_grad()
def test(loader: NeighborLoader, desc: str) -> np.ndarray:
    model.eval()

    pred_list: List[Tensor] = []
    for batch in tqdm(loader, desc=desc):
        batch = batch.to(device)
        batch_size = batch[task.src_entity_table].batch_size

        if model_name == "idgnn":
            out = (model.forward(batch, task.src_entity_table, task.dst_entity_table).detach().flatten())
            scores = torch.zeros(batch_size, task.num_dst_nodes, device=out.device)
            scores[batch[task.dst_entity_table].batch, batch[task.dst_entity_table].n_id] = torch.sigmoid(out)
        elif model_name in ['contextgnn', 'shallowrhsgnn']:
            out = model(batch, task.src_entity_table, task.dst_entity_table).detach()
            scores = torch.sigmoid(out)
        else:
            raise ValueError(f"Unsupported model type: {model_name}.")

        _, pred_mini = torch.topk(scores, k=task.eval_k, dim=1)
        pred_list.append(pred_mini)
    
    pred = torch.cat(pred_list, dim=0).cpu().numpy()
    return pred


In [86]:
# Training and evaluation loop
state_dict = None
best_val_metric = 0
for epoch in range(1, epochs + 1):
    train_loss = train()
    
    if epoch % eval_epochs_interval == 0:
        val_pred = test(loader_dict["val"], desc="Val")
        val_metrics = task.evaluate(val_pred, task.get_table("val"))
        print(f"Epoch: {epoch:02d}, Train loss: {train_loss}, Val metrics: {val_metrics}")

        if val_metrics[tune_metric] > best_val_metric:
            best_val_metric = val_metrics[tune_metric]
            state_dict = {k: v.cpu() for k, v in model.state_dict().items()}

# Load the best model and evaluate on validation and test sets
assert state_dict is not None

model.load_state_dict(state_dict)

val_pred = test(loader_dict["val"], desc="Best val")
val_metrics = task.evaluate(val_pred, task.get_table("val"))
print(f"Best val metrics: {val_metrics}")

test_pred = test(loader_dict["test"], desc="Test")
test_metrics = task.evaluate(test_pred)
print(f"Best test metrics: {test_metrics}")

Train: 100%|██████████| 99/99 [01:18<00:00,  1.26it/s]
Val: 100%|██████████| 1/1 [00:00<00:00, 30.45it/s]


Epoch: 01, Train loss: 0.3077996562725104, Val metrics: {'link_prediction_precision': np.float64(0.08333333333333333), 'link_prediction_recall': np.float64(0.2619047619047619), 'link_prediction_map': np.float64(0.21428571428571427), 'link_prediction_top': np.float64(0.3333333333333333)}


Train: 100%|██████████| 99/99 [01:18<00:00,  1.26it/s]
Val: 100%|██████████| 1/1 [00:00<00:00, 32.31it/s]


Epoch: 02, Train loss: nan, Val metrics: {'link_prediction_precision': np.float64(0.08333333333333333), 'link_prediction_recall': np.float64(0.2619047619047619), 'link_prediction_map': np.float64(0.20833333333333334), 'link_prediction_top': np.float64(0.3333333333333333)}


Train: 100%|██████████| 99/99 [01:23<00:00,  1.19it/s]
Val: 100%|██████████| 1/1 [00:00<00:00, 23.88it/s]


Epoch: 03, Train loss: 0.28528702723000937, Val metrics: {'link_prediction_precision': np.float64(0.09523809523809523), 'link_prediction_recall': np.float64(0.2857142857142857), 'link_prediction_map': np.float64(0.21626984126984125), 'link_prediction_top': np.float64(0.38095238095238093)}


Train: 100%|██████████| 99/99 [01:18<00:00,  1.26it/s]
Val: 100%|██████████| 1/1 [00:00<00:00, 32.37it/s]


Epoch: 04, Train loss: nan, Val metrics: {'link_prediction_precision': np.float64(0.08333333333333333), 'link_prediction_recall': np.float64(0.2619047619047619), 'link_prediction_map': np.float64(0.2103174603174603), 'link_prediction_top': np.float64(0.3333333333333333)}


Train: 100%|██████████| 99/99 [01:17<00:00,  1.28it/s]
Val: 100%|██████████| 1/1 [00:00<00:00, 31.33it/s]


Epoch: 05, Train loss: 0.2793998677638418, Val metrics: {'link_prediction_precision': np.float64(0.08333333333333333), 'link_prediction_recall': np.float64(0.2619047619047619), 'link_prediction_map': np.float64(0.20833333333333334), 'link_prediction_top': np.float64(0.3333333333333333)}


Train: 100%|██████████| 99/99 [01:25<00:00,  1.16it/s]
Val: 100%|██████████| 1/1 [00:00<00:00, 32.34it/s]


Epoch: 06, Train loss: 0.2812479808445714, Val metrics: {'link_prediction_precision': np.float64(0.08333333333333333), 'link_prediction_recall': np.float64(0.2619047619047619), 'link_prediction_map': np.float64(0.2103174603174603), 'link_prediction_top': np.float64(0.3333333333333333)}


Train: 100%|██████████| 99/99 [01:19<00:00,  1.25it/s]
Val: 100%|██████████| 1/1 [00:00<00:00, 32.36it/s]


Epoch: 07, Train loss: 0.2762963330451993, Val metrics: {'link_prediction_precision': np.float64(0.08333333333333333), 'link_prediction_recall': np.float64(0.2619047619047619), 'link_prediction_map': np.float64(0.20833333333333334), 'link_prediction_top': np.float64(0.3333333333333333)}


Train: 100%|██████████| 99/99 [01:15<00:00,  1.30it/s]
Val: 100%|██████████| 1/1 [00:00<00:00, 31.65it/s]


Epoch: 08, Train loss: nan, Val metrics: {'link_prediction_precision': np.float64(0.08333333333333333), 'link_prediction_recall': np.float64(0.2619047619047619), 'link_prediction_map': np.float64(0.2103174603174603), 'link_prediction_top': np.float64(0.3333333333333333)}


Train: 100%|██████████| 99/99 [01:19<00:00,  1.25it/s]
Val: 100%|██████████| 1/1 [00:00<00:00, 27.87it/s]


Epoch: 09, Train loss: nan, Val metrics: {'link_prediction_precision': np.float64(0.07142857142857142), 'link_prediction_recall': np.float64(0.23809523809523808), 'link_prediction_map': np.float64(0.20634920634920634), 'link_prediction_top': np.float64(0.2857142857142857)}


Train: 100%|██████████| 99/99 [01:19<00:00,  1.25it/s]
Val: 100%|██████████| 1/1 [00:00<00:00, 32.35it/s]


Epoch: 10, Train loss: 0.26645567435515677, Val metrics: {'link_prediction_precision': np.float64(0.10714285714285714), 'link_prediction_recall': np.float64(0.3333333333333333), 'link_prediction_map': np.float64(0.25793650793650796), 'link_prediction_top': np.float64(0.42857142857142855)}


Best val: 100%|██████████| 1/1 [00:00<00:00, 31.30it/s]


Best val metrics: {'link_prediction_precision': np.float64(0.10714285714285714), 'link_prediction_recall': np.float64(0.3333333333333333), 'link_prediction_map': np.float64(0.25793650793650796), 'link_prediction_top': np.float64(0.42857142857142855)}


Test: 100%|██████████| 1/1 [00:00<00:00, 31.33it/s]

Best test metrics: {'link_prediction_precision': np.float64(0.03571428571428571), 'link_prediction_recall': np.float64(0.14285714285714285), 'link_prediction_map': np.float64(0.10714285714285714), 'link_prediction_top': np.float64(0.14285714285714285)}





# Explain

In [104]:
test_data = loader_dict["test"].data

In [107]:
test_data.tf_dict

{'article': TensorFrame(
   num_cols=28,
   num_rows=2168,
   numerical (1): ['Base Unit Price'],
   categorical (7): ['Collection Code', 'Idea Code', 'Initial Pattern Code', 'Season', 'Style Code', 'Usage Type', 'Year'],
   embedding (20): ['Article Description', 'Brand', 'Collection Name', 'Color', 'Combined Feature', 'Dimensions', 'Idea Description', 'Initial Pattern Description', 'Material Category', 'Material Type', 'Model', 'Processing', 'Product Category', 'Product Group', 'Product Subgroup', 'Product Type', 'Size', 'Style Description', 'Usage Space', 'User'],
   has_target=False,
   device='cpu',
 ),
 'branches': TensorFrame(
   num_cols=2,
   num_rows=75,
   embedding (2): ['Branch Name', 'City'],
   has_target=False,
   device='cpu',
 ),
 'customer': TensorFrame(
   num_cols=3,
   num_rows=43453,
   categorical (3): ['City', 'Customer Category', 'Customer Process'],
   has_target=False,
   device='cpu',
 ),
 'transactions': TensorFrame(
   num_cols=26,
   num_rows=316487,
   

In [None]:
if model_name == "idgnn":
    model = IDGNN(
        data=data,
        col_stats_dict=col_stats_dict,
        num_layers=num_layers,
        channels=channels,
        out_channels=1,
        aggr=aggregation_method,
        norm="layer_norm",
        torch_frame_model_kwargs={
            "channels": 128,
            "num_layers": 4,
        },
    ).to(device)



{('transactions', 'f2p_articles_id', 'article'): tensor([[    16,     24,     31,  ..., 311381, 311451, 312867],
        [   580,      8,    561,  ...,    587,   2133,    556]]), ('article', 'rev_f2p_articles_id', 'transactions'): tensor([[     1,      8,      8,  ...,   2162,   2162,   2162],
        [277007,     24,     91,  ..., 279713, 279714, 279715]]), ('transactions', 'f2p_customers_id', 'customer'): tensor([[     0,      1,      2,  ..., 316484, 316485, 316486],
        [     0,      1,      2,  ...,  23449,  39691,  35513]]), ('customer', 'rev_f2p_customers_id', 'transactions'): tensor([[     0,      0,      0,  ...,  43452,  43452,  43452],
        [     0,    179,    603,  ..., 315994, 316101, 316108]]), ('transactions', 'f2p_BranchCode', 'branches'): tensor([[     0,      1,      2,  ..., 316484, 316485, 316486],
        [     0,      1,      2,  ...,     54,     54,     65]]), ('branches', 'rev_f2p_BranchCode', 'transactions'): tensor([[     0,      0,      0,  ...,     74

In [116]:
# Ensure batch is correctly extracted
batch = loader_dict["test"].data  # Extract HeteroData batch

# Extract the correct edge index
edge_index = batch.edge_index_dict[("transactions", "f2p_customers_id", "customer")]

# Ensure correct input format
if edge_index is None:
    raise ValueError("Edge index not found for the given edge type.")


In [121]:
from torch_geometric.explain import Explainer, CaptumExplainer


# Define the Explainer
explainer = Explainer(
    model=model,
    algorithm=CaptumExplainer('IntegratedGradients'),
    explanation_type='model',
    model_config=dict(
        mode='regression',  # or 'classification'
        task_level='edge',
        return_type='raw',
    ),
    node_mask_type='attributes',  # Ensure attributes are explainable
    edge_mask_type='object',
)

# Running the explanation process
explanation = explainer(
    x=batch,  # Pass the entire batch (HeteroData object)
    entity_table="transactions",  # Correct entity node type
    dst_table="customer",  # Correct destination node type
    edge_index=edge_index,  # Pass edge relationships
    index=0  # Select a sample edge
).cpu().detach()

print(explanation)


TypeError: IDGNN.forward() got multiple values for argument 'entity_table'

In [182]:
# Training and evaluation loop
state_dict = None
best_val_metric = 0
for epoch in range(1, epochs + 1):
    train_loss = train()
    
    if epoch % eval_epochs_interval == 0:
        val_pred = test(loader_dict["val"], desc="Val")
        val_metrics = task.evaluate(val_pred, task.get_table("val"))
        print(f"Epoch: {epoch:02d}, Train loss: {train_loss}, Val metrics: {val_metrics}")

        if val_metrics[tune_metric] > best_val_metric:
            best_val_metric = val_metrics[tune_metric]
            state_dict = {k: v.cpu() for k, v in model.state_dict().items()}

# Load the best model and evaluate on validation and test sets
assert state_dict is not None
model.load_state_dict(state_dict)

val_pred = test(loader_dict["val"], desc="Best val")
val_metrics = task.evaluate(val_pred, task.get_table("val"))
print(f"Best val metrics: {val_metrics}")

test_pred = test(loader_dict["test"], desc="Test")
test_metrics = task.evaluate(test_pred)
print(f"Best test metrics: {test_metrics}")

Train: 100%|██████████| 1130/1130 [56:08<00:00,  2.98s/it] 
Val: 100%|██████████| 5/5 [00:07<00:00,  1.56s/it]


Epoch: 01, Train loss: 15.693553883077128, Val metrics: {'link_prediction_precision': np.float64(0.14345637583892618), 'link_prediction_recall': np.float64(0.21569590923617768), 'link_prediction_map': np.float64(0.15491004847129008), 'link_prediction_top': np.float64(0.47315436241610737)}


Train: 100%|██████████| 1130/1130 [57:53<00:00,  3.07s/it] 
Val: 100%|██████████| 5/5 [00:07<00:00,  1.53s/it]


Epoch: 02, Train loss: 14.202766350519356, Val metrics: {'link_prediction_precision': np.float64(0.10339765100671142), 'link_prediction_recall': np.float64(0.15851809683604984), 'link_prediction_map': np.float64(0.10915944258016405), 'link_prediction_top': np.float64(0.36325503355704697)}


Train: 100%|██████████| 1130/1130 [57:31<00:00,  3.05s/it] 
Val: 100%|██████████| 5/5 [00:07<00:00,  1.53s/it]


Epoch: 03, Train loss: 14.02156443823075, Val metrics: {'link_prediction_precision': np.float64(0.06040268456375839), 'link_prediction_recall': np.float64(0.09708872643016939), 'link_prediction_map': np.float64(0.06484200223713646), 'link_prediction_top': np.float64(0.21140939597315436)}


Train: 100%|██████████| 1130/1130 [56:41<00:00,  3.01s/it] 
Val: 100%|██████████| 5/5 [00:07<00:00,  1.47s/it]


Epoch: 04, Train loss: 13.868679549075418, Val metrics: {'link_prediction_precision': np.float64(0.061870805369127514), 'link_prediction_recall': np.float64(0.09752816395014381), 'link_prediction_map': np.float64(0.06422445935868755), 'link_prediction_top': np.float64(0.21728187919463088)}


Train: 100%|██████████| 1130/1130 [55:50<00:00,  2.97s/it] 
Val: 100%|██████████| 5/5 [00:07<00:00,  1.52s/it]


Epoch: 05, Train loss: 13.76965204429071, Val metrics: {'link_prediction_precision': np.float64(0.05641778523489933), 'link_prediction_recall': np.float64(0.08841083413231063), 'link_prediction_map': np.float64(0.06416620059656972), 'link_prediction_top': np.float64(0.19546979865771813)}


Train: 100%|██████████| 1130/1130 [55:39<00:00,  2.96s/it] 
Val: 100%|██████████| 5/5 [00:07<00:00,  1.47s/it]


Epoch: 06, Train loss: 13.730094240045654, Val metrics: {'link_prediction_precision': np.float64(0.05893456375838926), 'link_prediction_recall': np.float64(0.0952141259188239), 'link_prediction_map': np.float64(0.0637117822520507), 'link_prediction_top': np.float64(0.20553691275167785)}


Train: 100%|██████████| 1130/1130 [55:46<00:00,  2.96s/it] 
Val: 100%|██████████| 5/5 [00:07<00:00,  1.47s/it]


Epoch: 07, Train loss: 13.672291335220528, Val metrics: {'link_prediction_precision': np.float64(0.05725671140939597), 'link_prediction_recall': np.float64(0.09168264621284755), 'link_prediction_map': np.float64(0.061964019388516034), 'link_prediction_top': np.float64(0.1988255033557047)}


Train: 100%|██████████| 1130/1130 [56:02<00:00,  2.98s/it] 
Val: 100%|██████████| 5/5 [00:07<00:00,  1.48s/it]


Epoch: 08, Train loss: 13.573945471917492, Val metrics: {'link_prediction_precision': np.float64(0.057885906040268456), 'link_prediction_recall': np.float64(0.0920881271971876), 'link_prediction_map': np.float64(0.06313502050708426), 'link_prediction_top': np.float64(0.20134228187919462)}


Train: 100%|██████████| 1130/1130 [55:56<00:00,  2.97s/it] 
Val: 100%|██████████| 5/5 [00:07<00:00,  1.48s/it]


Epoch: 09, Train loss: 13.533757880756776, Val metrics: {'link_prediction_precision': np.float64(0.05683724832214765), 'link_prediction_recall': np.float64(0.09063398849472674), 'link_prediction_map': np.float64(0.06106683445190156), 'link_prediction_top': np.float64(0.19714765100671142)}


Train: 100%|██████████| 1130/1130 [56:42<00:00,  3.01s/it] 
Val: 100%|██████████| 5/5 [00:07<00:00,  1.48s/it]


Epoch: 10, Train loss: 13.484285618741263, Val metrics: {'link_prediction_precision': np.float64(0.057466442953020135), 'link_prediction_recall': np.float64(0.09096356663470757), 'link_prediction_map': np.float64(0.06195819351230424), 'link_prediction_top': np.float64(0.1988255033557047)}


Best val: 100%|██████████| 5/5 [00:07<00:00,  1.48s/it]


Best val metrics: {'link_prediction_precision': np.float64(0.14345637583892618), 'link_prediction_recall': np.float64(0.21569590923617768), 'link_prediction_map': np.float64(0.15491004847129008), 'link_prediction_top': np.float64(0.47315436241610737)}


Test: 100%|██████████| 10/10 [00:10<00:00,  1.07s/it]


Best test metrics: {'link_prediction_precision': np.float64(0.13122721749696234), 'link_prediction_recall': np.float64(0.18079086657458468), 'link_prediction_map': np.float64(0.14091681292471084), 'link_prediction_top': np.float64(0.41271769947347103)}


In [280]:
model_save_path = "best_fullDorsa_model1.pth"
torch.save(state_dict, model_save_path)

In [63]:
model_save_path = "best_Dorsa_model.pth"
state_dict = torch.load(model_save_path)
model.load_state_dict(state_dict)

<All keys matched successfully>

In [64]:
test_pred = test(loader_dict["test"], desc="Test")
test_metrics = task.evaluate(test_pred)
print(f"Best test metrics: {test_metrics}")

Test: 100%|██████████| 10/10 [00:06<00:00,  1.49it/s]

Best test metrics: {'link_prediction_precision': np.float64(0.12829080599432968), 'link_prediction_recall': np.float64(0.17587147787633814), 'link_prediction_map': np.float64(0.1295818144997975), 'link_prediction_top': np.float64(0.40785743215876874)}





In [184]:
test_pred

array([[  72,  145,  105,  249],
       [  72,  145,  105, 1503],
       [  72,  145,  105, 1503],
       ...,
       [  72,  145,  105, 1073],
       [  72,  145,  105, 1503],
       [  72,  105,  145,  323]])

In [185]:
customer_ids = test_table.df["customers_id"].values  # Extract the customer_id column
predicted_labels = test_pred  # This contains the model's predicted labels

# Now zip customer_ids with the predictions
results = list(zip(customer_ids, predicted_labels))

# If you'd like to show or process the results, you can format them as a dataframe
results_df = pd.DataFrame(results, columns=["customers_id", "predicted_articles"])


In [186]:
print(results_df)

      customers_id    predicted_articles
0             9806   [72, 145, 105, 249]
1           126374  [72, 145, 105, 1503]
2            12345  [72, 145, 105, 1503]
3           126357  [72, 145, 105, 1503]
4           132536  [72, 145, 105, 1073]
...            ...                   ...
2464        102668   [105, 145, 72, 140]
2465         98453  [72, 145, 105, 1678]
2466         96365  [72, 145, 105, 1073]
2467        140302  [72, 145, 105, 1503]
2468          1489   [72, 105, 145, 323]

[2469 rows x 2 columns]


In [190]:
index_map_dict1 = db.reindex_pkeys_and_fkeys()
print(index_map_dict1)

{'article': articles_id
54821.0       0
53612.0       1
59459.0       2
54126.0       3
44301.0       4
           ... 
17315.0    7435
23028.0    7436
46775.0    7437
39866.0    7438
42931.0    7439
Name: index, Length: 7440, dtype: Int64, 'customer': customers_id
10121860.0         0
10447769.0         1
10469286.0         2
10075749.0         3
10432889.0         4
               ...  
10461949.0    146774
10477874.0    146775
10411981.0    146776
10440367.0    146777
10455252.0    146778
Name: index, Length: 146779, dtype: Int64, 'branches': BranchCode
10284.0     0
10124.0     1
10706.0     2
11008.0     3
10002.0     4
10100.0     5
10131.0     6
11086.0     7
10562.0     8
10593.0     9
10168.0    10
10267.0    11
10887.0    12
10008.0    13
10006.0    14
10003.0    15
10099.0    16
10656.0    17
10672.0    18
10536.0    19
10594.0    20
10113.0    21
10595.0    22
10732.0    23
11155.0    24
10642.0    25
10803.0    26
10001.0    27
10643.0    28
10709.0    29
10748.0    30
100

In [191]:
# Access the 'article' Series from index_map_dict1
article_series = index_map_dict1['article']

# Get the data type of the 'article' Series
article_dtype = article_series.dtype

# Print the data type
print(article_dtype)


Int64


In [192]:
reverse_mapping_articles

{np.int64(0): 54821.0,
 np.int64(1): 53612.0,
 np.int64(2): 59459.0,
 np.int64(3): 54126.0,
 np.int64(4): 44301.0,
 np.int64(5): 28378.0,
 np.int64(6): 44023.0,
 np.int64(7): 47122.0,
 np.int64(8): 54025.0,
 np.int64(9): 41372.0,
 np.int64(10): 42282.0,
 np.int64(11): 57538.0,
 np.int64(12): 44592.0,
 np.int64(13): 59476.0,
 np.int64(14): 34528.0,
 np.int64(15): 53435.0,
 np.int64(16): 54161.0,
 np.int64(17): 35142.0,
 np.int64(18): 47567.0,
 np.int64(19): 49489.0,
 np.int64(20): 31794.0,
 np.int64(21): 18213.0,
 np.int64(22): 49482.0,
 np.int64(23): 41415.0,
 np.int64(24): 24837.0,
 np.int64(25): 50084.0,
 np.int64(26): 47535.0,
 np.int64(27): 53614.0,
 np.int64(28): 44382.0,
 np.int64(29): 22012.0,
 np.int64(30): 24294.0,
 np.int64(31): 32147.0,
 np.int64(32): 59518.0,
 np.int64(33): 34656.0,
 np.int64(34): 41609.0,
 np.int64(35): 22439.0,
 np.int64(36): 38374.0,
 np.int64(37): 32658.0,
 np.int64(38): 54052.0,
 np.int64(39): 59457.0,
 np.int64(40): 59486.0,
 np.int64(41): 54293.0,
 n

In [193]:
# Assuming reverse_mapping_articles is already defined
reverse_mapping_articles = {v: k for k, v in index_map_dict1['article'].items()}

# Apply reverse mapping to each article in the list
results_df['mapped_articles_id'] = results_df['predicted_articles'].apply(
    lambda x: [reverse_mapping_articles.get(article, article) for article in x]
)


In [194]:
reverse_mapping = {v: k for k, v in index_map_dict1['customer'].items()}

# Map back to original customers_id
results_df['mapped_customers_id'] = results_df['customers_id'].map(reverse_mapping)

In [134]:
# Assuming reverse_mapping_articles is already defined
reverse_mapping_articles = {v: k for k, v in index_map_dict1['article'].items()}

# Apply reverse mapping to each article in the list and cast the result to float64
results_df['mapped_articles_id'] = results_df['predicted_articles'].apply(
    lambda x: [float(reverse_mapping_articles.get(article, article)) for article in x]
)

# Convert the mapped_articles_id column to the desired format (float64)
results_df['mapped_articles_id'] = results_df['mapped_articles_id'].apply(pd.Series).stack().astype('float64')

# Check the data type of the 'mapped_articles_id' column
print(results_df['mapped_articles_id'].dtype)




TypeError: incompatible index of inserted column with frame index

In [137]:
results_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 2469 entries, 0 to 2468
Data columns (total 4 columns):
 #   Column               Non-Null Count  Dtype  
---  ------               --------------  -----  
 0   customers_id         2469 non-null   int64  
 1   predicted_articles   2469 non-null   object 
 2   mapped_articles_id   2469 non-null   object 
 3   mapped_customers_id  1436 non-null   float64
dtypes: float64(1), int64(1), object(2)
memory usage: 77.3+ KB


In [195]:
results_df

Unnamed: 0,customers_id,predicted_articles,mapped_articles_id,mapped_customers_id
0,9806,"[72, 145, 105, 249]","[42413.0, 38471.0, 55102.0, 35592.0]",10421548.0
1,126374,"[72, 145, 105, 1503]","[42413.0, 38471.0, 55102.0, 41647.0]",10377320.0
2,12345,"[72, 145, 105, 1503]","[42413.0, 38471.0, 55102.0, 41647.0]",10292410.0
3,126357,"[72, 145, 105, 1503]","[42413.0, 38471.0, 55102.0, 41647.0]",10450448.0
4,132536,"[72, 145, 105, 1073]","[42413.0, 38471.0, 55102.0, 53840.0]",10436446.0
...,...,...,...,...
2464,102668,"[105, 145, 72, 140]","[55102.0, 38471.0, 42413.0, 38470.0]",10213421.0
2465,98453,"[72, 145, 105, 1678]","[42413.0, 38471.0, 55102.0, 29925.0]",10444305.0
2466,96365,"[72, 145, 105, 1073]","[42413.0, 38471.0, 55102.0, 53840.0]",10125017.0
2467,140302,"[72, 145, 105, 1503]","[42413.0, 38471.0, 55102.0, 41647.0]",10465791.0


In [196]:
nan_count = results_df['mapped_customers_id'].isna().sum()
print(f"Number of NaN values: {nan_count}")

Number of NaN values: 106


In [197]:
results_df = results_df.dropna(subset=['mapped_customers_id'])

In [198]:
test_data1 = task.get_table("test", mask_input_cols=False)
test_data1

Table(df=
      timestamp  customers_id                              articles_id
0    2025-01-11          9806                     [981, 1491, 72, 105]
1    2025-01-11        126374                 [2391, 3426, 6226, 1830]
2    2025-01-11         12345  [6892, 1106, 840, 5991, 1780, 977, 145]
3    2025-01-11        126357     [1949, 7190, 5265, 5849, 1290, 7344]
4    2025-01-11        132536                                    [869]
...         ...           ...                                      ...
2464 2025-01-11        102668                                    [879]
2465 2025-01-11         98453                                   [1235]
2466 2025-01-11         96365                                   [1140]
2467 2025-01-11        140302                                    [168]
2468 2025-01-11          1489                                    [934]

[2469 rows x 3 columns],
  fkey_col_to_pkey_table={'customers_id': 'customer', 'articles_id': 'article'},
  pkey_col=None,
  time_col=tim

In [199]:
customer_ids = test_data1.df["customers_id"].values  # Extract the customer_id column
predicted_labels = test_data1.df["articles_id"].values  # This contains the model's predicted labels

# Now zip customer_ids with the predictions
results = list(zip(customer_ids, predicted_labels))

# If you'd like to show or process the results, you can format them as a dataframe
results_df1 = pd.DataFrame(results, columns=["customers_id", "groundtruth_articles"])

In [200]:
# Assuming reverse_mapping_articles is already defined
reverse_mapping_articles = {v: k for k, v in index_map_dict1['article'].items()}

# Apply reverse mapping to each article in the list
results_df1['mapped_articles_id'] = results_df1['groundtruth_articles'].apply(
    lambda x: [reverse_mapping_articles.get(article, article) for article in x]
)

In [201]:
reverse_mapping = {v: k for k, v in index_map_dict1['customer'].items()}

# Map back to original customers_id
results_df1['mapped_customers_id'] = results_df1['customers_id'].map(reverse_mapping)

In [202]:
results_df1

Unnamed: 0,customers_id,groundtruth_articles,mapped_articles_id,mapped_customers_id
0,9806,"[981, 1491, 72, 105]","[53407.0, 52402.0, 42413.0, 55102.0]",10421548.0
1,126374,"[2391, 3426, 6226, 1830]","[53183.0, 32590.0, 45394.0, 47141.0]",10377320.0
2,12345,"[6892, 1106, 840, 5991, 1780, 977, 145]","[16274.0, 59028.0, 52452.0, 30373.0, 11176.0, ...",10292410.0
3,126357,"[1949, 7190, 5265, 5849, 1290, 7344]","[41781.0, 59161.0, 35600.0, 16144.0, 38860.0, ...",10450448.0
4,132536,[869],[53251.0],10436446.0
...,...,...,...,...
2464,102668,[879],[52390.0],10213421.0
2465,98453,[1235],[48204.0],10444305.0
2466,96365,[1140],[55073.0],10125017.0
2467,140302,[168],[32655.0],10465791.0


In [203]:
results_df1 = results_df1.dropna(subset=['mapped_customers_id'])

In [204]:
merged_df = pd.merge(results_df, results_df1, on='mapped_customers_id', how='inner')

# Print the result
print(merged_df)

      customers_id_x    predicted_articles  \
0               9806   [72, 145, 105, 249]   
1             126374  [72, 145, 105, 1503]   
2              12345  [72, 145, 105, 1503]   
3             126357  [72, 145, 105, 1503]   
4             132536  [72, 145, 105, 1073]   
...              ...                   ...   
2358          102668   [105, 145, 72, 140]   
2359           98453  [72, 145, 105, 1678]   
2360           96365  [72, 145, 105, 1073]   
2361          140302  [72, 145, 105, 1503]   
2362            1489   [72, 105, 145, 323]   

                      mapped_articles_id_x  mapped_customers_id  \
0     [42413.0, 38471.0, 55102.0, 35592.0]           10421548.0   
1     [42413.0, 38471.0, 55102.0, 41647.0]           10377320.0   
2     [42413.0, 38471.0, 55102.0, 41647.0]           10292410.0   
3     [42413.0, 38471.0, 55102.0, 41647.0]           10450448.0   
4     [42413.0, 38471.0, 55102.0, 53840.0]           10436446.0   
...                                    ...   

In [205]:
merged_df

Unnamed: 0,customers_id_x,predicted_articles,mapped_articles_id_x,mapped_customers_id,customers_id_y,groundtruth_articles,mapped_articles_id_y
0,9806,"[72, 145, 105, 249]","[42413.0, 38471.0, 55102.0, 35592.0]",10421548.0,9806,"[981, 1491, 72, 105]","[53407.0, 52402.0, 42413.0, 55102.0]"
1,126374,"[72, 145, 105, 1503]","[42413.0, 38471.0, 55102.0, 41647.0]",10377320.0,126374,"[2391, 3426, 6226, 1830]","[53183.0, 32590.0, 45394.0, 47141.0]"
2,12345,"[72, 145, 105, 1503]","[42413.0, 38471.0, 55102.0, 41647.0]",10292410.0,12345,"[6892, 1106, 840, 5991, 1780, 977, 145]","[16274.0, 59028.0, 52452.0, 30373.0, 11176.0, ..."
3,126357,"[72, 145, 105, 1503]","[42413.0, 38471.0, 55102.0, 41647.0]",10450448.0,126357,"[1949, 7190, 5265, 5849, 1290, 7344]","[41781.0, 59161.0, 35600.0, 16144.0, 38860.0, ..."
4,132536,"[72, 145, 105, 1073]","[42413.0, 38471.0, 55102.0, 53840.0]",10436446.0,132536,[869],[53251.0]
...,...,...,...,...,...,...,...
2358,102668,"[105, 145, 72, 140]","[55102.0, 38471.0, 42413.0, 38470.0]",10213421.0,102668,[879],[52390.0]
2359,98453,"[72, 145, 105, 1678]","[42413.0, 38471.0, 55102.0, 29925.0]",10444305.0,98453,[1235],[48204.0]
2360,96365,"[72, 145, 105, 1073]","[42413.0, 38471.0, 55102.0, 53840.0]",10125017.0,96365,[1140],[55073.0]
2361,140302,"[72, 145, 105, 1503]","[42413.0, 38471.0, 55102.0, 41647.0]",10465791.0,140302,[168],[32655.0]


In [206]:
merged_df.to_csv('C:\\Users\\KN2C\Desktop\\Dani\\contextgnn\\result_recom2.csv', index=False)

  merged_df.to_csv('C:\\Users\\KN2C\Desktop\\Dani\\contextgnn\\result_recom2.csv', index=False)


In [37]:
# Training and evaluation loop
state_dict = None
best_val_metric = 0
for epoch in range(1, epochs + 1):
    train_loss = train()
    
    if epoch % eval_epochs_interval == 0:
        val_pred = test(loader_dict["val"], desc="Val")
        val_metrics = task.evaluate(val_pred, task.get_table("val"))
        print(f"Epoch: {epoch:02d}, Train loss: {train_loss}, Val metrics: {val_metrics}")

        if val_metrics[tune_metric] > best_val_metric:
            best_val_metric = val_metrics[tune_metric]
            state_dict = {k: v.cpu() for k, v in model.state_dict().items()}

# Load the best model and evaluate on validation and test sets
assert state_dict is not None
model.load_state_dict(state_dict)

val_pred = test(loader_dict["val"], desc="Best val")
val_metrics = task.evaluate(val_pred, task.get_table("val"))
print(f"Best val metrics: {val_metrics}")

test_pred = test(loader_dict["test"], desc="Test")
test_metrics = task.evaluate(test_pred)
print(f"Best test metrics: {test_metrics}")

Train: 100%|██████████| 2000/2000 [4:26:32<00:00,  8.00s/it]  
Val: 100%|██████████| 59/59 [10:40<00:00, 10.86s/it]


Epoch: 01, Train loss: 72.34259055215553, Val metrics: {'link_prediction_precision': np.float64(0.11947518388443443), 'link_prediction_recall': np.float64(0.08714295934837625), 'link_prediction_map': np.float64(0.10646246106951163), 'link_prediction_top': np.float64(0.3480219998674707)}


Train: 100%|██████████| 2000/2000 [4:21:52<00:00,  7.86s/it]  
Val: 100%|██████████| 59/59 [10:39<00:00, 10.84s/it]


Epoch: 02, Train loss: 69.44414806568167, Val metrics: {'link_prediction_precision': np.float64(0.0778444105758399), 'link_prediction_recall': np.float64(0.059326035200976276), 'link_prediction_map': np.float64(0.060867772550232284), 'link_prediction_top': np.float64(0.25134185938638925)}


Train: 100%|██████████| 2000/2000 [4:20:35<00:00,  7.82s/it]  
Val: 100%|██████████| 59/59 [10:41<00:00, 10.88s/it]


Epoch: 03, Train loss: 68.62711653797692, Val metrics: {'link_prediction_precision': np.float64(0.04976476045325028), 'link_prediction_recall': np.float64(0.03974686011645201), 'link_prediction_map': np.float64(0.03744229452433018), 'link_prediction_top': np.float64(0.167517063150222)}


Train: 100%|██████████| 2000/2000 [4:21:12<00:00,  7.84s/it]  
Val: 100%|██████████| 59/59 [10:42<00:00, 10.89s/it]


Epoch: 04, Train loss: 68.28201239375534, Val metrics: {'link_prediction_precision': np.float64(0.04057053873169439), 'link_prediction_recall': np.float64(0.033005813262233834), 'link_prediction_map': np.float64(0.030213464242852622), 'link_prediction_top': np.float64(0.13809555364124312)}


Train: 100%|██████████| 2000/2000 [4:20:12<00:00,  7.81s/it]  
Val: 100%|██████████| 59/59 [10:39<00:00, 10.84s/it]


Epoch: 05, Train loss: 67.99987211956321, Val metrics: {'link_prediction_precision': np.float64(0.03801934928102843), 'link_prediction_recall': np.float64(0.03166580717627368), 'link_prediction_map': np.float64(0.029117336307880336), 'link_prediction_top': np.float64(0.13074017626399842)}


Train:  35%|███▌      | 702/2000 [1:31:24<2:49:00,  7.81s/it]


KeyboardInterrupt: 

## Sample Softmax

In [79]:
import os
import json
import warnings
from pathlib import Path
from typing import Dict, List, Tuple

import numpy as np
import torch
from relbench.base import Dataset, RecommendationTask, TaskType
from relbench.datasets import get_dataset
from relbench.modeling.graph import (
    get_link_train_table_input,
    make_pkey_fkey_graph,
)
from relbench.modeling.loader import SparseTensor
from relbench.modeling.utils import get_stype_proposal
from relbench.tasks import get_task
from torch_frame import stype
from torch_frame.config.text_embedder import TextEmbedderConfig
from torch_geometric.loader import NeighborLoader
from torch_geometric.seed import seed_everything
from torch_geometric.typing import NodeType
from torch_geometric.utils.cross_entropy import sparse_cross_entropy
from tqdm import tqdm
from contextgnn.nn.models import ContextGNN
from contextgnn.utils import GloveTextEmbedding, RHSEmbeddingMode

# Static Configuration
# dataset_name = "rel-hm"
# task_name = "user-item-purchase"
learning_rate = 0.01
epochs = 10
eval_epochs_interval = 1
batch_size = 64
channels = 128
aggregation_method = "sum"
num_layers = 4
num_neighbors = 128
rhs_sample_size = 1000  # Use -1 for sampling all RHS
temporal_strategy = "last"
max_steps_per_epoch = 200
num_workers = 0
seed = 42
cache_dir = os.path.expanduser("D:/Dani/relbench/relbench/.cache/relbench_examples")

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_num_threads(1 if torch.cuda.is_available() else os.cpu_count())
torch.manual_seed(seed)
seed_everything(seed)


In [38]:
# Load dataset and task
dataset: Dataset = get_dataset(dataset_name, download=True)
task: RecommendationTask = get_task(dataset_name, task_name, download=True)

# Ensure task type is LINK_PREDICTION
assert task.task_type == TaskType.LINK_PREDICTION

# Tune metric
tune_metric = "link_prediction_map"

# Handle column type mappings
stypes_cache_path = Path(f"{cache_dir}/{dataset_name}/stypes.json")
try:
    with open(stypes_cache_path, "r") as f:
        col_to_stype_dict = json.load(f)
    for table, col_to_stype in col_to_stype_dict.items():
        for col, stype_str in col_to_stype.items():
            col_to_stype[col] = stype(stype_str)
except FileNotFoundError:
    col_to_stype_dict = get_stype_proposal(dataset.get_db())
    Path(stypes_cache_path).parent.mkdir(parents=True, exist_ok=True)
    with open(stypes_cache_path, "w") as f:
        json.dump(col_to_stype_dict, f, indent=2, default=str)

# Prepare graph data and column stats
data, col_stats_dict = make_pkey_fkey_graph(
    dataset.get_db(),
    col_to_stype_dict=col_to_stype_dict,
    text_embedder_cfg=TextEmbedderConfig(
        text_embedder=GloveTextEmbedding(device=device), batch_size=256
    ),
    cache_dir=f"{cache_dir}/{dataset_name}/materialized2",
)


NameError: name 'dataset_name' is not defined

In [62]:
torch.cuda.empty_cache()


In [81]:
# Define number of neighbors for NeighborLoader
num_neighbors_list = [int(num_neighbors // 2**i) for i in range(num_layers)]

# Loader dictionaries
loader_dict: Dict[str, NeighborLoader] = {}
dst_nodes_dict: Dict[str, Tuple[NodeType, Tensor]] = {}
num_dst_nodes_dict: Dict[str, int] = {}

# Initialize data loaders for train, val, and test splits
for split in ["train", "val", "test"]:
    table = task.get_table(split)
    table_input = get_link_train_table_input(table, task)
    dst_nodes_dict[split] = table_input.dst_nodes
    num_dst_nodes_dict[split] = table_input.num_dst_nodes
    loader_dict[split] = NeighborLoader(
        data,
        num_neighbors=num_neighbors_list,
        time_attr="time",
        input_nodes=table_input.src_nodes,
        input_time=table_input.src_time,
        subgraph_type="bidirectional",
        batch_size=batch_size,
        temporal_strategy=temporal_strategy,
        shuffle=split == "train",
        num_workers=num_workers,
        persistent_workers=num_workers > 0,
    )


In [82]:
# Initialize ContextGNN model
model: ContextGNN = ContextGNN(
    data=data,
    col_stats_dict=col_stats_dict,
    rhs_emb_mode=RHSEmbeddingMode.FUSION,
    dst_entity_table=task.dst_entity_table,
    num_nodes=num_dst_nodes_dict["train"],
    num_layers=num_layers,
    channels=channels,
    aggr=aggregation_method,
    norm="layer_norm",
    embedding_dim=64,
    torch_frame_model_kwargs={"channels": 128, "num_layers": 4},
    rhs_sample_size=rhs_sample_size,
).to(device)

# Set up the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)


In [83]:
def train() -> float:
    model.train()

    loss_accum = count_accum = 0.0
    steps = 0
    total_steps = min(len(loader_dict["train"]), max_steps_per_epoch)
    sparse_tensor = SparseTensor(dst_nodes_dict["train"][1], device=device)

    for batch in tqdm(loader_dict["train"], total=total_steps, desc="Train"):
        batch = batch.to(device)

        # Get ground-truth
        input_id = batch[task.src_entity_table].input_id
        src_batch, dst_index = sparse_tensor[input_id]

        # Optimization
        optimizer.zero_grad()

        logits, lhs_y_batch, rhs_y_index = model.forward_sample_softmax(
            batch, task.src_entity_table, task.dst_entity_table, src_batch, dst_index
        )
        edge_label_index = torch.stack([lhs_y_batch, rhs_y_index], dim=0)
        loss = sparse_cross_entropy(logits, edge_label_index)

        numel = len(batch[task.dst_entity_table].batch)
        loss.backward()
        optimizer.step()

        loss_accum += float(loss) * numel
        count_accum += numel

        steps += 1
        if steps >= max_steps_per_epoch:
            break

    if count_accum == 0:
        warnings.warn(f"Did not sample a single '{task.dst_entity_table}' node in any mini-batch.")

    return loss_accum / count_accum if count_accum > 0 else float("nan")


In [84]:
@torch.no_grad()
def test(loader: NeighborLoader, desc: str) -> np.ndarray:
    model.eval()

    pred_list: List[Tensor] = []
    for batch in tqdm(loader, desc=desc):
        batch = batch.to(device)
        out = model(batch, task.src_entity_table, task.dst_entity_table).detach()
        scores = torch.sigmoid(out)
        _, pred_mini = torch.topk(scores, k=task.eval_k, dim=1)
        pred_list.append(pred_mini)
    pred = torch.cat(pred_list, dim=0).cpu().numpy()
    return pred


In [85]:
# Initialize variables for tracking the best model and validation metrics
state_dict = None
best_val_metric = 0

# Training and evaluation loop
for epoch in range(1, epochs + 1):
    train_loss = train()

    if epoch % eval_epochs_interval == 0:
        # Run validation
        val_pred = test(loader_dict["val"], desc="Val")
        val_metrics = task.evaluate(val_pred, task.get_table("val"))
        print(f"Epoch: {epoch:02d}, Train loss: {train_loss}, Val metrics: {val_metrics}")

        # Save best model
        if val_metrics[tune_metric] > best_val_metric:
            best_val_metric = val_metrics[tune_metric]
            state_dict = {k: v.cpu() for k, v in model.state_dict().items()}

# Load the best model weights
assert state_dict is not None
model.load_state_dict(state_dict)

# Evaluate on validation and test sets
val_pred = test(loader_dict["val"], desc="Best val")
val_metrics = task.evaluate(val_pred, task.get_table("val"))
print(f"Best val metrics: {val_metrics}")

test_pred = test(loader_dict["test"], desc="Test")
test_metrics = task.evaluate(test_pred)
print(f"Best test metrics: {test_metrics}")


Train: 100%|█████████▉| 199/200 [02:30<00:00,  1.32it/s]
Val: 100%|██████████| 236/236 [06:25<00:00,  1.63s/it]


Epoch: 01, Train loss: 72.70253252920064, Val metrics: {'link_prediction_precision': np.float64(0.09749188257902061), 'link_prediction_recall': np.float64(0.06706510539651768), 'link_prediction_map': np.float64(0.08634469404133442), 'link_prediction_top': np.float64(0.28182360347226826)}


Train: 100%|█████████▉| 199/200 [04:43<00:01,  1.42s/it]
Val: 100%|██████████| 236/236 [03:54<00:00,  1.01it/s]


Epoch: 02, Train loss: 59.54639587311882, Val metrics: {'link_prediction_precision': np.float64(0.00357829169703797), 'link_prediction_recall': np.float64(0.005774022671159914), 'link_prediction_map': np.float64(0.004895301835531112), 'link_prediction_top': np.float64(0.012524020939632894)}


Train: 100%|█████████▉| 199/200 [06:20<00:01,  1.91s/it]
Val: 100%|██████████| 236/236 [03:54<00:00,  1.01it/s]


Epoch: 03, Train loss: 57.7290588755016, Val metrics: {'link_prediction_precision': np.float64(0.016185143462991186), 'link_prediction_recall': np.float64(0.020551927445840258), 'link_prediction_map': np.float64(0.014822576370021868), 'link_prediction_top': np.float64(0.05552978596514479)}


Train: 100%|█████████▉| 199/200 [05:53<00:01,  1.78s/it]
Val: 100%|██████████| 236/236 [03:47<00:00,  1.04it/s]


Epoch: 04, Train loss: 57.93324835977943, Val metrics: {'link_prediction_precision': np.float64(0.02725134185938639), 'link_prediction_recall': np.float64(0.03162931530839053), 'link_prediction_map': np.float64(0.02378303109285151), 'link_prediction_top': np.float64(0.09310184878404347)}


Train: 100%|█████████▉| 199/200 [06:00<00:01,  1.81s/it]
Val: 100%|██████████| 236/236 [03:46<00:00,  1.04it/s]


Epoch: 05, Train loss: 57.70648078049331, Val metrics: {'link_prediction_precision': np.float64(0.012126432973295341), 'link_prediction_recall': np.float64(0.017434615671957702), 'link_prediction_map': np.float64(0.012312342161258732), 'link_prediction_top': np.float64(0.04247564773706183)}


Train: 100%|█████████▉| 199/200 [05:34<00:01,  1.68s/it]
Val: 100%|██████████| 236/236 [03:46<00:00,  1.04it/s]


Epoch: 06, Train loss: 58.250239113912855, Val metrics: {'link_prediction_precision': np.float64(0.010933669074282684), 'link_prediction_recall': np.float64(0.0175141116371128), 'link_prediction_map': np.float64(0.011881621864393052), 'link_prediction_top': np.float64(0.03889735604002385)}


Train: 100%|█████████▉| 199/200 [05:41<00:01,  1.72s/it]
Val: 100%|██████████| 236/236 [03:44<00:00,  1.05it/s]


Epoch: 07, Train loss: 57.74961981015315, Val metrics: {'link_prediction_precision': np.float64(0.009840302166854416), 'link_prediction_recall': np.float64(0.01664239200228313), 'link_prediction_map': np.float64(0.013073465420891038), 'link_prediction_top': np.float64(0.03618050493671725)}


Train: 100%|█████████▉| 199/200 [05:01<00:01,  1.51s/it]
Val: 100%|██████████| 236/236 [03:46<00:00,  1.04it/s]


Epoch: 08, Train loss: 57.77018339028004, Val metrics: {'link_prediction_precision': np.float64(0.07794380756742429), 'link_prediction_recall': np.float64(0.06456611555428791), 'link_prediction_map': np.float64(0.06902660894278413), 'link_prediction_top': np.float64(0.2522032999801206)}


Train: 100%|█████████▉| 199/200 [05:12<00:01,  1.57s/it]
Val: 100%|██████████| 236/236 [03:46<00:00,  1.04it/s]


Epoch: 09, Train loss: 57.02791413712222, Val metrics: {'link_prediction_precision': np.float64(0.08634285335630508), 'link_prediction_recall': np.float64(0.068243259848566), 'link_prediction_map': np.float64(0.08089902738203049), 'link_prediction_top': np.float64(0.26254058710489697)}


Train: 100%|█████████▉| 199/200 [04:59<00:01,  1.51s/it]
Val: 100%|██████████| 236/236 [03:44<00:00,  1.05it/s]


Epoch: 10, Train loss: 57.83727987871257, Val metrics: {'link_prediction_precision': np.float64(0.08056126167914651), 'link_prediction_recall': np.float64(0.06190801027918744), 'link_prediction_map': np.float64(0.07615604223267732), 'link_prediction_top': np.float64(0.24299251209330064)}


Best val: 100%|██████████| 236/236 [03:44<00:00,  1.05it/s]


Best val metrics: {'link_prediction_precision': np.float64(0.08634285335630508), 'link_prediction_recall': np.float64(0.068243259848566), 'link_prediction_map': np.float64(0.08089902738203049), 'link_prediction_top': np.float64(0.26254058710489697)}


Test: 100%|██████████| 243/243 [03:27<00:00,  1.17it/s]


Best test metrics: {'link_prediction_precision': np.float64(0.09036803500193025), 'link_prediction_recall': np.float64(0.07055906015060619), 'link_prediction_map': np.float64(0.0829043881096384), 'link_prediction_top': np.float64(0.27705571998455797)}


In [45]:
# Initialize variables for tracking the best model and validation metrics
state_dict = None
best_val_metric = 0

# Training and evaluation loop
for epoch in range(1, epochs + 1):
    train_loss = train()

    if epoch % eval_epochs_interval == 0:
        # Run validation
        val_pred = test(loader_dict["val"], desc="Val")
        val_metrics = task.evaluate(val_pred, task.get_table("val"))
        print(f"Epoch: {epoch:02d}, Train loss: {train_loss}, Val metrics: {val_metrics}")

        # Save best model
        if val_metrics[tune_metric] > best_val_metric:
            best_val_metric = val_metrics[tune_metric]
            state_dict = {k: v.cpu() for k, v in model.state_dict().items()}

# Load the best model weights
assert state_dict is not None
model.load_state_dict(state_dict)

# Evaluate on validation and test sets
val_pred = test(loader_dict["val"], desc="Best val")
val_metrics = task.evaluate(val_pred, task.get_table("val"))
print(f"Best val metrics: {val_metrics}")

test_pred = test(loader_dict["test"], desc="Test")
test_metrics = task.evaluate(test_pred)
print(f"Best test metrics: {test_metrics}")


Train:   0%|          | 0/200 [00:19<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 2.43 GiB. GPU 