In [1]:
import sys
import torch
print(torch.cuda.device_count())
sys.path.append("..")
import numpy as np
import os
import pickle, importlib, random, Engine, tqdm, copy, json, time, argparse
import util.Generator as Generator
import util.Datahelper as dh

sys.argv = ' '
parser = argparse.ArgumentParser()
parser.add_argument('-test', action="store_true", default=False)
parser.add_argument('-ckptname', dest='ckptname', default=None, required=False)

parser.add_argument('-nl', dest='num_layers', default=2, required=False)
parser.add_argument('-nhd', dest='num_hidden_dims', default=2**9, required=False)
parser.add_argument('-nh', dest='num_heads', default=8, required=False)

parser.add_argument('-i', dest='use_item_feat', default=True, required=False)
parser.add_argument('-u', dest='use_user_feat', default=True, required=False)

parser.add_argument(
    '-pt_sample_func', dest='pt_sample_func', default='(lambda x:x)', required=False)
parser.add_argument(
    '-pt_sample_param', dest='pt_sample_param', default='0', required=False)
parser.add_argument(
    '-pt_history_func', dest='pt_history_func', default='(lambda x:x)', required=False)
parser.add_argument(
    '-pt_history_param', dest='pt_history_param', default='20', required=False)

parser.add_argument('-nonimprove_limit', dest='nonimprove_limit', default=10, required=False)
parser.add_argument('-seed', dest='seed', default=0, required=False, type=int)
args = parser.parse_args()
print(args)

Engine.set_random_seed(args.seed)

basic_config = {
    'cuda_num' : Engine.GPU_max_free_memory(),
    'course_file' : '../datasets/Amazon_data/Beauty/basket_dataset.pkl',
    'num_times' : 20,
    'num_items' : 10000, 
    'batch_size' : 32, 
    'feats' : [5]
}
    
save_name = 'checkpoint/KNN'
print(basic_config)

os.environ['CUDA_VISIBLE_DEVICES'] = str(basic_config['cuda_num'])

with open(basic_config['course_file'], 'rb') as f:
    user_dict = pickle.load(f)
    print('Total Number of Users : ' + str(len(user_dict)))
    
all_keys = list(user_dict.keys())
all_keys.sort()
np.random.shuffle(all_keys)
#used_keys, _ = dh.list_partition(all_keys, 0.1, seed=0)

train_keys, tv_keys = dh.list_partition(all_keys, 0.7, seed=0)
test_keys, valid_keys = dh.list_partition(tv_keys, 0.5, seed=0)

8
Namespace(ckptname=None, nonimprove_limit=10, num_heads=8, num_hidden_dims=512, num_layers=2, pt_history_func='(lambda x:x)', pt_history_param='20', pt_sample_func='(lambda x:x)', pt_sample_param='0', seed=0, test=False, use_item_feat=True, use_user_feat=True)
{'cuda_num': 1, 'course_file': '../datasets/Amazon_data/Beauty/basket_dataset.pkl', 'num_times': 20, 'num_items': 10000, 'batch_size': 32, 'feats': [5]}
Total Number of Users : 10000


In [2]:
def generator2feature(generator):
    dataset = generator.__getitem__(batch_id=0, batch_size='MAX')[0]
    course = dataset[0]
    target = dataset[-1]
    return course, target

if True:
    train_generator_config = {
        'name' : None,
        'training' : True, 
        'sample_func' : args.pt_sample_func,
        'sample_param' : args.pt_sample_param,
        'history_func' : args.pt_history_func,
        'history_param' : args.pt_history_param,
        'next_basket' : True, 
        'batch_size' : basic_config['batch_size'],
        'shuffle' : True,
        'fixed_seed' : False}

    train_generator = Generator.TimeMultihotGenerator(
        user_dict, train_keys, basic_config, train_generator_config)

train_courses, train_target = generator2feature(train_generator)

from sklearn.metrics.pairwise import cosine_similarity
from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.neighbors import NearestNeighbors
import util.Metrics as Metrics

def userKNN_GPU(train_features, test_courses, cuda_num):
    # train_features, torch.FloatTensor in GPU [num_stu, num_courses]
    # test_courses, np.array in CPU [num_stu, num_sem, num_courses]
    test_features = torch.FloatTensor(test_courses.sum(1).astype(float)).cuda(cuda_num)
    #train_features += torch.rand(train_features.shape).cuda(cuda_num) * 1e-10
    #test_features += torch.rand(test_features.shape).cuda(cuda_num) * 1e-10

    sim = []
    for iter in tqdm.tqdm(range(test_features.shape[0])):
        sim.append(torch.cosine_similarity(train_features.unsqueeze(1), test_features[iter].unsqueeze(0).unsqueeze(0), dim=-1))
    sim = torch.cat(sim, dim=-1).T

    pred = []
    for iter in range(basic_config['num_times']):
        pred.append(torch.matmul(sim, torch.FloatTensor(train_courses[:, iter]).cuda(cuda_num))[:, np.newaxis])
    pred = torch.cat(pred, dim=1)
    #pred += torch.rand(pred.shape).cuda(cuda_num) * 1e-10
    pred = pred.cpu().numpy()
    return pred

In [3]:
test_generator_config = {
    'training' : False, 
    'max_sampling' : 6,
    'mask_rate' : None,
    'historical' : None,
    'batch_size' : 16,
    'shuffle' : False,
    'fixed_seed' : True}
    
    
results_mat = {}
for h in list(range(4)) + list(range(5, 21, 5)):
    results_mat[h] = {}
    for r in list(range(6)):
        test_generator_config['sample_func'] = '(lambda x:x)'
        test_generator_config['sample_param'] = str(r)
        test_generator_config['history_func'] = '(lambda x:x)'
        test_generator_config['history_param'] = str(h)
        test_generator_config['name'] = 'H={1}_R={0}'.format(r, h)
        test_generator = Generator.TimeMultihotGenerator(
            user_dict, test_keys, basic_config, test_generator_config)
        
        print(test_generator.name + ' ' + str(test_generator.batch_size))
        train_features = torch.FloatTensor(train_courses.sum(1).astype(float))
        train_features_cuda = train_features.cuda(basic_config['cuda_num'])
        test_courses, test_target = generator2feature(test_generator)
        user_pred = userKNN_GPU(train_features_cuda, test_courses, basic_config['cuda_num'])
        recall, recall_per_sem = Metrics.recall(test_target[:, h:], user_pred[:, h:], at_n=10)
        print('Recall: {:.4f}'.format(recall))
        results_mat[h][r] = [recall, recall_per_sem]
save_name = save_name + '.npy'
print(save_name)
np.save(save_name, np.array(results_mat))

H=0_R=0 16


100%|██████████| 1500/1500 [00:05<00:00, 251.75it/s]


Recall: 0.0014
H=0_R=1 16


100%|██████████| 1500/1500 [00:06<00:00, 247.82it/s]


Recall: 0.0485
H=0_R=2 16


100%|██████████| 1500/1500 [00:06<00:00, 230.16it/s]


Recall: 0.0549
H=0_R=3 16


100%|██████████| 1500/1500 [00:05<00:00, 259.37it/s]


Recall: 0.0619
H=0_R=4 16


100%|██████████| 1500/1500 [00:05<00:00, 259.32it/s]


Recall: 0.0706
H=0_R=5 16


100%|██████████| 1500/1500 [00:10<00:00, 145.45it/s]


Recall: 0.0727
H=1_R=0 16


100%|██████████| 1500/1500 [00:07<00:00, 202.07it/s]


Recall: 0.0359
H=1_R=1 16


100%|██████████| 1500/1500 [00:05<00:00, 251.04it/s]


Recall: 0.0501
H=1_R=2 16


100%|██████████| 1500/1500 [00:06<00:00, 236.42it/s]


Recall: 0.0539
H=1_R=3 16


100%|██████████| 1500/1500 [00:06<00:00, 242.19it/s]


Recall: 0.0607
H=1_R=4 16


100%|██████████| 1500/1500 [00:05<00:00, 259.30it/s]


Recall: 0.0653
H=1_R=5 16


100%|██████████| 1500/1500 [00:05<00:00, 296.13it/s]


Recall: 0.0670
H=2_R=0 16


100%|██████████| 1500/1500 [00:05<00:00, 296.03it/s]


Recall: 0.0415
H=2_R=1 16


100%|██████████| 1500/1500 [00:05<00:00, 297.22it/s]


Recall: 0.0482
H=2_R=2 16


100%|██████████| 1500/1500 [00:05<00:00, 280.31it/s]


Recall: 0.0515
H=2_R=3 16


100%|██████████| 1500/1500 [00:08<00:00, 176.37it/s]


Recall: 0.0566
H=2_R=4 16


100%|██████████| 1500/1500 [00:09<00:00, 159.91it/s]


Recall: 0.0585
H=2_R=5 16


100%|██████████| 1500/1500 [00:09<00:00, 152.83it/s]


Recall: 0.0599
H=3_R=0 16


100%|██████████| 1500/1500 [00:11<00:00, 135.29it/s]


Recall: 0.0491
H=3_R=1 16


100%|██████████| 1500/1500 [00:06<00:00, 237.96it/s]


Recall: 0.0533
H=3_R=2 16


100%|██████████| 1500/1500 [00:05<00:00, 296.46it/s]


Recall: 0.0522
H=3_R=3 16


100%|██████████| 1500/1500 [00:05<00:00, 297.19it/s]


Recall: 0.0581
H=3_R=4 16


100%|██████████| 1500/1500 [00:05<00:00, 297.25it/s]


Recall: 0.0601
H=3_R=5 16


100%|██████████| 1500/1500 [00:05<00:00, 296.62it/s]


Recall: 0.0615
H=5_R=0 16


100%|██████████| 1500/1500 [00:05<00:00, 297.31it/s]


Recall: 0.0518
H=5_R=1 16


100%|██████████| 1500/1500 [00:05<00:00, 297.22it/s]


Recall: 0.0557
H=5_R=2 16


100%|██████████| 1500/1500 [00:05<00:00, 296.56it/s]


Recall: 0.0566
H=5_R=3 16


100%|██████████| 1500/1500 [00:05<00:00, 296.46it/s]


Recall: 0.0613
H=5_R=4 16


100%|██████████| 1500/1500 [00:05<00:00, 296.46it/s]


Recall: 0.0664
H=5_R=5 16


100%|██████████| 1500/1500 [00:05<00:00, 297.39it/s]


Recall: 0.0681
H=10_R=0 16


100%|██████████| 1500/1500 [00:05<00:00, 296.47it/s]


Recall: 0.0582
H=10_R=1 16


100%|██████████| 1500/1500 [00:05<00:00, 296.62it/s]


Recall: 0.0601
H=10_R=2 16


100%|██████████| 1500/1500 [00:05<00:00, 296.69it/s]


Recall: 0.0660
H=10_R=3 16


100%|██████████| 1500/1500 [00:05<00:00, 297.24it/s]


Recall: 0.0650
H=10_R=4 16


100%|██████████| 1500/1500 [00:05<00:00, 297.23it/s]


Recall: 0.0679
H=10_R=5 16


100%|██████████| 1500/1500 [00:05<00:00, 296.56it/s]


Recall: 0.0679
H=15_R=0 16


100%|██████████| 1500/1500 [00:05<00:00, 296.61it/s]


Recall: 0.0941
H=15_R=1 16


100%|██████████| 1500/1500 [00:05<00:00, 296.49it/s]


Recall: 0.0916
H=15_R=2 16


100%|██████████| 1500/1500 [00:05<00:00, 296.75it/s]


Recall: 0.0941
H=15_R=3 16


100%|██████████| 1500/1500 [00:05<00:00, 296.61it/s]


Recall: 0.0965
H=15_R=4 16


100%|██████████| 1500/1500 [00:05<00:00, 297.23it/s]


Recall: 0.0965
H=15_R=5 16


100%|██████████| 1500/1500 [00:05<00:00, 296.70it/s]


Recall: 0.0965
H=20_R=0 16


100%|██████████| 1500/1500 [00:05<00:00, 297.24it/s]
  return np.array(recall_list).sum() / target.sum(), np.array(recall_list).sum(0).sum(-1) / target.sum(0).sum(-1)


Recall: nan
H=20_R=1 16


100%|██████████| 1500/1500 [00:05<00:00, 297.26it/s]


Recall: nan
H=20_R=2 16


100%|██████████| 1500/1500 [00:05<00:00, 296.58it/s]


Recall: nan
H=20_R=3 16


100%|██████████| 1500/1500 [00:05<00:00, 296.61it/s]


Recall: nan
H=20_R=4 16


100%|██████████| 1500/1500 [00:05<00:00, 296.59it/s]


Recall: nan
H=20_R=5 16


100%|██████████| 1500/1500 [00:05<00:00, 297.35it/s]


Recall: nan
checkpoint/KNN.npy
