In [None]:
import os
import random
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import scipy.sparse as sp
import logging
from tqdm import tqdm
from types import SimpleNamespace

In [None]:
# Logging Setup
def create_log_id(dir_path):
    log_count = 0
    while os.path.exists(os.path.join(dir_path, f'log{log_count}.log')):
        log_count += 1
    return log_count

def logging_config(folder=None, name='log', level=logging.INFO, no_console=True):
    if not os.path.exists(folder):
        os.makedirs(folder)
    for handler in logging.root.handlers[:]:
        logging.root.removeHandler(handler)
    log_path = os.path.join(folder, name + ".log")
    print(f"Logging to {log_path}")
    logging.basicConfig(level=level,
                        format='%(asctime)s - %(levelname)s - %(message)s',
                        handlers=[
                            logging.FileHandler(log_path),
                            logging.StreamHandler() if not no_console else logging.NullHandler()
                        ])


In [None]:
# Metric Functions (precision, recall, ndcg)
def precision_at_k_batch(hits, k):
    return hits[:, :k].mean(axis=1)

def recall_at_k_batch(hits, k):
    res = hits[:, :k].sum(axis=1) / hits.sum(axis=1)
    return res

def ndcg_at_k_batch(hits, k):
    hits_k = hits[:, :k]
    dcg = np.sum((2 ** hits_k - 1) / np.log2(np.arange(2, k + 2)), axis=1)
    sorted_hits_k = np.flip(np.sort(hits), axis=1)[:, :k]
    idcg = np.sum((2 ** sorted_hits_k - 1) / np.log2(np.arange(2, k + 2)), axis=1)
    idcg[idcg == 0] = np.inf
    return dcg / idcg

def calc_metrics_at_k(cf_scores, train_user_dict, test_user_dict, user_ids, item_ids, Ks):
    test_binary = np.zeros((len(user_ids), len(item_ids)), dtype=np.float32)
    for i, u in enumerate(user_ids):
        test_items = test_user_dict[u]
        train_items = train_user_dict[u]
        cf_scores[i][train_items] = -np.inf
        test_binary[i][test_items] = 1.0

    try:
        _, rank_indices = torch.sort(cf_scores.cuda(), descending=True)
    except:
        _, rank_indices = torch.sort(cf_scores, descending=True)
    rank_indices = rank_indices.cpu()

    binary_hit = np.array([test_binary[i][rank_indices[i]] for i in range(len(user_ids))], dtype=np.float32)

    results = {}
    for k in Ks:
        results[k] = {
            'precision': precision_at_k_batch(binary_hit, k),
            'recall': recall_at_k_batch(binary_hit, k),
            'ndcg': ndcg_at_k_batch(binary_hit, k)
        }
    return results

In [None]:
# Model Definition

def _L2_loss_mean(x):
    return torch.mean(torch.sum(torch.pow(x, 2), dim=1) / 2.)

class Aggregator(nn.Module):
    def __init__(self, in_dim, out_dim, dropout, aggregator_type):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.activation = nn.LeakyReLU()

        if aggregator_type == 'gcn':
            self.linear = nn.Linear(in_dim, out_dim)
        elif aggregator_type == 'graphsage':
            self.linear = nn.Linear(in_dim * 2, out_dim)
        elif aggregator_type == 'bi-interaction':
            self.linear1 = nn.Linear(in_dim, out_dim)
            self.linear2 = nn.Linear(in_dim, out_dim)
        else:
            raise NotImplementedError
        self.aggregator_type = aggregator_type

    def forward(self, ego_embeddings, A_in):
        side_embeddings = torch.matmul(A_in, ego_embeddings)
        if self.aggregator_type == 'gcn':
            embeddings = self.activation(self.linear(ego_embeddings + side_embeddings))
        elif self.aggregator_type == 'graphsage':
            embeddings = self.activation(self.linear(torch.cat([ego_embeddings, side_embeddings], dim=1)))
        elif self.aggregator_type == 'bi-interaction':
            sum_embed = self.activation(self.linear1(ego_embeddings + side_embeddings))
            bi_embed = self.activation(self.linear2(ego_embeddings * side_embeddings))
            embeddings = sum_embed + bi_embed
        return self.dropout(embeddings)

class KGAT(nn.Module):
    def __init__(self, args, n_users, n_entities, n_relations, A_in=None, user_pre_embed=None, item_pre_embed=None):
        super().__init__()
        self.embed_dim = args.embed_dim
        self.relation_dim = args.relation_dim
        self.n_users = n_users
        self.n_entities = n_entities
        self.n_relations = n_relations

        self.entity_user_embed = nn.Embedding(n_entities + n_users, self.embed_dim)
        self.relation_embed = nn.Embedding(n_relations, self.relation_dim)
        self.trans_M = nn.Parameter(torch.Tensor(n_relations, self.embed_dim, self.relation_dim))

        if args.use_pretrain == 1 and user_pre_embed is not None:
            entity_user_embed = torch.cat([item_pre_embed, nn.Parameter(torch.Tensor(n_entities - item_pre_embed.shape[0], self.embed_dim)), user_pre_embed], dim=0)
            self.entity_user_embed.weight = nn.Parameter(entity_user_embed)
        else:
            nn.init.xavier_uniform_(self.entity_user_embed.weight)

        nn.init.xavier_uniform_(self.relation_embed.weight)
        nn.init.xavier_uniform_(self.trans_M)

        self.aggregation_type = args.aggregation_type
        self.conv_dim_list = [args.embed_dim] + eval(args.conv_dim_list)
        self.mess_dropout = eval(args.mess_dropout)
        self.n_layers = len(self.conv_dim_list) - 1

        self.aggregator_layers = nn.ModuleList([
            Aggregator(self.conv_dim_list[i], self.conv_dim_list[i+1], self.mess_dropout[i], self.aggregation_type)
            for i in range(self.n_layers)
        ])

        self.A_in = nn.Parameter(torch.sparse.FloatTensor(n_users + n_entities, n_users + n_entities))
        if A_in is not None:
            self.A_in.data = A_in
        self.A_in.requires_grad = False

    def calc_cf_embeddings(self):
        ego_embed = self.entity_user_embed.weight
        all_embed = [ego_embed]
        for layer in self.aggregator_layers:
            ego_embed = layer(ego_embed, self.A_in)
            norm_embed = F.normalize(ego_embed, p=2, dim=1)
            all_embed.append(norm_embed)
        return torch.cat(all_embed, dim=1)

    def calc_score(self, user_ids, item_ids):
        all_embed = self.calc_cf_embeddings()
        user_embed = all_embed[user_ids]
        item_embed = all_embed[item_ids]
        return torch.matmul(user_embed, item_embed.transpose(0, 1))

    def forward(self, *input, mode):
        if mode == 'predict':
            return self.calc_score(*input)
        raise NotImplementedError

In [None]:
# DataLoaderKGAT Definition
class DataLoaderKGAT:
    def __init__(self, args, logger):
        self.logger = logger
        self.data_path = os.path.join(args.data_dir, args.data_name)

        self.train_user_dict = self._load_user_item_dict('train.txt')
        self.test_user_dict = self._load_user_item_dict('test.txt')
        self.n_users, self.n_items = self._get_user_item_num()

        kg_file = os.path.join(self.data_path, 'kg_final.txt')
        self.kg_data, self.n_entities, self.n_relations = self._load_kg(kg_file)
        self.n_users_entities = self.n_users + self.n_entities

        self.train_kg_dict = self._construct_kg_dict(self.kg_data)
        self.h_list, self.t_list, self.r_list = self._build_relation_triplets()
        self.A_in, self.laplacian_dict = self._build_sparse_graph()

        self.cf_batch_size = args.cf_batch_size
        self.kg_batch_size = args.kg_batch_size
        self.test_batch_size = args.test_batch_size

        self.n_cf_train = sum([len(v) for v in self.train_user_dict.values()])
        self.n_kg_train = len(self.kg_data)

        if args.use_pretrain == 1:
            self.user_pre_embed = np.load(os.path.join(args.pretrain_embedding_dir, args.data_name, 'user_embed.npy'))
            self.item_pre_embed = np.load(os.path.join(args.pretrain_embedding_dir, args.data_name, 'item_embed.npy'))

    def _load_user_item_dict(self, filename):
        user_dict = dict()
        filepath = os.path.join(self.data_path, filename)
        with open(filepath, 'r') as f:
            for line in f:
                items = list(map(int, line.strip().split()))
                if len(items) < 2:
                    continue
                user, item_list = items[0], items[1:]
                user_dict[user] = item_list
        return user_dict

    def _get_user_item_num(self):
        n_users = max(self.train_user_dict.keys()) + 1
        all_items = set()
        for item_list in self.train_user_dict.values():
            all_items.update(item_list)
        n_items = max(all_items) + 1
        return n_users, n_items

    def _load_kg(self, file_path):
        kg_data = []
        entities, relations = set(), set()
        with open(file_path, 'r') as f:
            for line in f:
                h, r, t = map(int, line.strip().split())
                kg_data.append((h, r, t))
                entities.update([h, t])
                relations.add(r)
        return kg_data, max(entities) + 1, max(relations) + 1

    def _construct_kg_dict(self, kg_data):
        kg_dict = dict()
        for h, r, t in kg_data:
            if h not in kg_dict:
                kg_dict[h] = []
            kg_dict[h].append((t, r))
        return kg_dict

    def _build_relation_triplets(self):
        h_list, t_list, r_list = [], [], []
        for h, pairs in self.train_kg_dict.items():
            for t, r in pairs:
                h_list.append(h)
                t_list.append(t)
                r_list.append(r)
        return torch.LongTensor(h_list), torch.LongTensor(t_list), torch.LongTensor(r_list)

    def _build_sparse_graph(self):
        rows, cols = [], []
        for user in self.train_user_dict:
            for item in self.train_user_dict[user]:
                rows.append(user)
                cols.append(self.n_users + item)
        for h, r, t in self.kg_data:
            rows.append(self.n_users + h)
            cols.append(self.n_users + t)

        data = [1] * len(rows)
        n_nodes = self.n_users + self.n_entities
        adj = sp.coo_matrix((data, (rows, cols)), shape=(n_nodes, n_nodes))
        adj = adj + adj.T.multiply(adj.T > 0) - adj.multiply(adj.T > 0)
        norm_adj = self._normalize_adj(adj)
        return self._convert_sp_mat_to_sp_tensor(norm_adj).to_dense(), {0: norm_adj}

    def _normalize_adj(self, adj):
        rowsum = np.array(adj.sum(1))
        d_inv = np.power(rowsum, -0.5).flatten()
        d_inv[np.isinf(d_inv)] = 0.
        d_mat = sp.diags(d_inv)
        norm_adj = d_mat.dot(adj).dot(d_mat)
        return norm_adj.tocoo()

    def _convert_sp_mat_to_sp_tensor(self, mat):
        mat = mat.tocoo()
        indices = torch.from_numpy(np.vstack((mat.row, mat.col)).astype(np.int64))
        values = torch.from_numpy(mat.data.astype(np.float32))
        shape = torch.Size(mat.shape)
        return torch.sparse.FloatTensor(indices, values, shape)

    def generate_cf_batch(self, user_dict, batch_size):
        users = list(user_dict.keys())
        batch_users, batch_pos_items, batch_neg_items = [], [], []
        for _ in range(batch_size):
            u = random.choice(users)
            pos_items = user_dict[u]
            if not pos_items:
                continue
            i = random.choice(pos_items)
            j = random.randint(0, self.n_items - 1)
            while j in pos_items:
                j = random.randint(0, self.n_items - 1)
            batch_users.append(u)
            batch_pos_items.append(i)
            batch_neg_items.append(j)
        return torch.LongTensor(batch_users), torch.LongTensor(batch_pos_items), torch.LongTensor(batch_neg_items)

    def generate_kg_batch(self, kg_dict, batch_size, entity_num):
        h_list = list(kg_dict.keys())
        batch_h, batch_r, batch_pos_t, batch_neg_t = [], [], [], []
        for _ in range(batch_size):
            h = random.choice(h_list)
            t_r_pairs = kg_dict[h]
            if not t_r_pairs:
                continue
            t, r = random.choice(t_r_pairs)
            neg_t = random.randint(0, entity_num - 1)
            while neg_t in [x[0] for x in t_r_pairs]:
                neg_t = random.randint(0, entity_num - 1)
            batch_h.append(h)
            batch_r.append(r)
            batch_pos_t.append(t)
            batch_neg_t.append(neg_t)
        return torch.LongTensor(batch_h), torch.LongTensor(batch_r), torch.LongTensor(batch_pos_t), torch.LongTensor(batch_neg_t)

In [None]:
# Training / Evaluation / Prediction functions

def evaluate(model, dataloader, Ks, device):
    test_batch_size = dataloader.test_batch_size
    train_user_dict = dataloader.train_user_dict
    test_user_dict = dataloader.test_user_dict

    model.eval()
    user_ids = list(test_user_dict.keys())
    user_batches = [user_ids[i:i+test_batch_size] for i in range(0, len(user_ids), test_batch_size)]
    user_batches = [torch.LongTensor(u) for u in user_batches]
    n_items = dataloader.n_items
    item_ids = torch.arange(n_items, dtype=torch.long).to(device)

    all_scores = []
    metric_names = ['precision', 'recall', 'ndcg']
    metrics_dict = {k: {m: [] for m in metric_names} for k in Ks}

    for batch_users in tqdm(user_batches, desc="Evaluating"):
        batch_users = batch_users.to(device)
        with torch.no_grad():
            scores = model(batch_users, item_ids, mode='predict')
        scores = scores.cpu()
        batch_metrics = calc_metrics_at_k(scores, train_user_dict, test_user_dict, batch_users.cpu().numpy(), item_ids.cpu().numpy(), Ks)
        all_scores.append(scores.numpy())
        for k in Ks:
            for m in metric_names:
                metrics_dict[k][m].append(batch_metrics[k][m])

    all_scores = np.concatenate(all_scores, axis=0)
    for k in Ks:
        for m in metric_names:
            metrics_dict[k][m] = np.concatenate(metrics_dict[k][m]).mean()
    return all_scores, metrics_dict

def train(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    log_id = create_log_id(args.save_dir)
    logging_config(args.save_dir, f'log{log_id}', no_console=False)
    logging.info(args)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    data = DataLoaderKGAT(args, logging)
    user_pre_embed = torch.tensor(data.user_pre_embed) if args.use_pretrain == 1 else None
    item_pre_embed = torch.tensor(data.item_pre_embed) if args.use_pretrain == 1 else None

    model = KGAT(args, data.n_users, data.n_entities, data.n_relations, data.A_in, user_pre_embed, item_pre_embed)
    if args.use_pretrain == 2:
        model = load_model(model, args.pretrain_model_path)
    model.to(device)

    cf_optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    kg_optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    best_epoch = -1
    Ks = eval(args.Ks)
    metric_store = {k: {m: [] for m in ['precision', 'recall', 'ndcg']} for k in Ks}
    epoch_list = []

    for epoch in range(1, args.n_epoch + 1):
        model.train()
        cf_loss_total = 0
        for _ in range(data.n_cf_train // args.cf_batch_size + 1):
            u, pos_i, neg_i = data.generate_cf_batch(data.train_user_dict, args.cf_batch_size)
            u, pos_i, neg_i = u.to(device), pos_i.to(device), neg_i.to(device)
            cf_loss = model(u, pos_i, neg_i, mode='train_cf')
            cf_optimizer.zero_grad()
            cf_loss.backward()
            cf_optimizer.step()
            cf_loss_total += cf_loss.item()

        kg_loss_total = 0
        for _ in range(data.n_kg_train // args.kg_batch_size + 1):
            h, r, pt, nt = data.generate_kg_batch(data.train_kg_dict, args.kg_batch_size, data.n_users_entities)
            h, r, pt, nt = h.to(device), r.to(device), pt.to(device), nt.to(device)
            kg_loss = model(h, r, pt, nt, mode='train_kg')
            kg_optimizer.zero_grad()
            kg_loss.backward()
            kg_optimizer.step()
            kg_loss_total += kg_loss.item()

        h, t, r = data.h_list.to(device), data.t_list.to(device), data.r_list.to(device)
        model(h, t, r, list(data.laplacian_dict.keys()), mode='update_att')

        if epoch % args.evaluate_every == 0 or epoch == args.n_epoch:
            scores, metrics = evaluate(model, data, Ks, device)
            for k in Ks:
                for m in ['precision', 'recall', 'ndcg']:
                    metric_store[k][m].append(metrics[k][m])
            epoch_list.append(epoch)
            best_recall, stop = early_stopping(metric_store[min(Ks)]['recall'], args.stopping_steps)
            if stop:
                break
            if metric_store[min(Ks)]['recall'][-1] == best_recall:
                save_model(model, args.save_dir, epoch, best_epoch)
                best_epoch = epoch

    # Save metrics to CSV
    records = {'epoch_idx': epoch_list}
    for k in Ks:
        for m in ['precision', 'recall', 'ndcg']:
            records[f'{m}@{k}'] = metric_store[k][m]
    pd.DataFrame(records).to_csv(os.path.join(args.save_dir, 'metrics.tsv'), sep='\t', index=False)

    # Print best
    best_idx = epoch_list.index(best_epoch)
    logging.info(f'Best @ Epoch {best_epoch}:')
    for k in Ks:
        for m in ['precision', 'recall', 'ndcg']:
            logging.info(f'{m}@{k}: {metric_store[k][m][best_idx]:.4f}')

def predict(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    data = DataLoaderKGAT(args, logging)
    model = KGAT(args, data.n_users, data.n_entities, data.n_relations)
    model = load_model(model, args.pretrain_model_path)
    model.to(device)
    Ks = eval(args.Ks)
    scores, metrics = evaluate(model, data, Ks, device)
    np.save(os.path.join(args.save_dir, 'cf_scores.npy'), scores)
    for k in Ks:
        print(f'Precision@{k}: {metrics[k]["precision"]:.4f}, Recall@{k}: {metrics[k]["recall"]:.4f}, NDCG@{k}: {metrics[k]["ndcg"]:.4f}')


In [None]:
# Main Execution
if __name__ == '__main__':
    args = parse_kgat_args()
    train(args)