In [1]:
import numpy as np
import random
import torch
from src.dataloaders import Dataset100k
from src.models import GMFBCEModel, MLPBCEModel, NeuralMatrixFactorizationBCEModel
from src.trainer import Trainer
from src.metrics import hitratio, ndcg

np.random.seed(42)
random.seed(42)
torch.manual_seed(42)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = 'cpu'
print(f"{device=}")

device='cuda'


In [2]:
class config:
    data_dir = 'ml-100k'
    epochs = 40
    batch_size = 2048
    gmf_embed_size = 16
    mlp_embed_size = 32
    layers = [32, 16, 8]
    lr = 0.001
    b1 = 0.3
    b2 = 0.6
    weight_decay = 0.0001

dataset = Dataset100k(config.data_dir)
dataset.gen_adjacency()
dataset.make_train_test()
print(f"{dataset.train_size=}, {dataset.test_size=}")

metrics = {
    "HR@1": (hitratio, {"top_n": 1}),
    "HR@5": (hitratio, {"top_n": 5}),
    "HR@10": (hitratio, {"top_n": 10}),
    "NDCG@1": (ndcg, {"top_n": 1}),
    "NDCG@5": (ndcg, {"top_n": 5}),
    "NDCG@10": (ndcg, {"top_n": 10}),
}

dataset.train_size=198114, dataset.test_size=943


In [3]:
model = NeuralMatrixFactorizationBCEModel(dataset.user_count, dataset.item_count, gmf_embed_size=config.gmf_embed_size, mlp_embed_size=config.mlp_embed_size, layers=config.layers, alpha=0.5)

gmf_model = GMFBCEModel(dataset.user_count, dataset.item_count, embed_size=config.gmf_embed_size)
gmf_model.load_state_dict(torch.load("saved_models/gmfbce.pt"))
mlp_model = MLPBCEModel(dataset.user_count, dataset.item_count, embed_size=config.mlp_embed_size, layers=config.layers)
mlp_model.load_state_dict(torch.load("saved_models/mlpbce.pt"))

model.load_pretrained_weights(gmf_model, mlp_model)
del(gmf_model)
del(mlp_model)

# optimizer = torch.optim.SGD(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)
optimizer = torch.optim.Adam(model.parameters(), lr=config.lr, betas=(config.b1, config.b2), weight_decay=config.weight_decay)
# optimizer = torch.optim.Adam(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)

trainer = Trainer(dataset, model, optimizer, metrics, 
                  epochs=config.epochs, batch_size=config.batch_size,
                  device=device)

In [4]:
trainer.train(evaluate=True, verbose=True, progressbar=True)
# trainer.test(verbose=False, pbar=False)

                                                                   

Epoch 0: Avg Loss/Batch 0.308001            


                                                   

HR@1: 0.2873806998939555
HR@5: 0.6256627783669141
HR@10: 0.7720042417815483
NDCG@1: 0.2873806998939555
NDCG@5: 0.4653906911234425
NDCG@10: 0.5133117450220354


                                                                   

Epoch 1: Avg Loss/Batch 0.307314            


                                                   

HR@1: 0.28207847295864263
HR@5: 0.6224814422057264
HR@10: 0.7762460233297985
NDCG@1: 0.28207847295864263
NDCG@5: 0.4612976597137803
NDCG@10: 0.5113373628812248


                                                                   

Epoch 2: Avg Loss/Batch 0.308059            


                                                   

HR@1: 0.2831389183457052
HR@5: 0.6299045599151644
HR@10: 0.7720042417815483
NDCG@1: 0.2831389183457052
NDCG@5: 0.4648077004507823
NDCG@10: 0.5111191228855387


                                                                   

Epoch 3: Avg Loss/Batch 0.309015            


                                                   

HR@1: 0.28950159066808057
HR@5: 0.6277836691410392
HR@10: 0.7698833510074231
NDCG@1: 0.28950159066808057
NDCG@5: 0.46715846445285186
NDCG@10: 0.5137269743885267


                                                                   

Epoch 4: Avg Loss/Batch 0.309796            


                                                   

HR@1: 0.28101802757158006
HR@5: 0.6256627783669141
HR@10: 0.7698833510074231
NDCG@1: 0.28101802757158006
NDCG@5: 0.4632722913094238
NDCG@10: 0.5103316416414804


                                                                   

Epoch 5: Avg Loss/Batch 0.310136            


                                                   

HR@1: 0.28525980911983034
HR@5: 0.6267232237539767
HR@10: 0.7645811240721103
NDCG@1: 0.28525980911983034
NDCG@5: 0.46398399709048394
NDCG@10: 0.5092502130433262


                                                                   

Epoch 6: Avg Loss/Batch 0.310052            


                                                   

HR@1: 0.2863202545068929
HR@5: 0.6246023329798516
HR@10: 0.7667020148462355
NDCG@1: 0.2863202545068929
NDCG@5: 0.46300029552217087
NDCG@10: 0.5095025625544393


                                                                   

Epoch 7: Avg Loss/Batch 0.309697            


                                                   

HR@1: 0.28101802757158006
HR@5: 0.6224814422057264
HR@10: 0.7667020148462355
NDCG@1: 0.28101802757158006
NDCG@5: 0.46007226261803763
NDCG@10: 0.5070789123988119


                                                                   

Epoch 8: Avg Loss/Batch 0.309248            


                                                   

HR@1: 0.27465535524920465
HR@5: 0.6214209968186638
HR@10: 0.7624602332979852
NDCG@1: 0.27465535524920465
NDCG@5: 0.45567318767579257
NDCG@10: 0.5017712466491344


                                                                   

Epoch 9: Avg Loss/Batch 0.308751            


                                                   

HR@1: 0.26935312831389185
HR@5: 0.6214209968186638
HR@10: 0.7592788971367974
NDCG@1: 0.26935312831389185
NDCG@5: 0.45465560469842214
NDCG@10: 0.49968905893717436


                                                                    

Epoch 10: Avg Loss/Batch 0.308159            


                                                   

HR@1: 0.264050901378579
HR@5: 0.6214209968186638
HR@10: 0.7582184517497349
NDCG@1: 0.264050901378579
NDCG@5: 0.4520943721453894
NDCG@10: 0.4967338568076701


                                                                    

Epoch 11: Avg Loss/Batch 0.307522            


                                                   

HR@1: 0.2598091198303287
HR@5: 0.616118769883351
HR@10: 0.76033934252386
NDCG@1: 0.2598091198303287
NDCG@5: 0.4471206841655656
NDCG@10: 0.4939382400679523


                                                                    

Epoch 12: Avg Loss/Batch 0.306744            


                                                   

HR@1: 0.26193001060445387
HR@5: 0.6108165429480382
HR@10: 0.7582184517497349
NDCG@1: 0.26193001060445387
NDCG@5: 0.44457754077717826
NDCG@10: 0.4924287091896132


                                                                    

Epoch 13: Avg Loss/Batch 0.305860            


                                                   

HR@1: 0.26935312831389185
HR@5: 0.6076352067868505
HR@10: 0.7613997879109226
NDCG@1: 0.26935312831389185
NDCG@5: 0.445931272791135
NDCG@10: 0.49557679414713696


                                                                    

Epoch 14: Avg Loss/Batch 0.304769            


                                                   

HR@1: 0.26193001060445387
HR@5: 0.6055143160127253
HR@10: 0.7624602332979852
NDCG@1: 0.26193001060445387
NDCG@5: 0.44133218856743484
NDCG@10: 0.49212684857985817


                                                                    

Epoch 15: Avg Loss/Batch 0.303658            


                                                   

HR@1: 0.2576882290562036
HR@5: 0.6065747613997879
HR@10: 0.7635206786850477
NDCG@1: 0.2576882290562036
NDCG@5: 0.4397470752052187
NDCG@10: 0.4904219932071109


                                                                    

Epoch 16: Avg Loss/Batch 0.302514            


                                                   

HR@1: 0.25556733828207845
HR@5: 0.6033934252386002
HR@10: 0.7613997879109226
NDCG@1: 0.25556733828207845
NDCG@5: 0.4381628687456052
NDCG@10: 0.4891758346619752


                                                                    

Epoch 17: Avg Loss/Batch 0.301325            


                                                   

HR@1: 0.25556733828207845
HR@5: 0.6033934252386002
HR@10: 0.7613997879109226
NDCG@1: 0.25556733828207845
NDCG@5: 0.4372587821463339
NDCG@10: 0.488236470012524


                                                                    

Epoch 18: Avg Loss/Batch 0.300174            


                                                   

HR@1: 0.2523860021208908
HR@5: 0.601272534464475
HR@10: 0.7635206786850477
NDCG@1: 0.2523860021208908
NDCG@5: 0.435090104597346
NDCG@10: 0.4873528886431732


                                                                    

Epoch 19: Avg Loss/Batch 0.299078            


                                                   

HR@1: 0.2513255567338282
HR@5: 0.6044538706256628
HR@10: 0.7667020148462355
NDCG@1: 0.2513255567338282
NDCG@5: 0.4366952515578566
NDCG@10: 0.48910601156870787


                                                                    

Epoch 20: Avg Loss/Batch 0.298003            


                                                   

HR@1: 0.2523860021208908
HR@5: 0.6076352067868505
HR@10: 0.7613997879109226
NDCG@1: 0.2523860021208908
NDCG@5: 0.43696910580416864
NDCG@10: 0.48668988909467825


                                                                    

Epoch 21: Avg Loss/Batch 0.296840            


                                                   

HR@1: 0.24708377518557795
HR@5: 0.6065747613997879
HR@10: 0.7709437963944857
NDCG@1: 0.24708377518557795
NDCG@5: 0.43471186213694585
NDCG@10: 0.48725988106516527


                                                                    

Epoch 22: Avg Loss/Batch 0.295787            


                                                   

HR@1: 0.2492046659597031
HR@5: 0.6097560975609756
HR@10: 0.76033934252386
NDCG@1: 0.2492046659597031
NDCG@5: 0.43512644854696214
NDCG@10: 0.48356801600264065


                                                                    

Epoch 23: Avg Loss/Batch 0.294791            


                                                   

HR@1: 0.24814422057264052
HR@5: 0.6097560975609756
HR@10: 0.7497348886532343
NDCG@1: 0.24814422057264052
NDCG@5: 0.4349643687045212
NDCG@10: 0.48030748247966143


                                                                    

Epoch 24: Avg Loss/Batch 0.293740            


                                                   

HR@1: 0.24284199363732767
HR@5: 0.6033934252386002
HR@10: 0.7582184517497349
NDCG@1: 0.24284199363732767
NDCG@5: 0.430822596099752
NDCG@10: 0.4813130161908322


                                                                    

Epoch 25: Avg Loss/Batch 0.292733            


                                                   

HR@1: 0.2492046659597031
HR@5: 0.6023329798515377
HR@10: 0.7613997879109226
NDCG@1: 0.2492046659597031
NDCG@5: 0.4314318311127188
NDCG@10: 0.4830442611022762


                                                                    

Epoch 26: Avg Loss/Batch 0.291819            


                                                   

HR@1: 0.2417815482502651
HR@5: 0.601272534464475
HR@10: 0.7571580063626723
NDCG@1: 0.2417815482502651
NDCG@5: 0.42858667169171816
NDCG@10: 0.4795133707076591


                                                                    

Epoch 27: Avg Loss/Batch 0.290843            


                                                   

HR@1: 0.2449628844114528
HR@5: 0.5980911983032874
HR@10: 0.7582184517497349
NDCG@1: 0.2449628844114528
NDCG@5: 0.4278353058919086
NDCG@10: 0.47970735726735575


                                                                    

Epoch 28: Avg Loss/Batch 0.289901            


                                                   

HR@1: 0.24602332979851538
HR@5: 0.5991516436903499
HR@10: 0.7529162248144221
NDCG@1: 0.24602332979851538
NDCG@5: 0.43042872701526524
NDCG@10: 0.4802707789848953


                                                                    

Epoch 29: Avg Loss/Batch 0.289160            


                                                   

HR@1: 0.2417815482502651
HR@5: 0.5949098621420997
HR@10: 0.7613997879109226
NDCG@1: 0.2417815482502651
NDCG@5: 0.42728759668333804
NDCG@10: 0.4812965600169047


                                                                    

Epoch 30: Avg Loss/Batch 0.288276            


                                                   

HR@1: 0.24284199363732767
HR@5: 0.601272534464475
HR@10: 0.7592788971367974
NDCG@1: 0.24284199363732767
NDCG@5: 0.4301270904568903
NDCG@10: 0.48101811650576703


                                                                    

Epoch 31: Avg Loss/Batch 0.287408            


                                                   

HR@1: 0.2417815482502651
HR@5: 0.6033934252386002
HR@10: 0.7571580063626723
NDCG@1: 0.2417815482502651
NDCG@5: 0.43034210389878563
NDCG@10: 0.4799277823880227


                                                                    

Epoch 32: Avg Loss/Batch 0.286746            


                                                   

HR@1: 0.23753976670201485
HR@5: 0.6002120890774125
HR@10: 0.7571580063626723
NDCG@1: 0.23753976670201485
NDCG@5: 0.4265204649108231
NDCG@10: 0.47709790875501074


                                                                    

Epoch 33: Avg Loss/Batch 0.285868            


                                                   

HR@1: 0.24284199363732767
HR@5: 0.5959703075291622
HR@10: 0.7645811240721103
NDCG@1: 0.24284199363732767
NDCG@5: 0.42727998660224076
NDCG@10: 0.4815969845838107


                                                                    

Epoch 34: Avg Loss/Batch 0.285125            


                                                   

HR@1: 0.23860021208907742
HR@5: 0.5949098621420997
HR@10: 0.7529162248144221
NDCG@1: 0.23860021208907742
NDCG@5: 0.42499189686468375
NDCG@10: 0.4765133243978178


                                                                    

Epoch 35: Avg Loss/Batch 0.284349            


                                                   

HR@1: 0.23860021208907742
HR@5: 0.5970307529162248
HR@10: 0.7550371155885471
NDCG@1: 0.23860021208907742
NDCG@5: 0.42566476773602097
NDCG@10: 0.4771447270667739


                                                                    

Epoch 36: Avg Loss/Batch 0.283736            


                                                   

HR@1: 0.23329798515376457
HR@5: 0.5991516436903499
HR@10: 0.7613997879109226
NDCG@1: 0.23329798515376457
NDCG@5: 0.4230835285723474
NDCG@10: 0.4756477400325269


                                                                    

Epoch 37: Avg Loss/Batch 0.283152            


                                                   

HR@1: 0.23223753976670203
HR@5: 0.5959703075291622
HR@10: 0.7592788971367974
NDCG@1: 0.23223753976670203
NDCG@5: 0.4213991752525283
NDCG@10: 0.474507502981838


                                                                    

Epoch 38: Avg Loss/Batch 0.282515            


                                                   

HR@1: 0.23117709437963946
HR@5: 0.5896076352067868
HR@10: 0.7539766702014846
NDCG@1: 0.23117709437963946
NDCG@5: 0.4183579977409337
NDCG@10: 0.4719181918837519


                                                                    

Epoch 39: Avg Loss/Batch 0.281985            


                                                   

HR@1: 0.22799575821845175
HR@5: 0.5938494167550371
HR@10: 0.76033934252386
NDCG@1: 0.22799575821845175
NDCG@5: 0.4176683503534991
NDCG@10: 0.47183542926737776


In [None]:
best_epoch = np.argmax([r["NDCG@10"] for r in trainer.test_log])
print(f"{best_epoch}: {trainer.test_log[best_epoch]}")

22: {'HR@1': 0.2841993637327678, 'HR@5': 0.6224814422057264, 'HR@10': 0.7751855779427359, 'NDCG@1': 0.2841993637327678, 'NDCG@5': 0.4620285833055621, 'NDCG@10': 0.51180795041922}


In [None]:
best_epoch = np.argmax([r["NDCG@10"] for r in trainer.test_log])
print(f"{best_epoch}: {trainer.test_log[best_epoch]}")

3: {'HR@1': 0.28950159066808057, 'HR@5': 0.6277836691410392, 'HR@10': 0.7698833510074231, 'NDCG@1': 0.28950159066808057, 'NDCG@5': 0.46715846445285186, 'NDCG@10': 0.5137269743885267}


In [None]:
torch.save(trainer.model.state_dict(), "saved_models/nmfbce.pt")
# trainer.model.load_state_dict(torch.load("saved_models/nmfbce.pt"))