# 2. KGAT: Knowledge Graph Attention Network for Recommendation

## 2.1. Aggregator and KGAT 

In [18]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# L2 regularization loss, calculates the mean of squared values for a tensor
def _L2_loss_mean(x):
    return torch.mean(torch.sum(torch.pow(x, 2), dim=1, keepdim=False) / 2.)

# Aggregator class for message-passing in GNN
class Aggregator(nn.Module):
    
    def __init__(self, in_dim, out_dim, dropout, aggregator_type):
        super(Aggregator, self).__init__()
        # Initialize dimensions, dropout, and type of aggregator
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.dropout = dropout
        self.aggregator_type = aggregator_type

        # Dropout for message-passing and activation function
        self.message_dropout = nn.Dropout(dropout)
        self.activation = nn.LeakyReLU()

        # Initialize linear transformations based on aggregator type
        if self.aggregator_type == 'gcn':
            # GCN type: single linear layer
            self.linear = nn.Linear(self.in_dim, self.out_dim)
            nn.init.xavier_uniform_(self.linear.weight)
        elif self.aggregator_type == 'graphsage':
            # GraphSAGE type: concatenates input and neighbor embeddings
            self.linear = nn.Linear(self.in_dim * 2, self.out_dim)
            nn.init.xavier_uniform_(self.linear.weight)
        elif self.aggregator_type == 'bi-interaction':
            # Bi-interaction type: uses two linear transformations
            self.linear1 = nn.Linear(self.in_dim, self.out_dim)
            self.linear2 = nn.Linear(self.in_dim, self.out_dim)
            nn.init.xavier_uniform_(self.linear1.weight)
            nn.init.xavier_uniform_(self.linear2.weight)
        else:
            raise NotImplementedError

    # Forward pass for aggregator
    def forward(self, ego_embeddings, A_in):
        """
        Parameters:
        ego_embeddings: user and entity embeddings
        A_in: adjacency matrix as a sparse tensor
        """
        # Neighbor aggregation
        side_embeddings = torch.matmul(A_in, ego_embeddings)

        if self.aggregator_type == 'gcn':
            embeddings = ego_embeddings + side_embeddings
            embeddings = self.activation(self.linear(embeddings))
        elif self.aggregator_type == 'graphsage':
            embeddings = torch.cat([ego_embeddings, side_embeddings], dim=1)
            embeddings = self.activation(self.linear(embeddings))
        elif self.aggregator_type == 'bi-interaction':
            sum_embeddings = self.activation(self.linear1(ego_embeddings + side_embeddings))
            bi_embeddings = self.activation(self.linear2(ego_embeddings * side_embeddings))
            embeddings = bi_embeddings + sum_embeddings

        embeddings = self.message_dropout(embeddings) 
        return embeddings

# Knowledge Graph Attention Network (KGAT) model
class KGAT(nn.Module):
    
    def __init__(self, args, n_users, n_entities, n_relations, A_in=None, user_pre_embed=None, item_pre_embed=None):
        super(KGAT, self).__init__()
        
        # Initialize user and entity embeddings, relation embeddings
        self.embed_dim = args.embed_dim
        self.user_entity_embed = nn.Embedding(n_users + n_entities, self.embed_dim)
        self.relation_embed = nn.Embedding(n_relations, self.embed_dim)
        self.use_pretrain = args.use_pretrain

        # Initialize hyperparameters and structure settings
        self.n_users = n_users
        self.n_entities = n_entities
        self.n_relations = n_relations
        self.relation_dim = args.relation_dim
        self.aggregation_type = args.aggregation_type
        self.conv_dim_list = [args.embed_dim] + eval(args.conv_dim_list)
        self.mess_dropout = eval(args.mess_dropout)
        self.n_layers = len(eval(args.conv_dim_list))
        self.kg_l2loss_lambda = args.kg_l2loss_lambda
        self.cf_l2loss_lambda = args.cf_l2loss_lambda

        # Initialize transformation matrices and embeddings
        self.entity_user_embed = nn.Embedding(self.n_entities + self.n_users, self.embed_dim)
        self.relation_embed = nn.Embedding(self.n_relations, self.relation_dim)
        self.trans_M = nn.Parameter(torch.Tensor(self.n_relations, self.embed_dim, self.relation_dim))

        # Pretrain embedding weights if available
        if (self.use_pretrain == 1) and (user_pre_embed is not None) and (item_pre_embed is not None):
            other_entity_embed = nn.Parameter(torch.Tensor(self.n_entities - item_pre_embed.shape[0], self.embed_dim))
            nn.init.xavier_uniform_(other_entity_embed)
            entity_user_embed = torch.cat([item_pre_embed, other_entity_embed, user_pre_embed], dim=0)
            self.entity_user_embed.weight = nn.Parameter(entity_user_embed)
        else:
            nn.init.xavier_uniform_(self.entity_user_embed.weight)

        # Initialize weights
        nn.init.xavier_uniform_(self.relation_embed.weight)
        nn.init.xavier_uniform_(self.trans_M)

        # Aggregator layers for multi-layer GNN
        self.aggregator_layers = nn.ModuleList()
        for k in range(self.n_layers):
            self.aggregator_layers.append(Aggregator(self.conv_dim_list[k], self.conv_dim_list[k + 1], self.mess_dropout[k], self.aggregation_type))

        # Initialize sparse adjacency matrix
        self.A_in = nn.Parameter(torch.sparse.FloatTensor(self.n_users + self.n_entities, self.n_users + self.n_entities))
        if A_in is not None:
            self.A_in.data = A_in
        self.A_in.requires_grad = False

    # Calculate collaborative filtering embeddings
    def calc_cf_embeddings(self):
        ego_embed = self.entity_user_embed.weight
        all_embed = [ego_embed]

        for idx, layer in enumerate(self.aggregator_layers):
            ego_embed = layer(ego_embed, self.A_in)
            norm_embed = F.normalize(ego_embed, p=2, dim=1)
            all_embed.append(norm_embed)

        all_embed = torch.cat(all_embed, dim=1)
        return all_embed

    # Calculate collaborative filtering loss
    def calc_cf_loss(self, user_ids, item_pos_ids, item_neg_ids):
        all_embed = self.calc_cf_embeddings()
        user_embed = all_embed[user_ids]
        item_pos_embed = all_embed[item_pos_ids]
        item_neg_embed = all_embed[item_neg_ids]

        pos_score = torch.sum(user_embed * item_pos_embed, dim=1)
        neg_score = torch.sum(user_embed * item_neg_embed, dim=1)

        cf_loss = (-1.0) * F.logsigmoid(pos_score - neg_score)
        cf_loss = torch.mean(cf_loss)

        l2_loss = _L2_loss_mean(user_embed) + _L2_loss_mean(item_pos_embed) + _L2_loss_mean(item_neg_embed)
        loss = cf_loss + self.cf_l2loss_lambda * l2_loss
        return loss

    # Calculate knowledge graph loss
    def calc_kg_loss(self, h, r, pos_t, neg_t):
        r_embed = self.relation_embed(r)
        W_r = self.trans_M[r]

        h_embed = self.entity_user_embed(h)
        pos_t_embed = self.entity_user_embed(pos_t)
        neg_t_embed = self.entity_user_embed(neg_t)

        r_mul_h = torch.bmm(h_embed.unsqueeze(1), W_r).squeeze(1)
        r_mul_pos_t = torch.bmm(pos_t_embed.unsqueeze(1), W_r).squeeze(1)
        r_mul_neg_t = torch.bmm(neg_t_embed.unsqueeze(1), W_r).squeeze(1)

        pos_score = torch.sum(torch.pow(r_mul_h + r_embed - r_mul_pos_t, 2), dim=1)
        neg_score = torch.sum(torch.pow(r_mul_h + r_embed - r_mul_neg_t, 2), dim=1)

        kg_loss = (-1.0) * F.logsigmoid(neg_score - pos_score)
        kg_loss = torch.mean(kg_loss)

        l2_loss = _L2_loss_mean(r_mul_h) + _L2_loss_mean(r_embed) + _L2_loss_mean(r_mul_pos_t) + _L2_loss_mean(r_mul_neg_t)
        loss = kg_loss + self.kg_l2loss_lambda * l2_loss
        return loss

    # Update attention scores for entities and relations
    def update_attention_batch(self, h_list, t_list, r_idx):
        r_embed = self.relation_embed.weight[r_idx]
        W_r = self.trans_M[r_idx]

        h_embed = self.entity_user_embed.weight[h_list]
        t_embed = self.entity_user_embed.weight[t_list]

        r_mul_h = torch.matmul(h_embed, W_r)
        r_mul_t = torch.matmul(t_embed, W_r)
        v_list = torch.sum(r_mul_t * torch.tanh(r_mul_h + r_embed), dim=1)
        return v_list

    # Calculate attention for the entire graph
    def update_attention(self, h_list, t_list, r_list, relations):
        device = self.A_in.device

        rows, cols, values = [], [], []

        for r_idx in relations:
            index_list = torch.where(r_list == r_idx)
            batch_h_list = h_list[index_list]
            batch_t_list = t_list[index_list]

            batch_v_list = self.update_attention_batch(batch_h_list, batch_t_list, r_idx)
            rows.append(batch_h_list)
            cols.append(batch_t_list)
            values.append(batch_v_list)

        rows = torch.cat(rows)
        cols = torch.cat(cols)
        values = torch.cat(values)

        indices = torch.stack([rows, cols])
        shape = self.A_in.shape
        A_in = torch.sparse.FloatTensor(indices, values, torch.Size(shape))

        A_in = torch.sparse.softmax(A_in.cpu(), dim=1)
        self.A_in.data = A_in.to(device)

    # Calculate scores for user-item pairs
    def calc_score(self, user_ids, item_ids):
        all_embed = self.calc_cf_embeddings()
        user_embed = all_embed[user_ids]
        item_embed = all_embed[item_ids]

        cf_score = torch.matmul(user_embed, item_embed.transpose(0, 1))
        return cf_score

    # Main forward function with different modes
    def forward(self, *input, mode):
        if mode == 'train_cf':
            return self.calc_cf_loss(*input)
        if mode == 'train_kg':
            return self.calc_kg_loss(*input)
        if mode == 'update_att':
            return self.update_attention(*input)
        if mode == 'predict':
            return self.calc_score(*input)


## 2.2. Log_helper

In [19]:
import os
import logging
import csv
from collections import OrderedDict

# Function to create a unique log ID by incrementing a counter
def create_log_id(dir_path):
    # Initialize the log count at 0
    log_count = 0
    # Create a file path with 'log0.log', 'log1.log', etc., until a unique file name is found
    file_path = os.path.join(dir_path, 'log{:d}.log'.format(log_count))
    while os.path.exists(file_path):
        # Increment log_count and generate a new file path
        log_count += 1
        file_path = os.path.join(dir_path, 'log{:d}.log'.format(log_count))
    return log_count  # Return the unique log ID

# Function to configure logging settings
def logging_config(folder=None, name=None,
                   level=logging.DEBUG,
                   console_level=logging.DEBUG,
                   no_console=True):

    # Create the log folder if it does not exist
    if not os.path.exists(folder):
        os.makedirs(folder)

    # Clear any existing handlers from the root logger
    for handler in logging.root.handlers:
        logging.root.removeHandler(handler)
    logging.root.handlers = []  # Reset handlers to avoid duplicate logs

    # Set the log file path using the provided folder and log name
    logpath = os.path.join(folder, name + ".log")
    print("All logs will be saved to %s" % logpath)

    # Set the logging level for the root logger
    logging.root.setLevel(level)

    # Define the log format
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')

    # Create a file handler for logging to a file
    logfile = logging.FileHandler(logpath)
    logfile.setLevel(level)           # Set the level for the file handler
    logfile.setFormatter(formatter)   # Apply the format to the file handler
    logging.root.addHandler(logfile)  # Add the file handler to the root logger

    # Optionally add a console handler to also log to the console
    if not no_console:
        logconsole = logging.StreamHandler()  # Create a stream handler for console output
        logconsole.setLevel(console_level)    # Set the level for console logging
        logconsole.setFormatter(formatter)    # Apply the format to the console handler
        logging.root.addHandler(logconsole)   # Add the console handler to the root logger

    return folder  # Return the log folder path


## 2.3. Metrics

In [20]:
import torch
import numpy as np
from sklearn.metrics import roc_auc_score, log_loss, mean_squared_error

# Function to calculate recall for a single example at a given cutoff k
def calc_recall(rank, ground_truth, k):
    """
    rank: list of predicted ranked items
    ground_truth: list of true items
    k: top-k items to consider
    """
    return len(set(rank[:k]) & set(ground_truth)) / float(len(set(ground_truth)))

# Function to calculate precision at k for a single user
def precision_at_k(hit, k):
    """
    hit: list of binary values indicating whether top-k items were relevant (1) or not (0)
    """
    hit = np.asarray(hit)[:k]  # Take the first k elements
    return np.mean(hit)  # Compute mean precision

# Function to calculate precision at k for a batch of users
def precision_at_k_batch(hits, k):
    """
    hits: 2D array with rows representing users, columns binary (0 / 1) indicating relevance
    """
    res = hits[:, :k].mean(axis=1)  # Calculate mean precision across users
    return res

# Function to calculate average precision
def average_precision(hit, cut):
    """
    hit: list of binary values indicating relevant items
    cut: maximum position to consider for precision
    """
    hit = np.asarray(hit)
    precisions = [precision_at_k(hit, k + 1) for k in range(cut) if len(hit) >= k]
    if not precisions:
        return 0.0
    return np.sum(precisions) / float(min(cut, np.sum(hit)))

# Function to calculate Discounted Cumulative Gain (DCG) at k
def dcg_at_k(rel, k):
    """
    rel: list of relevance scores (binary or real) sorted by rank
    """
    rel = np.asfarray(rel)[:k]
    dcg = np.sum((2 ** rel - 1) / np.log2(np.arange(2, rel.size + 2)))  # Compute DCG
    return dcg

# Function to calculate Normalized Discounted Cumulative Gain (NDCG) at k
def ndcg_at_k(rel, k):
    """
    rel: list of relevance scores (binary or real)
    """
    idcg = dcg_at_k(sorted(rel, reverse=True), k)  # Ideal DCG for normalization
    if not idcg:
        return 0.0
    return dcg_at_k(rel, k) / idcg  # NDCG is DCG divided by IDCG

# Function to calculate NDCG at k for a batch of users
def ndcg_at_k_batch(hits, k):
    """
    hits: 2D array of binary values indicating relevance
    """
    hits_k = hits[:, :k]
    dcg = np.sum((2 ** hits_k - 1) / np.log2(np.arange(2, k + 2)), axis=1)  # Compute DCG

    sorted_hits_k = np.flip(np.sort(hits), axis=1)[:, :k]  # Sort hits for ideal ranking
    idcg = np.sum((2 ** sorted_hits_k - 1) / np.log2(np.arange(2, k + 2)), axis=1)

    idcg[idcg == 0] = np.inf  # Handle cases where ideal DCG is zero
    ndcg = (dcg / idcg)  # Compute NDCG
    return ndcg

# Function to calculate recall at k for a single user
def recall_at_k(hit, k, all_pos_num):
    """
    hit: list of binary values indicating relevant items
    all_pos_num: total number of relevant items
    """
    hit = np.asfarray(hit)[:k]
    return np.sum(hit) / all_pos_num  # Compute recall

# Function to calculate recall at k for a batch of users
def recall_at_k_batch(hits, k):
    """
    hits: 2D array of binary values indicating relevance
    """
    res = (hits[:, :k].sum(axis=1) / hits.sum(axis=1))  # Recall per user
    return res

# Function to calculate F1 score from precision and recall
def F1(pre, rec):
    if pre + rec > 0:
        return (2.0 * pre * rec) / (pre + rec)  # F1 score formula
    else:
        return 0.0

# Function to calculate Area Under Curve (AUC) score
def calc_auc(ground_truth, prediction):
    try:
        res = roc_auc_score(y_true=ground_truth, y_score=prediction)  # Compute AUC
    except Exception:
        res = 0.0  # Handle any exceptions
    return res

# Function to calculate log loss
def logloss(ground_truth, prediction):
    logloss = log_loss(np.asarray(ground_truth), np.asarray(prediction))  # Compute log loss
    return logloss

# Function to calculate various metrics at different cutoff points (Ks) for collaborative filtering
def calc_metrics_at_k(cf_scores, train_user_dict, test_user_dict, user_ids, item_ids, Ks):
    """
    cf_scores: 2D array of scores for each user-item pair
    train_user_dict: dictionary of training items for each user
    test_user_dict: dictionary of test items for each user
    user_ids: list of user IDs
    item_ids: list of item IDs
    Ks: list of top-K values for evaluation
    """
    # Create a binary matrix indicating test items
    test_pos_item_binary = np.zeros([len(user_ids), len(item_ids)], dtype=np.float32)
    for idx, u in enumerate(user_ids):
        train_pos_item_list = train_user_dict.get(u, [])
        test_pos_item_list = test_user_dict.get(u, [])
        cf_scores[idx][train_pos_item_list] = -np.inf  # Exclude training items from ranking
        test_pos_item_binary[idx][test_pos_item_list] = 1  # Mark test items

    try:
        _, rank_indices = torch.sort(cf_scores.cuda(), descending=True)  # Sort in descending order on GPU
    except:
        _, rank_indices = torch.sort(cf_scores, descending=True)  # Fallback to CPU sorting
    rank_indices = rank_indices.cpu()  # Move result back to CPU

    # Generate binary hit matrix based on sorted indices
    binary_hit = []
    for i in range(len(user_ids)):
        binary_hit.append(test_pos_item_binary[i][rank_indices[i]])
    binary_hit = np.array(binary_hit, dtype=np.float32)

    # Calculate precision, recall, and NDCG metrics for each k in Ks
    metrics_dict = {}
    for k in Ks:
        metrics_dict[k] = {}
        metrics_dict[k]['precision'] = precision_at_k_batch(binary_hit, k)
        metrics_dict[k]['recall'] = recall_at_k_batch(binary_hit, k)
        metrics_dict[k]['ndcg'] = ndcg_at_k_batch(binary_hit, k)
    return metrics_dict

## 2.4. Model_helper

In [21]:
import os
from collections import OrderedDict
import torch

# Function to determine early stopping based on recall scores
def early_stopping(recall_list, stopping_steps):
    # Find the maximum recall and its corresponding step
    best_recall = max(recall_list)
    best_step = recall_list.index(best_recall)
    
    # Check if the number of steps since the best step exceeds the stopping criterion
    if len(recall_list) - best_step - 1 >= stopping_steps:
        should_stop = True  # Set flag to stop training
    else:
        should_stop = False  # Continue training
    return best_recall, should_stop  # Return best recall and stopping decision

# Function to save model checkpoint
def save_model(model, model_dir, current_epoch, last_best_epoch=None):
    # Create the model directory if it doesn't exist
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
        
    # Define the file path for saving the model at the current epoch
    model_state_file = os.path.join(model_dir, 'model_epoch{}.pth'.format(current_epoch))
    # Save the model state dictionary and current epoch in a .pth file
    torch.save({'model_state_dict': model.state_dict(), 'epoch': current_epoch}, model_state_file)

    # Remove the previous best model file if applicable
    if last_best_epoch is not None and current_epoch != last_best_epoch:
        old_model_state_file = os.path.join(model_dir, 'model_epoch{}.pth'.format(last_best_epoch))
        # Delete the old model file if it exists
        if os.path.exists(old_model_state_file):
            os.system('rm {}'.format(old_model_state_file))  # Remove the file from the system

# Function to load model checkpoint
def load_model(model, model_path):
    # Load the checkpoint from the specified path, mapping to CPU to ensure compatibility
    checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
    # Load the saved model state dictionary into the model
    model.load_state_dict(checkpoint['model_state_dict'])
    # Set the model to evaluation mode (for inference)
    model.eval()
    return model  # Return the loaded model

## 2.5. Loader_base

In [22]:
import os
import time
import random
import collections

import torch
import numpy as np
import pandas as pd

# Base class for loading and processing data for recommendation models
class DataLoaderBase(object):

    # Initialize the DataLoader with paths, data files, and configuration
    def __init__(self, args, logging):
        self.args = args
        self.data_name = args.data_name
        self.use_pretrain = args.use_pretrain
        self.pretrain_embedding_dir = args.pretrain_embedding_dir

        # Set up data paths for training, testing, and knowledge graph files
        self.data_dir = os.path.join(args.data_dir, args.data_name)
        self.train_file = os.path.join(self.data_dir, 'train.txt')
        self.test_file = os.path.join(self.data_dir, 'test.txt')
        self.kg_file = os.path.join(self.data_dir, "kg_final.txt")

        # Load collaborative filtering (CF) data and user-item interactions
        self.cf_train_data, self.train_user_dict = self.load_cf(self.train_file)
        self.cf_test_data, self.test_user_dict = self.load_cf(self.test_file)
        self.statistic_cf()  # Calculate statistics on users and items

        # Load pre-trained embeddings if required
        if self.use_pretrain == 1:
            self.load_pretrained_data()

    # Load collaborative filtering data from a file
    def load_cf(self, filename):
        user = []
        item = []
        user_dict = dict()

        # Read each line from the file representing user-item interactions
        lines = open(filename, 'r').readlines()
        for l in lines:
            tmp = l.strip()
            inter = [int(i) for i in tmp.split()]  # Convert to integers

            if len(inter) > 1:
                user_id, item_ids = inter[0], inter[1:]  # Extract user and item IDs
                item_ids = list(set(item_ids))  # Remove duplicate items

                # Append items and populate the user dictionary
                for item_id in item_ids:
                    user.append(user_id)
                    item.append(item_id)
                user_dict[user_id] = item_ids

        # Convert lists to numpy arrays
        user = np.array(user, dtype=np.int32)
        item = np.array(item, dtype=np.int32)
        return (user, item), user_dict

    # Calculate statistics on the number of users, items, and interactions
    def statistic_cf(self):
        self.n_users = max(max(self.cf_train_data[0]), max(self.cf_test_data[0])) + 1
        self.n_items = max(max(self.cf_train_data[1]), max(self.cf_test_data[1])) + 1
        self.n_cf_train = len(self.cf_train_data[0])
        self.n_cf_test = len(self.cf_test_data[0])

    # Load knowledge graph (KG) data
    def load_kg(self, filename):
        kg_data = pd.read_csv(filename, sep=' ', names=['h', 'r', 't'], engine='python')  # Load head-relation-tail data
        kg_data = kg_data.drop_duplicates()  # Remove duplicate triples
        return kg_data

    # Sample positive items for a user
    def sample_pos_items_for_u(self, user_dict, user_id, n_sample_pos_items):
        pos_items = user_dict[user_id]  # Get the list of positive items
        n_pos_items = len(pos_items)

        sample_pos_items = []
        while True:
            if len(sample_pos_items) == n_sample_pos_items:
                break

            pos_item_idx = np.random.randint(low=0, high=n_pos_items, size=1)[0]
            pos_item_id = pos_items[pos_item_idx]
            if pos_item_id not in sample_pos_items:
                sample_pos_items.append(pos_item_id)
        return sample_pos_items

    # Sample negative items for a user
    def sample_neg_items_for_u(self, user_dict, user_id, n_sample_neg_items):
        pos_items = user_dict[user_id]  # Get the list of positive items

        sample_neg_items = []
        while True:
            if len(sample_neg_items) == n_sample_neg_items:
                break

            neg_item_id = np.random.randint(low=0, high=self.n_items, size=1)[0]
            if neg_item_id not in pos_items and neg_item_id not in sample_neg_items:
                sample_neg_items.append(neg_item_id)
        return sample_neg_items

    # Generate a batch of user-item interactions for collaborative filtering
    def generate_cf_batch(self, user_dict, batch_size):
        exist_users = list(user_dict.keys())  # List of existing users
        if batch_size <= len(exist_users):
            batch_user = random.sample(exist_users, batch_size)  # Randomly sample users
        else:
            batch_user = [random.choice(exist_users) for _ in range(batch_size)]  # Sample with replacement if needed

        batch_pos_item, batch_neg_item = [], []
        for u in batch_user:
            batch_pos_item += self.sample_pos_items_for_u(user_dict, u, 1)
            batch_neg_item += self.sample_neg_items_for_u(user_dict, u, 1)

        # Convert lists to torch tensors
        batch_user = torch.LongTensor(batch_user)
        batch_pos_item = torch.LongTensor(batch_pos_item)
        batch_neg_item = torch.LongTensor(batch_neg_item)
        return batch_user, batch_pos_item, batch_neg_item

    # Sample positive triples for a given head entity in the knowledge graph
    def sample_pos_triples_for_h(self, kg_dict, head, n_sample_pos_triples):
        pos_triples = kg_dict[head]  # Get list of positive triples for head
        n_pos_triples = len(pos_triples)

        sample_relations, sample_pos_tails = [], []
        while True:
            if len(sample_relations) == n_sample_pos_triples:
                break

            pos_triple_idx = np.random.randint(low=0, high=n_pos_triples, size=1)[0]
            tail = pos_triples[pos_triple_idx][0]
            relation = pos_triples[pos_triple_idx][1]

            if relation not in sample_relations and tail not in sample_pos_tails:
                sample_relations.append(relation)
                sample_pos_tails.append(tail)
        return sample_relations, sample_pos_tails

    # Sample negative triples for a given head entity and relation in the knowledge graph
    def sample_neg_triples_for_h(self, kg_dict, head, relation, n_sample_neg_triples, highest_neg_idx):
        pos_triples = kg_dict[head]  # Get positive triples for head

        sample_neg_tails = []
        while True:
            if len(sample_neg_tails) == n_sample_neg_triples:
                break

            tail = np.random.randint(low=0, high=highest_neg_idx, size=1)[0]
            if (tail, relation) not in pos_triples and tail not in sample_neg_tails:
                sample_neg_tails.append(tail)
        return sample_neg_tails

    # Generate a batch of triples for knowledge graph training
    def generate_kg_batch(self, kg_dict, batch_size, highest_neg_idx):
        exist_heads = list(kg_dict.keys())  # List of existing heads in KG
        if batch_size <= len(exist_heads):
            batch_head = random.sample(exist_heads, batch_size)  # Randomly sample heads
        else:
            batch_head = [random.choice(exist_heads) for _ in range(batch_size)]  # Sample with replacement if needed

        batch_relation, batch_pos_tail, batch_neg_tail = [], [], []
        for h in batch_head:
            relation, pos_tail = self.sample_pos_triples_for_h(kg_dict, h, 1)
            batch_relation += relation
            batch_pos_tail += pos_tail

            neg_tail = self.sample_neg_triples_for_h(kg_dict, h, relation[0], 1, highest_neg_idx)
            batch_neg_tail += neg_tail

        # Convert lists to torch tensors
        batch_head = torch.LongTensor(batch_head)
        batch_relation = torch.LongTensor(batch_relation)
        batch_pos_tail = torch.LongTensor(batch_pos_tail)
        batch_neg_tail = torch.LongTensor(batch_neg_tail)
        return batch_head, batch_relation, batch_pos_tail, batch_neg_tail

    # Load pre-trained embeddings for users and items
    def load_pretrained_data(self):
        pre_model = 'mf'  # Specify model type for pre-trained embeddings
        pretrain_path = '%s/%s/%s.npz' % (self.pretrain_embedding_dir, self.data_name, pre_model)  # Define path
        pretrain_data = np.load(pretrain_path)  # Load pre-trained embeddings
        self.user_pre_embed = pretrain_data['user_embed']
        self.item_pre_embed = pretrain_data['item_embed']

        # Validate embedding dimensions match the specified parameters
        assert self.user_pre_embed.shape[0] == self.n_users
        assert self.item_pre_embed.shape[0] == self.n_items
        assert self.user_pre_embed.shape[1] == self.args.embed_dim
        assert self.item_pre_embed.shape[1] == self.args.embed_dim

## 2.6. Loader_KGAT

In [23]:
import os
import random
import collections

import torch
import numpy as np
import pandas as pd
import scipy.sparse as sp

# DataLoaderKGAT class, extending from DataLoaderBase for the KGAT model
class DataLoaderKGAT(DataLoaderBase):

    # Initialization function
    def __init__(self, args, logging):
        super().__init__(args, logging)  # Initialize DataLoaderBase class

        # Define number of users and entities based on training and test data
        self.n_users = max(self.cf_train_data[0].max(), self.cf_test_data[0].max()) + 1
        self.n_entities = max(self.cf_train_data[1].max(), self.cf_test_data[1].max()) + 1
        # Set batch sizes for collaborative filtering, KG, and testing
        self.cf_batch_size = args.cf_batch_size
        self.kg_batch_size = args.kg_batch_size
        self.test_batch_size = args.test_batch_size

        # Load KG data and construct the required structures
        kg_data = self.load_kg(self.kg_file)
        self.construct_data(kg_data)
        self.print_info(logging)  # Log information about data

        # Set the laplacian type and initialize adjacency and laplacian dictionaries
        self.laplacian_type = args.laplacian_type
        self.create_adjacency_dict()
        self.create_laplacian_dict()

    # Method to construct data structures from KG data
    def construct_data(self, kg_data):
        # Duplicate KG data and reverse the head-tail for inverse relationships
        n_relations = max(kg_data['r']) + 1
        inverse_kg_data = kg_data.copy()
        inverse_kg_data = inverse_kg_data.rename({'h': 't', 't': 'h'}, axis='columns')
        inverse_kg_data['r'] += n_relations  # Assign new relation IDs for inverse relations
        kg_data = pd.concat([kg_data, inverse_kg_data], axis=0, ignore_index=True, sort=False)

        # Offset relation IDs for user-item interactions
        kg_data['r'] += 2
        self.n_relations = max(kg_data['r']) + 1  # Calculate total number of relations
        self.n_entities = max(max(kg_data['h']), max(kg_data['t'])) + 1  # Update entity count
        self.n_users_entities = self.n_users + self.n_entities  # Total count for users and entities

        # Remap user IDs to ensure distinct user and entity spaces
        self.cf_train_data = (np.array(list(map(lambda d: d + self.n_entities, self.cf_train_data[0]))).astype(np.int32), 
                              self.cf_train_data[1].astype(np.int32))
        self.cf_test_data = (np.array(list(map(lambda d: d + self.n_entities, self.cf_test_data[0]))).astype(np.int32), 
                             self.cf_test_data[1].astype(np.int32))

        # Offset user IDs in training and test dictionaries
        self.train_user_dict = {k + self.n_entities: np.unique(v).astype(np.int32) for k, v in self.train_user_dict.items()}
        self.test_user_dict = {k + self.n_entities: np.unique(v).astype(np.int32) for k, v in self.test_user_dict.items()}

        # Create interaction data in KG format for training
        cf2kg_train_data = pd.DataFrame(np.zeros((self.n_cf_train, 3), dtype=np.int32), columns=['h', 'r', 't'])
        cf2kg_train_data['h'] = self.cf_train_data[0]
        cf2kg_train_data['t'] = self.cf_train_data[1]

        # Inverse interactions for undirected connections
        inverse_cf2kg_train_data = pd.DataFrame(np.ones((self.n_cf_train, 3), dtype=np.int32), columns=['h', 'r', 't'])
        inverse_cf2kg_train_data['h'] = self.cf_train_data[1]
        inverse_cf2kg_train_data['t'] = self.cf_train_data[0]

        # Concatenate all KG and CF data into a single KG training set
        self.kg_train_data = pd.concat([kg_data, cf2kg_train_data, inverse_cf2kg_train_data], ignore_index=True)
        self.n_kg_train = len(self.kg_train_data)  # Total KG training samples

        # Construct dictionaries and lists for head, tail, and relation mapping
        h_list, t_list, r_list = [], [], []
        self.train_kg_dict = collections.defaultdict(list)
        self.train_relation_dict = collections.defaultdict(list)

        # Populate head, relation, and tail lists and dictionaries
        for row in self.kg_train_data.iterrows():
            h, r, t = row[1]
            h_list.append(h)
            t_list.append(t)
            r_list.append(r)

            self.train_kg_dict[h].append((t, r))
            self.train_relation_dict[r].append((h, t))

        # Convert lists to PyTorch tensors
        self.h_list = torch.LongTensor(h_list)
        self.t_list = torch.LongTensor(t_list)
        self.r_list = torch.LongTensor(r_list)

    # Convert a sparse matrix in COO format to a sparse PyTorch tensor
    def convert_coo2tensor(self, coo):
        values = coo.data  # Non-zero values of COO matrix
        indices = np.vstack((coo.row, coo.col))  # Row and column indices

        i = torch.LongTensor(indices)  # Convert indices to PyTorch tensor
        v = torch.FloatTensor(values)  # Convert values to PyTorch tensor
        shape = coo.shape
        return torch.sparse.FloatTensor(i, v, torch.Size(shape))

    # Create adjacency dictionary for each relation type
    def create_adjacency_dict(self):
        self.adjacency_dict = {}
        for r, ht_list in self.train_relation_dict.items():
            rows = [e[0] for e in ht_list]
            cols = [e[1] for e in ht_list]
            vals = [1] * len(rows)  # Use 1 for all adjacency values
            adj = sp.coo_matrix((vals, (rows, cols)), shape=(self.n_users_entities, self.n_users_entities))
            self.adjacency_dict[r] = adj  # Store adjacency matrix for each relation

    # Create Laplacian matrices based on chosen normalization (symmetric or random walk)
    def create_laplacian_dict(self):
        
        # Symmetric normalization of Laplacian matrix
        def symmetric_norm_lap(adj):
            rowsum = np.array(adj.sum(axis=1))

            d_inv_sqrt = np.power(rowsum, -0.5).flatten()
            d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0  # Handle inf values
            d_mat_inv_sqrt = sp.diags(d_inv_sqrt)

            norm_adj = d_mat_inv_sqrt.dot(adj).dot(d_mat_inv_sqrt)  # Symmetric normalization
            return norm_adj.tocoo()

        # Random walk normalization of Laplacian matrix
        def random_walk_norm_lap(adj):
            rowsum = np.array(adj.sum(axis=1))

            d_inv = np.power(rowsum, -1.0).flatten()
            d_inv[np.isinf(d_inv)] = 0  # Handle inf values
            d_mat_inv = sp.diags(d_inv)

            norm_adj = d_mat_inv.dot(adj)  # Random walk normalization
            return norm_adj.tocoo()

        # Choose normalization function based on laplacian type
        if self.laplacian_type == 'symmetric':
            norm_lap_func = symmetric_norm_lap
        elif self.laplacian_type == 'random-walk':
            norm_lap_func = random_walk_norm_lap
        else:
            raise NotImplementedError

        # Create Laplacian matrix for each relation
        self.laplacian_dict = {}
        for r, adj in self.adjacency_dict.items():
            self.laplacian_dict[r] = norm_lap_func(adj)

        # Aggregate all Laplacians and convert to PyTorch tensor
        A_in = sum(self.laplacian_dict.values())
        self.A_in = self.convert_coo2tensor(A_in.tocoo())

    # Log information about the dataset
    def print_info(self, logging):
        logging.info('n_users:           %d' % self.n_users)
        logging.info('n_items:           %d' % self.n_items)
        logging.info('n_entities:        %d' % self.n_entities)
        logging.info('n_users_entities:  %d' % self.n_users_entities)
        logging.info('n_relations:       %d' % self.n_relations)

        logging.info('n_h_list:          %d' % len(self.h_list))
        logging.info('n_t_list:          %d' % len(self.t_list))
        logging.info('n_r_list:          %d' % len(self.r_list))

        logging.info('n_cf_train:        %d' % self.n_cf_train)
        logging.info('n_cf_test:         %d' % self.n_cf_test)

        logging.info('n_kg_train:        %d' % self.n_kg_train)

## 2.7. Parser

In [26]:
import argparse


def parse_kgat_args():
    parser = argparse.ArgumentParser(description="Run KGAT.")

    parser.add_argument('--seed', type=int, default=2024,
                        help='Random seed.')

    parser.add_argument('--data_name', nargs='?', default='exchange-students',
                        help='Choose a dataset from {yelp2018, last-fm, amazon-book}')
    parser.add_argument('--data_dir', nargs='?', default='/Users/gayeonlee/Documents/2024/rec_system/kgat_data/',
                        help='Input data path.')

    parser.add_argument('--use_pretrain', type=int, default=0,
                        help='0: No pretrain, 1: Pretrain with the learned embeddings, 2: Pretrain with stored model.')
    parser.add_argument('--pretrain_embedding_dir', nargs='?', default='/Users/gayeonlee/Documents/2024/rec_system/',
                        help='Path of learned embeddings.')
    parser.add_argument('--pretrain_model_path', nargs='?', default='trained_model/model.pth',
                        help='Path of stored model.')

    parser.add_argument('--cf_batch_size', type=int, default=64,
                        help='CF batch size.')
    parser.add_argument('--kg_batch_size', type=int, default=128,
                        help='KG batch size.')
    parser.add_argument('--test_batch_size', type=int, default=100,
                        help='Test batch size (the user number to test every batch).')

    parser.add_argument('--embed_dim', type=int, default=64,
                        help='User / entity Embedding size.')
    parser.add_argument('--relation_dim', type=int, default=64,
                        help='Relation Embedding size.')

    parser.add_argument('--laplacian_type', type=str, default='random-walk',
                        help='Specify the type of the adjacency (laplacian) matrix from {symmetric, random-walk}.')
    parser.add_argument('--aggregation_type', type=str, default='bi-interaction',
                        help='Specify the type of the aggregation layer from {gcn, graphsage, bi-interaction}.')
    parser.add_argument('--conv_dim_list', nargs='?', default='[64, 32, 16]',
                        help='Output sizes of every aggregation layer.')
    parser.add_argument('--mess_dropout', nargs='?', default='[0.1, 0.1, 0.1]',
                        help='Dropout probability w.r.t. message dropout for each deep layer. 0: no dropout.')

    parser.add_argument('--kg_l2loss_lambda', type=float, default=1e-5,
                        help='Lambda when calculating KG l2 loss.')
    parser.add_argument('--cf_l2loss_lambda', type=float, default=1e-5,
                        help='Lambda when calculating CF l2 loss.')

    parser.add_argument('--lr', type=float, default=0.0001,
                        help='Learning rate.')
    parser.add_argument('--n_epoch', type=int, default=3,
                        help='Number of epoch.')
    parser.add_argument('--stopping_steps', type=int, default=1,
                        help='Number of epoch for early stopping')

    parser.add_argument('--cf_print_every', type=int, default=3,
                        help='Iter interval of printing CF loss.')
    parser.add_argument('--kg_print_every', type=int, default=3,
                        help='Iter interval of printing KG loss.')
    parser.add_argument('--evaluate_every', type=int, default=3,
                        help='Epoch interval of evaluating CF.')

    parser.add_argument('--Ks', nargs='?', default='[10, 12, 14, 16, 20]',
                        help='Calculate metric@K when evaluating.')

    # args = parser.parse_args()
    args = parser.parse_args(args=[])


    save_dir = 'trained_model/KGAT/{}/embed-dim{}_relation-dim{}_{}_{}_{}_lr{}_pretrain{}/'.format(
        args.data_name, args.embed_dim, args.relation_dim, args.laplacian_type, args.aggregation_type,
        '-'.join([str(i) for i in eval(args.conv_dim_list)]), args.lr, args.use_pretrain)
    args.save_dir = save_dir

    return args

## 2.8. Train and test

In [27]:
import os
import sys
import random
from time import time

import pandas as pd
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim

def evaluate(model, dataloader, Ks, device):
    test_batch_size = dataloader.test_batch_size
    train_user_dict = dataloader.train_user_dict
    test_user_dict = dataloader.test_user_dict

    model.eval()

    user_ids = list(test_user_dict.keys())
    user_ids_batches = [user_ids[i: i + test_batch_size] for i in range(0, len(user_ids), test_batch_size)]
    user_ids_batches = [torch.LongTensor(d) for d in user_ids_batches]

    n_items = dataloader.n_items
    item_ids = torch.arange(n_items, dtype=torch.long).to(device)

    cf_scores = []
    metric_names = ['precision', 'recall', 'ndcg']
    metrics_dict = {k: {m: [] for m in metric_names} for k in Ks}

    with tqdm(total=len(user_ids_batches), desc='Evaluating Iteration') as pbar:
        for batch_user_ids in user_ids_batches:
            batch_user_ids = batch_user_ids.to(device)

            with torch.no_grad():
                batch_scores = model(batch_user_ids, item_ids, mode='predict')       # (n_batch_users, n_items)

            batch_scores = batch_scores.cpu()
            batch_metrics = calc_metrics_at_k(batch_scores, train_user_dict, test_user_dict, batch_user_ids.cpu().numpy(), item_ids.cpu().numpy(), Ks)

            cf_scores.append(batch_scores.numpy())
            for k in Ks:
                for m in metric_names:
                    metrics_dict[k][m].append(batch_metrics[k][m])
            pbar.update(1)

    cf_scores = np.concatenate(cf_scores, axis=0)
    for k in Ks:
        for m in metric_names:
            metrics_dict[k][m] = np.concatenate(metrics_dict[k][m]).mean()
    return cf_scores, metrics_dict


def train(args):
    # seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    log_save_id = create_log_id(args.save_dir)
    logging_config(folder=args.save_dir, name='log{:d}'.format(log_save_id), no_console=False)
    logging.info(args)

    # GPU / CPU
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # load data
    data = DataLoaderKGAT(args, logging)
    if args.use_pretrain == 1:
        user_pre_embed = torch.tensor(data.user_pre_embed)
        item_pre_embed = torch.tensor(data.item_pre_embed)
    else:
        user_pre_embed, item_pre_embed = None, None

    # construct model & optimizer
    # model = KGAT(args, data.n_users, data.n_entities, data.n_relations, data.A_in, user_pre_embed, item_pre_embed)
    # model = KGAT(args, data.n_users, data.n_entities, data.n_relations, data.A_in)
    model = KGAT(args, data.n_users, data.n_entities, data.n_relations, data.A_in, user_pre_embed, item_pre_embed)
    model = KGAT(args, data.n_users, data.n_entities, data.n_relations, data.A_in, user_pre_embed, item_pre_embed)
    if args.use_pretrain == 2:
        model = load_model(model, args.pretrain_model_path)

    model.to(device)
    logging.info(model)

    cf_optimizer = optim.Adam(model.parameters(), lr=args.lr)
    kg_optimizer = optim.Adam(model.parameters(), lr=args.lr)

    # initialize metrics
    best_epoch = -1
    best_recall = 0

    Ks = eval(args.Ks)
    k_min = min(Ks)
    k_max = max(Ks)

    epoch_list = []
    metrics_list = {k: {'precision': [], 'recall': [], 'ndcg': []} for k in Ks}

    # train model
    for epoch in range(1, args.n_epoch + 1):
        time0 = time()
        model.train()

        # train cf
        time1 = time()
        cf_total_loss = 0
        n_cf_batch = data.n_cf_train // data.cf_batch_size + 1

        for iter in range(1, n_cf_batch + 1):
            time2 = time()
            cf_batch_user, cf_batch_pos_item, cf_batch_neg_item = data.generate_cf_batch(data.train_user_dict, data.cf_batch_size)
            cf_batch_user = cf_batch_user.to(device)
            cf_batch_pos_item = cf_batch_pos_item.to(device)
            cf_batch_neg_item = cf_batch_neg_item.to(device)

            cf_batch_loss = model(cf_batch_user, cf_batch_pos_item, cf_batch_neg_item, mode='train_cf')

            if np.isnan(cf_batch_loss.cpu().detach().numpy()):
                logging.info('ERROR (CF Training): Epoch {:04d} Iter {:04d} / {:04d} Loss is nan.'.format(epoch, iter, n_cf_batch))
                sys.exit()

            cf_batch_loss.backward()
            cf_optimizer.step()
            cf_optimizer.zero_grad()
            cf_total_loss += cf_batch_loss.item()

            if (iter % args.cf_print_every) == 0:
                logging.info('CF Training: Epoch {:04d} Iter {:04d} / {:04d} | Time {:.1f}s | Iter Loss {:.4f} | Iter Mean Loss {:.4f}'.format(epoch, iter, n_cf_batch, time() - time2, cf_batch_loss.item(), cf_total_loss / iter))
        logging.info('CF Training: Epoch {:04d} Total Iter {:04d} | Total Time {:.1f}s | Iter Mean Loss {:.4f}'.format(epoch, n_cf_batch, time() - time1, cf_total_loss / n_cf_batch))

        # train kg
        time3 = time()
        kg_total_loss = 0
        n_kg_batch = data.n_kg_train // data.kg_batch_size + 1

        for iter in range(1, n_kg_batch + 1):
            time4 = time()
            kg_batch_head, kg_batch_relation, kg_batch_pos_tail, kg_batch_neg_tail = data.generate_kg_batch(data.train_kg_dict, data.kg_batch_size, data.n_users_entities)
            kg_batch_head = kg_batch_head.to(device)
            kg_batch_relation = kg_batch_relation.to(device)
            kg_batch_pos_tail = kg_batch_pos_tail.to(device)
            kg_batch_neg_tail = kg_batch_neg_tail.to(device)

            kg_batch_loss = model(kg_batch_head, kg_batch_relation, kg_batch_pos_tail, kg_batch_neg_tail, mode='train_kg')

            if np.isnan(kg_batch_loss.cpu().detach().numpy()):
                logging.info('ERROR (KG Training): Epoch {:04d} Iter {:04d} / {:04d} Loss is nan.'.format(epoch, iter, n_kg_batch))
                sys.exit()

            kg_batch_loss.backward()
            kg_optimizer.step()
            kg_optimizer.zero_grad()
            kg_total_loss += kg_batch_loss.item()

            if (iter % args.kg_print_every) == 0:
                logging.info('KG Training: Epoch {:04d} Iter {:04d} / {:04d} | Time {:.1f}s | Iter Loss {:.4f} | Iter Mean Loss {:.4f}'.format(epoch, iter, n_kg_batch, time() - time4, kg_batch_loss.item(), kg_total_loss / iter))
        logging.info('KG Training: Epoch {:04d} Total Iter {:04d} | Total Time {:.1f}s | Iter Mean Loss {:.4f}'.format(epoch, n_kg_batch, time() - time3, kg_total_loss / n_kg_batch))

        # update attention
        time5 = time()
        h_list = data.h_list.to(device)
        t_list = data.t_list.to(device)
        r_list = data.r_list.to(device)
        relations = list(data.laplacian_dict.keys())
        model(h_list, t_list, r_list, relations, mode='update_att')
        logging.info('Update Attention: Epoch {:04d} | Total Time {:.1f}s'.format(epoch, time() - time5))

        logging.info('CF + KG Training: Epoch {:04d} | Total Time {:.1f}s'.format(epoch, time() - time0))

        # evaluate cf
        if (epoch % args.evaluate_every) == 0 or epoch == args.n_epoch:
            time6 = time()
            _, metrics_dict = evaluate(model, data, Ks, device)
            logging.info('CF Evaluation: Epoch {:04d} | Total Time {:.1f}s | Precision [{:.4f}, {:.4f}], Recall [{:.4f}, {:.4f}], NDCG [{:.4f}, {:.4f}]'.format(
                epoch, time() - time6, metrics_dict[k_min]['precision'], metrics_dict[k_max]['precision'], metrics_dict[k_min]['recall'], metrics_dict[k_max]['recall'], metrics_dict[k_min]['ndcg'], metrics_dict[k_max]['ndcg']))

            epoch_list.append(epoch)
            for k in Ks:
                for m in ['precision', 'recall', 'ndcg']:
                    metrics_list[k][m].append(metrics_dict[k][m])
            best_recall, should_stop = early_stopping(metrics_list[k_min]['recall'], args.stopping_steps)

            if should_stop:
                break

            if metrics_list[k_min]['recall'].index(best_recall) == len(epoch_list) - 1:
                save_model(model, args.save_dir, epoch, best_epoch)
                logging.info('Save model on epoch {:04d}!'.format(epoch))
                best_epoch = epoch

    # save metrics
    metrics_df = [epoch_list]
    metrics_cols = ['epoch_idx']
    for k in Ks:
        for m in ['precision', 'recall', 'ndcg']:
            metrics_df.append(metrics_list[k][m])
            metrics_cols.append('{}@{}'.format(m, k))
    metrics_df = pd.DataFrame(metrics_df).transpose()
    metrics_df.columns = metrics_cols
    metrics_df.to_csv(args.save_dir + '/metrics.tsv', sep='\t', index=False)

    # print best metrics
    best_metrics = metrics_df.loc[metrics_df['epoch_idx'] == best_epoch].iloc[0].to_dict()
    logging.info('Best CF Evaluation: Epoch {:04d} | Precision [{:.4f}, {:.4f}], Recall [{:.4f}, {:.4f}], NDCG [{:.4f}, {:.4f}]'.format(
        int(best_metrics['epoch_idx']), best_metrics['precision@{}'.format(k_min)], best_metrics['precision@{}'.format(k_max)], best_metrics['recall@{}'.format(k_min)], best_metrics['recall@{}'.format(k_max)], best_metrics['ndcg@{}'.format(k_min)], best_metrics['ndcg@{}'.format(k_max)]))


def predict(args):
    # GPU / CPU
    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

    # load data
    data = DataLoaderKGAT(args, logging)

    # load model
    model = KGAT(args, data.n_users, data.n_entities, data.n_relations)
    model = load_model(model, args.pretrain_model_path)
    model.to(device)

    # predict
    Ks = eval(args.Ks)
    k_min = min(Ks)
    k_max = max(Ks)

    cf_scores, metrics_dict = evaluate(model, data, Ks, device)
    np.save(args.save_dir + 'cf_scores.npy', cf_scores)
    print('CF Evaluation: Precision [{:.4f}, {:.4f}], Recall [{:.4f}, {:.4f}], NDCG [{:.4f}, {:.4f}]'.format(
        metrics_dict[k_min]['precision'], metrics_dict[k_max]['precision'], metrics_dict[k_min]['recall'], metrics_dict[k_max]['recall'], metrics_dict[k_min]['ndcg'], metrics_dict[k_max]['ndcg']))

if __name__ == '__main__':
    args = parse_kgat_args()
    train(args)
    # predict(args)

2024-12-10 04:10:24,792 - root - INFO - Namespace(seed=2024, data_name='exchange-students', data_dir='/Users/gayeonlee/Documents/2024/rec_system/kgat_data/', use_pretrain=0, pretrain_embedding_dir='/Users/gayeonlee/Documents/2024/rec_system/', pretrain_model_path='trained_model/model.pth', cf_batch_size=64, kg_batch_size=128, test_batch_size=100, embed_dim=64, relation_dim=64, laplacian_type='random-walk', aggregation_type='bi-interaction', conv_dim_list='[64, 32, 16]', mess_dropout='[0.1, 0.1, 0.1]', kg_l2loss_lambda=1e-05, cf_l2loss_lambda=1e-05, lr=0.0001, n_epoch=3, stopping_steps=1, cf_print_every=3, kg_print_every=3, evaluate_every=3, Ks='[10, 12, 14, 16, 20]', save_dir='trained_model/KGAT/exchange-students/embed-dim64_relation-dim64_random-walk_bi-interaction_64-32-16_lr0.0001_pretrain0/')


All logs will be saved to trained_model/KGAT/exchange-students/embed-dim64_relation-dim64_random-walk_bi-interaction_64-32-16_lr0.0001_pretrain0/log4.log


2024-12-10 04:10:25,116 - root - INFO - n_users:           2309
2024-12-10 04:10:25,116 - root - INFO - n_items:           196
2024-12-10 04:10:25,117 - root - INFO - n_entities:        2309
2024-12-10 04:10:25,117 - root - INFO - n_users_entities:  4618
2024-12-10 04:10:25,117 - root - INFO - n_relations:       30
2024-12-10 04:10:25,118 - root - INFO - n_h_list:          27520
2024-12-10 04:10:25,118 - root - INFO - n_t_list:          27520
2024-12-10 04:10:25,118 - root - INFO - n_r_list:          27520
2024-12-10 04:10:25,118 - root - INFO - n_cf_train:        1847
2024-12-10 04:10:25,118 - root - INFO - n_cf_test:         462
2024-12-10 04:10:25,119 - root - INFO - n_kg_train:        27520
  d_inv = np.power(rowsum, -1.0).flatten()
2024-12-10 04:10:25,152 - root - INFO - KGAT(
  (user_entity_embed): Embedding(4618, 64)
  (relation_embed): Embedding(30, 64)
  (entity_user_embed): Embedding(4618, 64)
  (aggregator_layers): ModuleList(
    (0): Aggregator(
      (message_dropout): Dr