In [1]:
from Models.BPR import BPR
from Utils.dataset import implicit_CF_dataset, implicit_CF_dataset_test
from Utils.data_utils import read_LOO_settings

import torch
import torch.utils.data as data
import torch.optim as optim

from run import LOO_run

In [2]:
# dummy class
class opt:
    def __init__(self, max_epoch, early_stop, es_epoch):
        self.max_epoch = max_epoch
        self.early_stop = early_stop
        self.es_epoch = es_epoch

In [3]:
# gpu setting
gpu = torch.device('cuda:' + str(2))

# for training
model, lr, batch_size, num_ns = 'BPR', 0.001, 1024, 1
max_epoch, early_stop, es_epoch= 500, 30, 0
reg = 0.001

# dataset
data_path, dataset, LOO_seed = 'Dataset/', 'citeULike', 0
user_count, item_count, train_mat, train_interactions, valid_sample, test_sample, candidates = read_LOO_settings(data_path, dataset, LOO_seed)

train_dataset = implicit_CF_dataset(user_count, item_count, train_mat, train_interactions, num_ns)
test_dataset = implicit_CF_dataset_test(user_count, test_sample, valid_sample, candidates)

train_loader = data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [4]:
dim = 100
model = BPR(user_count, item_count, dim, gpu)

# optimizer
model = model.to(gpu)
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=reg)

print("User::", user_count, "Item::", item_count, "Interactions::", len(train_interactions))

User:: 5220 Item:: 25182 Interactions:: 115142


In [5]:
# start train
LOO_run(opt(max_epoch, early_stop, es_epoch), model, gpu, optimizer, train_loader, test_dataset, model_save_path=None)

Epoch [0/500], Train Loss: 705.6689, Elapsed: train 0.52 test 2.15 *
valid H@5: 0.1255, N@5: 0.0904, M@5: 0.0788
valid H@10: 0.1705, N@10: 0.1049, M@10: 0.0848
valid H@20: 0.2336, N@20: 0.1208, M@20: 0.0891

test H@5: 0.1391, N@5: 0.0994, M@5: 0.0864
test H@10: 0.1903, N@10: 0.1159, M@10: 0.0931
test H@20: 0.2518, N@20: 0.1314, M@20: 0.0973

Epoch [1/500], Train Loss: 668.7119, Elapsed: train 0.50 test 1.82 *
valid H@5: 0.2152, N@5: 0.1567, M@5: 0.1374
valid H@10: 0.2820, N@10: 0.1783, M@10: 0.1462
valid H@20: 0.3688, N@20: 0.2001, M@20: 0.1522

test H@5: 0.2261, N@5: 0.1679, M@5: 0.1486
test H@10: 0.2957, N@10: 0.1904, M@10: 0.1579
test H@20: 0.3740, N@20: 0.2102, M@20: 0.1633

Epoch [2/500], Train Loss: 544.1658, Elapsed: train 0.50 test 1.93 *
valid H@5: 0.2462, N@5: 0.1795, M@5: 0.1575
valid H@10: 0.3280, N@10: 0.2058, M@10: 0.1683
valid H@20: 0.4053, N@20: 0.2253, M@20: 0.1736

test H@5: 0.2522, N@5: 0.1848, M@5: 0.1626
test H@10: 0.3292, N@10: 0.2096, M@10: 0.1728
test H@20: 0.41

Epoch [24/500], Train Loss: 56.4162, Elapsed: train 0.53 test 1.75 *
valid H@5: 0.4074, N@5: 0.3066, M@5: 0.2733
valid H@10: 0.5152, N@10: 0.3415, M@10: 0.2877
valid H@20: 0.6197, N@20: 0.3680, M@20: 0.2951

test H@5: 0.4223, N@5: 0.3162, M@5: 0.2811
test H@10: 0.5313, N@10: 0.3514, M@10: 0.2956
test H@20: 0.6354, N@20: 0.3778, M@20: 0.3030

Epoch [25/500], Train Loss: 53.7824, Elapsed: train 0.49 test 1.76 *
valid H@5: 0.4118, N@5: 0.3092, M@5: 0.2753
valid H@10: 0.5206, N@10: 0.3444, M@10: 0.2898
valid H@20: 0.6227, N@20: 0.3703, M@20: 0.2970

test H@5: 0.4229, N@5: 0.3160, M@5: 0.2806
test H@10: 0.5317, N@10: 0.3512, M@10: 0.2952
test H@20: 0.6392, N@20: 0.3785, M@20: 0.3027

Epoch [26/500], Train Loss: 52.7137, Elapsed: train 0.49 test 1.76 *
valid H@5: 0.4144, N@5: 0.3101, M@5: 0.2756
valid H@10: 0.5242, N@10: 0.3457, M@10: 0.2903
valid H@20: 0.6223, N@20: 0.3705, M@20: 0.2972

test H@5: 0.4265, N@5: 0.3196, M@5: 0.2843
test H@10: 0.5340, N@10: 0.3544, M@10: 0.2986
test H@20: 0.64

Epoch [48/500], Train Loss: 30.2603, Elapsed: train 0.46 test 1.81 *
valid H@5: 0.4532, N@5: 0.3423, M@5: 0.3055
valid H@10: 0.5597, N@10: 0.3770, M@10: 0.3200
valid H@20: 0.6612, N@20: 0.4027, M@20: 0.3271

test H@5: 0.4635, N@5: 0.3531, M@5: 0.3166
test H@10: 0.5681, N@10: 0.3871, M@10: 0.3306
test H@20: 0.6760, N@20: 0.4145, M@20: 0.3383

Epoch [49/500], Train Loss: 28.6725, Elapsed: train 0.46 test 1.81 *
valid H@5: 0.4530, N@5: 0.3437, M@5: 0.3074
valid H@10: 0.5622, N@10: 0.3792, M@10: 0.3222
valid H@20: 0.6620, N@20: 0.4045, M@20: 0.3292

test H@5: 0.4646, N@5: 0.3537, M@5: 0.3169
test H@10: 0.5727, N@10: 0.3886, M@10: 0.3313
test H@20: 0.6771, N@20: 0.4152, M@20: 0.3387

Epoch [50/500], Train Loss: 28.2332, Elapsed: train 0.53 test 1.64 *
valid H@5: 0.4526, N@5: 0.3446, M@5: 0.3088
valid H@10: 0.5652, N@10: 0.3811, M@10: 0.3239
valid H@20: 0.6609, N@20: 0.4052, M@20: 0.3305

test H@5: 0.4654, N@5: 0.3546, M@5: 0.3179
test H@10: 0.5735, N@10: 0.3896, M@10: 0.3324
test H@20: 0.67

Epoch [72/500], Train Loss: 22.9766, Elapsed: train 0.43 test 1.75 *
valid H@5: 0.4775, N@5: 0.3625, M@5: 0.3245
valid H@10: 0.5823, N@10: 0.3965, M@10: 0.3386
valid H@20: 0.6762, N@20: 0.4202, M@20: 0.3451

test H@5: 0.4853, N@5: 0.3706, M@5: 0.3327
test H@10: 0.5940, N@10: 0.4057, M@10: 0.3472
test H@20: 0.6925, N@20: 0.4307, M@20: 0.3541

Epoch [73/500], Train Loss: 22.7681, Elapsed: train 0.49 test 1.72 *
valid H@5: 0.4779, N@5: 0.3627, M@5: 0.3245
valid H@10: 0.5806, N@10: 0.3960, M@10: 0.3384
valid H@20: 0.6768, N@20: 0.4204, M@20: 0.3451

test H@5: 0.4850, N@5: 0.3725, M@5: 0.3353
test H@10: 0.5913, N@10: 0.4069, M@10: 0.3495
test H@20: 0.6929, N@20: 0.4327, M@20: 0.3567

Epoch [74/500], Train Loss: 22.6391, Elapsed: train 0.53 test 1.74 *
valid H@5: 0.4788, N@5: 0.3633, M@5: 0.3250
valid H@10: 0.5827, N@10: 0.3970, M@10: 0.3390
valid H@20: 0.6781, N@20: 0.4212, M@20: 0.3457

test H@5: 0.4834, N@5: 0.3723, M@5: 0.3356
test H@10: 0.5928, N@10: 0.4077, M@10: 0.3502
test H@20: 0.69

Epoch [96/500], Train Loss: 20.1196, Elapsed: train 0.50 test 1.72 *
valid H@5: 0.4873, N@5: 0.3748, M@5: 0.3374
valid H@10: 0.5930, N@10: 0.4090, M@10: 0.3516
valid H@20: 0.6860, N@20: 0.4325, M@20: 0.3581

test H@5: 0.4957, N@5: 0.3823, M@5: 0.3448
test H@10: 0.6026, N@10: 0.4167, M@10: 0.3589
test H@20: 0.7003, N@20: 0.4416, M@20: 0.3658

Epoch [97/500], Train Loss: 20.0154, Elapsed: train 0.55 test 1.70
valid H@5: 0.4873, N@5: 0.3748, M@5: 0.3374
valid H@10: 0.5919, N@10: 0.4087, M@10: 0.3515
valid H@20: 0.6873, N@20: 0.4329, M@20: 0.3582

test H@5: 0.4942, N@5: 0.3808, M@5: 0.3432
test H@10: 0.6016, N@10: 0.4156, M@10: 0.3576
test H@20: 0.7007, N@20: 0.4409, M@20: 0.3647

Epoch [98/500], Train Loss: 19.9199, Elapsed: train 0.44 test 1.73 *
valid H@5: 0.4882, N@5: 0.3750, M@5: 0.3374
valid H@10: 0.5942, N@10: 0.4093, M@10: 0.3516
valid H@20: 0.6869, N@20: 0.4328, M@20: 0.3581

test H@5: 0.4945, N@5: 0.3828, M@5: 0.3457
test H@10: 0.5995, N@10: 0.4168, M@10: 0.3598
test H@20: 0.7005

Epoch [120/500], Train Loss: 18.7800, Elapsed: train 0.52 test 1.91 *
valid H@5: 0.4991, N@5: 0.3823, M@5: 0.3436
valid H@10: 0.5957, N@10: 0.4136, M@10: 0.3566
valid H@20: 0.6911, N@20: 0.4378, M@20: 0.3633

test H@5: 0.5147, N@5: 0.3955, M@5: 0.3561
test H@10: 0.6105, N@10: 0.4265, M@10: 0.3689
test H@20: 0.7059, N@20: 0.4507, M@20: 0.3756

Epoch [121/500], Train Loss: 18.6241, Elapsed: train 0.51 test 1.81
valid H@5: 0.4980, N@5: 0.3825, M@5: 0.3442
valid H@10: 0.5978, N@10: 0.4149, M@10: 0.3577
valid H@20: 0.6909, N@20: 0.4385, M@20: 0.3642

test H@5: 0.5108, N@5: 0.3924, M@5: 0.3532
test H@10: 0.6120, N@10: 0.4252, M@10: 0.3669
test H@20: 0.7084, N@20: 0.4497, M@20: 0.3736

Epoch [122/500], Train Loss: 18.6301, Elapsed: train 0.49 test 1.74 *
valid H@5: 0.5001, N@5: 0.3829, M@5: 0.3441
valid H@10: 0.5978, N@10: 0.4146, M@10: 0.3573
valid H@20: 0.6898, N@20: 0.4379, M@20: 0.3637

test H@5: 0.5118, N@5: 0.3922, M@5: 0.3527
test H@10: 0.6128, N@10: 0.4251, M@10: 0.3663
test H@20: 0.7

Epoch [144/500], Train Loss: 17.6609, Elapsed: train 0.54 test 1.76 *
valid H@5: 0.5026, N@5: 0.3871, M@5: 0.3488
valid H@10: 0.6022, N@10: 0.4196, M@10: 0.3624
valid H@20: 0.6965, N@20: 0.4434, M@20: 0.3689

test H@5: 0.5160, N@5: 0.3988, M@5: 0.3599
test H@10: 0.6189, N@10: 0.4322, M@10: 0.3738
test H@20: 0.7124, N@20: 0.4558, M@20: 0.3803

Epoch [145/500], Train Loss: 17.5927, Elapsed: train 0.48 test 1.81 *
valid H@5: 0.5003, N@5: 0.3855, M@5: 0.3474
valid H@10: 0.6055, N@10: 0.4197, M@10: 0.3616
valid H@20: 0.6953, N@20: 0.4424, M@20: 0.3678

test H@5: 0.5150, N@5: 0.3994, M@5: 0.3610
test H@10: 0.6214, N@10: 0.4340, M@10: 0.3754
test H@20: 0.7111, N@20: 0.4566, M@20: 0.3816

Epoch [146/500], Train Loss: 17.6923, Elapsed: train 0.44 test 1.79
valid H@5: 0.5026, N@5: 0.3862, M@5: 0.3476
valid H@10: 0.6047, N@10: 0.4193, M@10: 0.3613
valid H@20: 0.6959, N@20: 0.4423, M@20: 0.3676

test H@5: 0.5160, N@5: 0.4008, M@5: 0.3626
test H@10: 0.6204, N@10: 0.4347, M@10: 0.3768
test H@20: 0.7

Epoch [168/500], Train Loss: 17.2043, Elapsed: train 0.46 test 1.78
valid H@5: 0.5062, N@5: 0.3891, M@5: 0.3502
valid H@10: 0.6068, N@10: 0.4217, M@10: 0.3637
valid H@20: 0.6988, N@20: 0.4450, M@20: 0.3701

test H@5: 0.5185, N@5: 0.4012, M@5: 0.3623
test H@10: 0.6189, N@10: 0.4336, M@10: 0.3757
test H@20: 0.7176, N@20: 0.4587, M@20: 0.3826

Epoch [169/500], Train Loss: 17.2313, Elapsed: train 0.44 test 1.73
valid H@5: 0.5043, N@5: 0.3874, M@5: 0.3486
valid H@10: 0.6049, N@10: 0.4200, M@10: 0.3621
valid H@20: 0.6984, N@20: 0.4437, M@20: 0.3687

test H@5: 0.5166, N@5: 0.4011, M@5: 0.3628
test H@10: 0.6197, N@10: 0.4345, M@10: 0.3766
test H@20: 0.7139, N@20: 0.4585, M@20: 0.3833

Epoch [170/500], Train Loss: 17.2694, Elapsed: train 0.54 test 1.82 *
valid H@5: 0.5045, N@5: 0.3881, M@5: 0.3495
valid H@10: 0.6041, N@10: 0.4204, M@10: 0.3629
valid H@20: 0.7011, N@20: 0.4450, M@20: 0.3696

test H@5: 0.5156, N@5: 0.4000, M@5: 0.3616
test H@10: 0.6193, N@10: 0.4335, M@10: 0.3755
test H@20: 0.714

Epoch [192/500], Train Loss: 16.6973, Elapsed: train 0.48 test 1.74 *
valid H@5: 0.5101, N@5: 0.3945, M@5: 0.3562
valid H@10: 0.6118, N@10: 0.4273, M@10: 0.3697
valid H@20: 0.7007, N@20: 0.4498, M@20: 0.3759

test H@5: 0.5194, N@5: 0.4026, M@5: 0.3638
test H@10: 0.6256, N@10: 0.4370, M@10: 0.3781
test H@20: 0.7158, N@20: 0.4599, M@20: 0.3844

Epoch [193/500], Train Loss: 16.6581, Elapsed: train 0.46 test 2.01 *
valid H@5: 0.5108, N@5: 0.3945, M@5: 0.3560
valid H@10: 0.6122, N@10: 0.4271, M@10: 0.3693
valid H@20: 0.7009, N@20: 0.4495, M@20: 0.3754

test H@5: 0.5173, N@5: 0.4016, M@5: 0.3632
test H@10: 0.6266, N@10: 0.4371, M@10: 0.3779
test H@20: 0.7162, N@20: 0.4599, M@20: 0.3842

Epoch [194/500], Train Loss: 16.7940, Elapsed: train 0.56 test 1.85 *
valid H@5: 0.5081, N@5: 0.3950, M@5: 0.3574
valid H@10: 0.6128, N@10: 0.4288, M@10: 0.3713
valid H@20: 0.7024, N@20: 0.4514, M@20: 0.3775

test H@5: 0.5214, N@5: 0.4047, M@5: 0.3660
test H@10: 0.6266, N@10: 0.4388, M@10: 0.3801
test H@20: 0

Epoch [216/500], Train Loss: 16.5284, Elapsed: train 0.47 test 1.76
valid H@5: 0.5074, N@5: 0.3940, M@5: 0.3563
valid H@10: 0.6122, N@10: 0.4280, M@10: 0.3704
valid H@20: 0.6992, N@20: 0.4501, M@20: 0.3766

test H@5: 0.5254, N@5: 0.4078, M@5: 0.3688
test H@10: 0.6258, N@10: 0.4405, M@10: 0.3824
test H@20: 0.7193, N@20: 0.4643, M@20: 0.3890

Epoch [217/500], Train Loss: 16.6604, Elapsed: train 0.55 test 1.78
valid H@5: 0.5078, N@5: 0.3935, M@5: 0.3555
valid H@10: 0.6107, N@10: 0.4268, M@10: 0.3693
valid H@20: 0.6988, N@20: 0.4492, M@20: 0.3755

test H@5: 0.5252, N@5: 0.4078, M@5: 0.3689
test H@10: 0.6271, N@10: 0.4408, M@10: 0.3825
test H@20: 0.7214, N@20: 0.4647, M@20: 0.3891

Epoch [218/500], Train Loss: 16.5670, Elapsed: train 0.47 test 1.79
valid H@5: 0.5078, N@5: 0.3948, M@5: 0.3573
valid H@10: 0.6108, N@10: 0.4282, M@10: 0.3711
valid H@20: 0.6994, N@20: 0.4508, M@20: 0.3774

test H@5: 0.5240, N@5: 0.4084, M@5: 0.3700
test H@10: 0.6264, N@10: 0.4415, M@10: 0.3837
test H@20: 0.7220,

Epoch [240/500], Train Loss: 16.4269, Elapsed: train 0.55 test 1.91
valid H@5: 0.5118, N@5: 0.3959, M@5: 0.3575
valid H@10: 0.6153, N@10: 0.4294, M@10: 0.3713
valid H@20: 0.7043, N@20: 0.4521, M@20: 0.3776

test H@5: 0.5210, N@5: 0.4074, M@5: 0.3697
test H@10: 0.6300, N@10: 0.4427, M@10: 0.3843
test H@20: 0.7229, N@20: 0.4662, M@20: 0.3907

Epoch [241/500], Train Loss: 16.2538, Elapsed: train 0.48 test 1.90
valid H@5: 0.5095, N@5: 0.3953, M@5: 0.3573
valid H@10: 0.6147, N@10: 0.4294, M@10: 0.3715
valid H@20: 0.7057, N@20: 0.4525, M@20: 0.3778

test H@5: 0.5233, N@5: 0.4090, M@5: 0.3711
test H@10: 0.6285, N@10: 0.4430, M@10: 0.3851
test H@20: 0.7256, N@20: 0.4676, M@20: 0.3919

Epoch [242/500], Train Loss: 16.4864, Elapsed: train 0.52 test 1.80 *
valid H@5: 0.5104, N@5: 0.3965, M@5: 0.3586
valid H@10: 0.6160, N@10: 0.4307, M@10: 0.3728
valid H@20: 0.7045, N@20: 0.4532, M@20: 0.3790

test H@5: 0.5235, N@5: 0.4089, M@5: 0.3708
test H@10: 0.6312, N@10: 0.4437, M@10: 0.3852
test H@20: 0.724

Epoch [264/500], Train Loss: 16.3850, Elapsed: train 0.45 test 1.93 *
valid H@5: 0.5187, N@5: 0.4014, M@5: 0.3624
valid H@10: 0.6195, N@10: 0.4341, M@10: 0.3760
valid H@20: 0.7097, N@20: 0.4569, M@20: 0.3823

test H@5: 0.5286, N@5: 0.4095, M@5: 0.3700
test H@10: 0.6319, N@10: 0.4430, M@10: 0.3839
test H@20: 0.7208, N@20: 0.4656, M@20: 0.3902

Epoch [265/500], Train Loss: 16.2632, Elapsed: train 0.50 test 1.80
valid H@5: 0.5170, N@5: 0.4010, M@5: 0.3625
valid H@10: 0.6179, N@10: 0.4338, M@10: 0.3762
valid H@20: 0.7099, N@20: 0.4572, M@20: 0.3826

test H@5: 0.5263, N@5: 0.4086, M@5: 0.3695
test H@10: 0.6310, N@10: 0.4426, M@10: 0.3836
test H@20: 0.7214, N@20: 0.4657, M@20: 0.3901

Epoch [266/500], Train Loss: 16.4641, Elapsed: train 0.47 test 1.80
valid H@5: 0.5156, N@5: 0.4009, M@5: 0.3628
valid H@10: 0.6187, N@10: 0.4345, M@10: 0.3768
valid H@20: 0.7103, N@20: 0.4576, M@20: 0.3832

test H@5: 0.5269, N@5: 0.4088, M@5: 0.3696
test H@10: 0.6321, N@10: 0.4429, M@10: 0.3837
test H@20: 0.723

Epoch [288/500], Train Loss: 16.1825, Elapsed: train 0.45 test 1.81
valid H@5: 0.5141, N@5: 0.3972, M@5: 0.3584
valid H@10: 0.6154, N@10: 0.4302, M@10: 0.3722
valid H@20: 0.7089, N@20: 0.4539, M@20: 0.3787

test H@5: 0.5275, N@5: 0.4104, M@5: 0.3715
test H@10: 0.6329, N@10: 0.4447, M@10: 0.3858
test H@20: 0.7256, N@20: 0.4682, M@20: 0.3923

Epoch [289/500], Train Loss: 16.1864, Elapsed: train 0.47 test 1.79
valid H@5: 0.5108, N@5: 0.3953, M@5: 0.3570
valid H@10: 0.6156, N@10: 0.4296, M@10: 0.3713
valid H@20: 0.7082, N@20: 0.4530, M@20: 0.3778

test H@5: 0.5304, N@5: 0.4112, M@5: 0.3716
test H@10: 0.6338, N@10: 0.4448, M@10: 0.3856
test H@20: 0.7249, N@20: 0.4679, M@20: 0.3920

Epoch [290/500], Train Loss: 15.8208, Elapsed: train 0.57 test 1.85
valid H@5: 0.5145, N@5: 0.3975, M@5: 0.3587
valid H@10: 0.6139, N@10: 0.4299, M@10: 0.3722
valid H@20: 0.7082, N@20: 0.4539, M@20: 0.3789

test H@5: 0.5300, N@5: 0.4109, M@5: 0.3714
test H@10: 0.6312, N@10: 0.4438, M@10: 0.3851
test H@20: 0.7268,