In [1]:
# %load main.py
import os
import random

import torch
import numpy as np

from time import time
from tqdm import tqdm
from copy import deepcopy
import logging
from prettytable import PrettyTable

from utils.parser import parse_args
from utils.data_loader import load_data
from utils.evaluate import test
from utils.helper import early_stopping

n_users = 0
n_items = 0


def get_feed_dict(train_entity_pairs, train_pos_set, start, end, n_negs=1):

    def sampling(user_item, train_set, n):
        neg_items = []
        for user, _ in user_item.cpu().numpy():
            user = int(user)
            negitems = []
            for i in range(n):  # sample n times
                while True:
                    negitem = random.choice(range(n_items))
                    if negitem not in train_set[user]:
                        break
                negitems.append(negitem)
            neg_items.append(negitems)
        return neg_items

    feed_dict = {}
    entity_pairs = train_entity_pairs[start:end]
    feed_dict['users'] = entity_pairs[:, 0]
    feed_dict['pos_items'] = entity_pairs[:, 1]
    feed_dict['neg_items'] = torch.LongTensor(sampling(entity_pairs,
                                                       train_pos_set,
                                                       n_negs*K)).to(device)
    return feed_dict


if __name__ == '__main__':
    """fix the random seed"""
    seed = 2020
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    """read args"""
    global args, device
    args = parse_args()
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
    device = torch.device("cuda:0") if args.cuda else torch.device("cpu")

    """build dataset"""
    train_cf, user_dict, n_params, norm_mat = load_data(args)
    train_cf_size = len(train_cf)
    train_cf = torch.LongTensor(np.array([[cf[0], cf[1]] for cf in train_cf], np.int32))

    n_users = n_params['n_users']
    n_items = n_params['n_items']
    n_negs = args.n_negs
    K = args.K

    """define model"""
    from modules.LightGCN import LightGCN
    if args.gnn == 'lightgcn':
        model = LightGCN(n_params, args, norm_mat).to(device)

    """define optimizer"""
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    cur_best_pre_0 = 0
    stopping_step = 0
    should_stop = False

    print("start training ...")
    for epoch in range(args.epoch):
        # shuffle training data
        train_cf_ = train_cf
        index = np.arange(len(train_cf_))
        np.random.shuffle(index)
        train_cf_ = train_cf_[index].to(device)

        """training"""
        model.train()
        loss, s = 0, 0
        hits = 0
        train_s_t = time()
        while s + args.batch_size <= len(train_cf):
            batch = get_feed_dict(train_cf_,
                                  user_dict['train_user_set'],
                                  s, s + args.batch_size,
                                  n_negs)

            batch_loss, _, _ = model(epoch, batch)

            optimizer.zero_grad()
            batch_loss.backward()
            optimizer.step()

            loss += batch_loss
            s += args.batch_size

        train_e_t = time()

        if epoch % 5 == 0:
            """testing"""

            train_res = PrettyTable()
            train_res.field_names = ["Epoch", "training time(s)", "tesing time(s)", "Loss", "recall", "ndcg", "precision", "hit_ratio"]

            model.eval()
            test_s_t = time()
            test_ret = test(model, user_dict, n_params, mode='test')
            test_e_t = time()
            train_res.add_row(
                [epoch, train_e_t - train_s_t, test_e_t - test_s_t, loss.item(), test_ret['recall'], test_ret['ndcg'],
                 test_ret['precision'], test_ret['hit_ratio']])

            if user_dict['valid_user_set'] is None:
                valid_ret = test_ret
            else:
                test_s_t = time()
                valid_ret = test(model, user_dict, n_params, mode='valid')
                test_e_t = time()
                train_res.add_row(
                    [epoch, train_e_t - train_s_t, test_e_t - test_s_t, loss.item(), valid_ret['recall'], valid_ret['ndcg'],
                     valid_ret['precision'], valid_ret['hit_ratio']])
            print(train_res)

            # *********************************************************
            # early stopping when cur_best_pre_0 is decreasing for 10 successive steps.
            cur_best_pre_0, stopping_step, should_stop = early_stopping(valid_ret['recall'][0], cur_best_pre_0,
                                                                        stopping_step, expected_order='acc',
                                                                        flag_step=10)
            if should_stop:
                break

            """save weight"""
            if valid_ret['recall'][0] == cur_best_pre_0 and args.save:
                torch.save(model.state_dict(), args.out_dir + 'model_' + '.ckpt')
        else:
            # logging.info('training loss at epoch %d: %f' % (epoch, loss.item()))
            print('using time %.4fs, training loss at epoch %d: %.4f' % (train_e_t - train_s_t, epoch, loss.item()))

    print('early stopping at %d, recall@20:%.4f' % (epoch, cur_best_pre_0))


reading train and test user-item set ...
building the adj mat ...
{'n_users': 640, 'n_items': 4165}
loading over ...
start training ...
+-------+--------------------+--------------------+-------------------+--------------+-------------+--------------+-------------+
| Epoch |  training time(s)  |   tesing time(s)   |        Loss       |    recall    |     ndcg    |  precision   |  hit_ratio  |
+-------+--------------------+--------------------+-------------------+--------------+-------------+--------------+-------------+
|   0   | 0.2412886619567871 | 1.3720169067382812 | 5.816640853881836 | [0.01407513] | [0.0130865] | [0.01095652] | [0.1026087] |
+-------+--------------------+--------------------+-------------------+--------------+-------------+--------------+-------------+
using time 0.2541s, training loss at epoch 1: 5.8094
using time 0.2424s, training loss at epoch 2: 5.7962
using time 0.2447s, training loss at epoch 3: 5.7751
using time 0.2807s, training loss at epoch 4: 5.7446
+-

using time 0.2751s, training loss at epoch 46: 2.7922
using time 0.1741s, training loss at epoch 47: 2.7461
using time 0.1768s, training loss at epoch 48: 2.7138
using time 0.1713s, training loss at epoch 49: 2.6774
+-------+--------------------+------------------+------------------+--------------+-------------+--------------+--------------+
| Epoch |  training time(s)  |  tesing time(s)  |       Loss       |    recall    |     ndcg    |  precision   |  hit_ratio   |
+-------+--------------------+------------------+------------------+--------------+-------------+--------------+--------------+
|   50  | 0.1774299144744873 | 1.45259428024292 | 2.65997052192688 | [0.09662753] | [0.0913306] | [0.04991304] | [0.32521739] |
+-------+--------------------+------------------+------------------+--------------+-------------+--------------+--------------+
using time 0.2560s, training loss at epoch 51: 2.6432
using time 0.2395s, training loss at epoch 52: 2.6245
using time 0.2421s, training loss at

using time 0.2541s, training loss at epoch 96: 1.7816
using time 0.2783s, training loss at epoch 97: 1.7725
using time 0.2434s, training loss at epoch 98: 1.7296
using time 0.2458s, training loss at epoch 99: 1.7568
+-------+--------------------+--------------------+-------------------+--------------+--------------+--------------+--------------+
| Epoch |  training time(s)  |   tesing time(s)   |        Loss       |    recall    |     ndcg     |  precision   |  hit_ratio   |
+-------+--------------------+--------------------+-------------------+--------------+--------------+--------------+--------------+
|  100  | 0.2445824146270752 | 1.3857660293579102 | 1.727109432220459 | [0.11332071] | [0.10206089] | [0.05669565] | [0.36695652] |
+-------+--------------------+--------------------+-------------------+--------------+--------------+--------------+--------------+
using time 0.2518s, training loss at epoch 101: 1.7183
using time 0.2465s, training loss at epoch 102: 1.6773
using time 0.2

using time 0.2520s, training loss at epoch 146: 1.1965
using time 0.2416s, training loss at epoch 147: 1.1730
using time 0.2833s, training loss at epoch 148: 1.1813
using time 0.2416s, training loss at epoch 149: 1.1704
+-------+---------------------+-------------------+--------------------+--------------+--------------+-------------+--------------+
| Epoch |   training time(s)  |   tesing time(s)  |        Loss        |    recall    |     ndcg     |  precision  |  hit_ratio   |
+-------+---------------------+-------------------+--------------------+--------------+--------------+-------------+--------------+
|  150  | 0.24079132080078125 | 1.402925968170166 | 1.1527090072631836 | [0.11868586] | [0.10895337] | [0.0606087] | [0.38608696] |
+-------+---------------------+-------------------+--------------------+--------------+--------------+-------------+--------------+
using time 0.2337s, training loss at epoch 151: 1.1535
using time 0.2359s, training loss at epoch 152: 1.1332
using time

using time 0.2485s, training loss at epoch 196: 0.8423
using time 0.2446s, training loss at epoch 197: 0.8432
using time 0.2373s, training loss at epoch 198: 0.8247
using time 0.2806s, training loss at epoch 199: 0.8264
+-------+---------------------+--------------------+--------------------+-------------+--------------+--------------+--------------+
| Epoch |   training time(s)  |   tesing time(s)   |        Loss        |    recall   |     ndcg     |  precision   |  hit_ratio   |
+-------+---------------------+--------------------+--------------------+-------------+--------------+--------------+--------------+
|  200  | 0.23378992080688477 | 1.3838880062103271 | 0.8254035711288452 | [0.1232404] | [0.11152897] | [0.06095652] | [0.39826087] |
+-------+---------------------+--------------------+--------------------+-------------+--------------+--------------+--------------+
using time 0.2611s, training loss at epoch 201: 0.8442
using time 0.2427s, training loss at epoch 202: 0.8114
using

+-------+---------------------+--------------------+--------------------+--------------+--------------+--------------+--------------+
| Epoch |   training time(s)  |   tesing time(s)   |        Loss        |    recall    |     ndcg     |  precision   |  hit_ratio   |
+-------+---------------------+--------------------+--------------------+--------------+--------------+--------------+--------------+
|  245  | 0.24157190322875977 | 1.3697524070739746 | 0.6440037488937378 | [0.12529553] | [0.11488307] | [0.06226087] | [0.40521739] |
+-------+---------------------+--------------------+--------------------+--------------+--------------+--------------+--------------+
using time 0.2520s, training loss at epoch 246: 0.6755
using time 0.2373s, training loss at epoch 247: 0.6516
using time 0.2414s, training loss at epoch 248: 0.6401
using time 0.2410s, training loss at epoch 249: 0.6108
+-------+--------------------+--------------------+--------------------+--------------+--------------+--------

using time 0.2569s, training loss at epoch 291: 0.5514
using time 0.2408s, training loss at epoch 292: 0.5399
using time 0.2474s, training loss at epoch 293: 0.5462
using time 0.2393s, training loss at epoch 294: 0.5380
+-------+-------------------+--------------------+-------------------+--------------+--------------+--------------+--------------+
| Epoch |  training time(s) |   tesing time(s)   |        Loss       |    recall    |     ndcg     |  precision   |  hit_ratio   |
+-------+-------------------+--------------------+-------------------+--------------+--------------+--------------+--------------+
|  295  | 0.237687349319458 | 1.3807532787322998 | 0.548681378364563 | [0.12869027] | [0.11822192] | [0.06286957] | [0.40869565] |
+-------+-------------------+--------------------+-------------------+--------------+--------------+--------------+--------------+
using time 0.2411s, training loss at epoch 296: 0.5324
using time 0.2430s, training loss at epoch 297: 0.5397
using time 0.24

using time 0.2621s, training loss at epoch 341: 0.1863
using time 0.2422s, training loss at epoch 342: 0.1872
using time 0.2411s, training loss at epoch 343: 0.1982
using time 0.2467s, training loss at epoch 344: 0.1867
+-------+--------------------+--------------------+---------------------+--------------+--------------+--------------+-----------+
| Epoch |  training time(s)  |   tesing time(s)   |         Loss        |    recall    |     ndcg     |  precision   | hit_ratio |
+-------+--------------------+--------------------+---------------------+--------------+--------------+--------------+-----------+
|  345  | 0.2539842128753662 | 1.4626522064208984 | 0.18454313278198242 | [0.12991752] | [0.11872316] | [0.06365217] |   [0.4]   |
+-------+--------------------+--------------------+---------------------+--------------+--------------+--------------+-----------+
using time 0.2503s, training loss at epoch 346: 0.1962
using time 0.2407s, training loss at epoch 347: 0.1836
using time 0.24