In [1]:
import numpy as np
import random
import torch
from src.data import Dataset, PairwiseDataset
from src.models import MatrixFactorizationModel, MatrixFactorizationBPRModel
from src.trainer import Trainer, BPRTrainer
from src.evaluation import hitratio, ndcg

np.random.seed(42)
random.seed(42)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = 'cpu'
print(f"{device=}")

device='cuda'


In [2]:
class config:
    data_dir = 'ml-100k'
    neg_count = 16
    epochs = 50
    batch_size = 256
    dim = 40
    lr = 0.1

In [3]:
# dataset = Dataset(config.data_dir)
dataset = PairwiseDataset(config.data_dir)
dataset.gen_adjacency()
dataset.make_train_test()

# model = MatrixFactorizationModel(dataset.user_count, dataset.item_count, config.dim)
model = MatrixFactorizationBPRModel(dataset.user_count, dataset.item_count, config.dim)
optimizer = torch.optim.SGD(model.parameters(), lr=config.lr, weight_decay=0.01)

metrics = {
    "HR@1": (hitratio, {"top_n": 1}),
    "HR@5": (hitratio, {"top_n": 5}),
    "HR@10": (hitratio, {"top_n": 10}),
    "NDCG@1": (ndcg, {"top_n": 1}),
    "NDCG@5": (ndcg, {"top_n": 5}),
    "NDCG@10": (ndcg, {"top_n": 10}),
}

# trainer = Trainer(dataset, model, optimizer, metrics, 
#                   epochs=config.epochs, batch_size=config.batch_size,
#                   device=device)
trainer = BPRTrainer(dataset, model, optimizer, metrics, 
                     epochs=config.epochs, batch_size=config.batch_size,
                     device=device)

In [4]:
trainer.train(evaluate=True)

                                                                      

Epoch 0: Avg Loss/Batch 110.081229          


                                                   

HR@1: 0.007423117709437964
HR@5: 0.06468716861081654
HR@10: 0.11983032873807
NDCG@1: 0.007423117709437964
NDCG@5: 0.0360929127205354
NDCG@10: 0.05361513171912654


                                                                     

Epoch 1: Avg Loss/Batch 73.822857           


                                                   

HR@1: 0.032873806998939555
HR@5: 0.12513255567338283
HR@10: 0.19618239660657477
NDCG@1: 0.032873806998939555
NDCG@5: 0.07792015515288123
NDCG@10: 0.10051049651197438


                                                                     

Epoch 2: Avg Loss/Batch 36.437239           


                                                   

HR@1: 0.10180275715800637
HR@5: 0.3372216330858961
HR@10: 0.4973488865323436
NDCG@1: 0.10180275715800637
NDCG@5: 0.21871046609919292
NDCG@10: 0.27062460123812276


                                                                     

Epoch 3: Avg Loss/Batch 24.501098           


                                                   

HR@1: 0.1919406150583245
HR@5: 0.5015906680805938
HR@10: 0.672322375397667
NDCG@1: 0.1919406150583245
NDCG@5: 0.3510729456253505
NDCG@10: 0.4065080625966442


                                                                     

Epoch 4: Avg Loss/Batch 22.280228           


                                                   

HR@1: 0.23647932131495228
HR@5: 0.5800636267232238
HR@10: 0.7232237539766702
NDCG@1: 0.23647932131495228
NDCG@5: 0.4162464390221339
NDCG@10: 0.46280380032806595


                                                                     

Epoch 5: Avg Loss/Batch 21.488761           


                                                   

HR@1: 0.2598091198303287
HR@5: 0.5874867444326617
HR@10: 0.7338282078472959
NDCG@1: 0.2598091198303287
NDCG@5: 0.43059499208844054
NDCG@10: 0.47857700984572665


                                                                     

Epoch 6: Avg Loss/Batch 21.071222           


                                                   

HR@1: 0.264050901378579
HR@5: 0.6055143160127253
HR@10: 0.7465535524920467
NDCG@1: 0.264050901378579
NDCG@5: 0.44102119153418684
NDCG@10: 0.48688427701001097


                                                                     

Epoch 7: Avg Loss/Batch 20.828900           


                                                   

HR@1: 0.26617179215270415
HR@5: 0.6076352067868505
HR@10: 0.7560975609756098
NDCG@1: 0.26617179215270415
NDCG@5: 0.44324378552242055
NDCG@10: 0.4914702268162411


                                                                     

Epoch 8: Avg Loss/Batch 20.677551           


                                                   

HR@1: 0.2672322375397667
HR@5: 0.6055143160127253
HR@10: 0.7592788971367974
NDCG@1: 0.2672322375397667
NDCG@5: 0.44272097887011364
NDCG@10: 0.49291767638631195


                                                                     

Epoch 9: Avg Loss/Batch 20.573831           


                                                   

HR@1: 0.2704135737009544
HR@5: 0.6097560975609756
HR@10: 0.7624602332979852
NDCG@1: 0.2704135737009544
NDCG@5: 0.44553663472076993
NDCG@10: 0.4952580359229108


                                                                      

Epoch 10: Avg Loss/Batch 20.496696           


                                                   

HR@1: 0.2757158006362672
HR@5: 0.6139978791092259
HR@10: 0.7656415694591728
NDCG@1: 0.2757158006362672
NDCG@5: 0.44895983131376554
NDCG@10: 0.498202897198879


                                                                      

Epoch 11: Avg Loss/Batch 20.436227           


                                                   

HR@1: 0.2757158006362672
HR@5: 0.6171792152704135
HR@10: 0.7656415694591728
NDCG@1: 0.2757158006362672
NDCG@5: 0.45059582362866357
NDCG@10: 0.49869543711310105


                                                                      

Epoch 12: Avg Loss/Batch 20.387533           


                                                   

HR@1: 0.2788971367974549
HR@5: 0.616118769883351
HR@10: 0.7667020148462355
NDCG@1: 0.2788971367974549
NDCG@5: 0.4521004158351824
NDCG@10: 0.500999257338095


                                                                      

Epoch 13: Avg Loss/Batch 20.347859           


                                                   

HR@1: 0.2788971367974549
HR@5: 0.6214209968186638
HR@10: 0.7688229056203606
NDCG@1: 0.2788971367974549
NDCG@5: 0.45432566568790095
NDCG@10: 0.5020479891707565


                                                                      

Epoch 14: Avg Loss/Batch 20.315368           


                                                   

HR@1: 0.2788971367974549
HR@5: 0.6246023329798516
HR@10: 0.7688229056203606
NDCG@1: 0.2788971367974549
NDCG@5: 0.4553527747604697
NDCG@10: 0.5020332230732373


                                                                      

Epoch 15: Avg Loss/Batch 20.288670           


                                                   

HR@1: 0.28101802757158006
HR@5: 0.623541887592789
HR@10: 0.7698833510074231
NDCG@1: 0.28101802757158006
NDCG@5: 0.45573654423564663
NDCG@10: 0.5031022904032139


                                                                      

Epoch 16: Avg Loss/Batch 20.266652           


                                                   

HR@1: 0.28207847295864263
HR@5: 0.6246023329798516
HR@10: 0.7720042417815483
NDCG@1: 0.28207847295864263
NDCG@5: 0.45676994859292364
NDCG@10: 0.5043772915290267


                                                                      

Epoch 17: Avg Loss/Batch 20.248408           


                                                   

HR@1: 0.2788971367974549
HR@5: 0.6246023329798516
HR@10: 0.7720042417815483
NDCG@1: 0.2788971367974549
NDCG@5: 0.45538345449453155
NDCG@10: 0.5030593092814709


                                                                      

Epoch 18: Avg Loss/Batch 20.233205           


                                                   

HR@1: 0.2788971367974549
HR@5: 0.6203605514316013
HR@10: 0.7730646871686108
NDCG@1: 0.2788971367974549
NDCG@5: 0.4538354547837569
NDCG@10: 0.5033399199032118


                                                                      

Epoch 19: Avg Loss/Batch 20.220456           


                                                   

HR@1: 0.2767762460233298
HR@5: 0.6203605514316013
HR@10: 0.7741251325556734
NDCG@1: 0.2767762460233298
NDCG@5: 0.45368158624012533
NDCG@10: 0.5035547423385704


                                                                      

Epoch 20: Avg Loss/Batch 20.209694           


                                                   

HR@1: 0.2767762460233298
HR@5: 0.6214209968186638
HR@10: 0.7730646871686108
NDCG@1: 0.2767762460233298
NDCG@5: 0.45409182251502944
NDCG@10: 0.5033625195910126


                                                                      

Epoch 21: Avg Loss/Batch 20.200548           


                                                   

HR@1: 0.2767762460233298
HR@5: 0.6214209968186638
HR@10: 0.7741251325556734
NDCG@1: 0.2767762460233298
NDCG@5: 0.45406478148524393
NDCG@10: 0.5036815798986761


                                                                      

Epoch 22: Avg Loss/Batch 20.192726           


                                                   

HR@1: 0.2757158006362672
HR@5: 0.6214209968186638
HR@10: 0.7741251325556734
NDCG@1: 0.2757158006362672
NDCG@5: 0.4537198753393376
NDCG@10: 0.5034590064994454


                                                                      

Epoch 23: Avg Loss/Batch 20.185996           


                                                   

HR@1: 0.27465535524920465
HR@5: 0.6214209968186638
HR@10: 0.7741251325556734
NDCG@1: 0.27465535524920465
NDCG@5: 0.45344848291763756
NDCG@10: 0.5032155116696866


                                                                      

Epoch 24: Avg Loss/Batch 20.180172           


                                                   

HR@1: 0.27465535524920465
HR@5: 0.6214209968186638
HR@10: 0.7730646871686108
NDCG@1: 0.27465535524920465
NDCG@5: 0.4535873267708417
NDCG@10: 0.5030884040618395


                                                                      

Epoch 25: Avg Loss/Batch 20.175104           


                                                   

HR@1: 0.27465535524920465
HR@5: 0.6214209968186638
HR@10: 0.7741251325556734
NDCG@1: 0.27465535524920465
NDCG@5: 0.4535873267708417
NDCG@10: 0.5034076299321089


                                                                      

Epoch 26: Avg Loss/Batch 20.170672           


                                                   

HR@1: 0.27465535524920465
HR@5: 0.6214209968186638
HR@10: 0.7751855779427359
NDCG@1: 0.27465535524920465
NDCG@5: 0.453660840495048
NDCG@10: 0.503787681117946


                                                                      

Epoch 27: Avg Loss/Batch 20.166779           


                                                   

HR@1: 0.27465535524920465
HR@5: 0.6214209968186638
HR@10: 0.7741251325556734
NDCG@1: 0.27465535524920465
NDCG@5: 0.4537343542192544
NDCG@10: 0.5035493494301216


                                                                      

Epoch 28: Avg Loss/Batch 20.163343           


                                                   

HR@1: 0.2757158006362672
HR@5: 0.6224814422057264
HR@10: 0.7741251325556734
NDCG@1: 0.2757158006362672
NDCG@5: 0.45453596933448565
NDCG@10: 0.5039732262769457


                                                                      

Epoch 29: Avg Loss/Batch 20.160297           


                                                   

HR@1: 0.2757158006362672
HR@5: 0.6224814422057264
HR@10: 0.7741251325556734
NDCG@1: 0.2757158006362672
NDCG@5: 0.4548218406361024
NDCG@10: 0.5042590975785626


                                                                      

Epoch 30: Avg Loss/Batch 20.157588           


                                                   

HR@1: 0.2757158006362672
HR@5: 0.6224814422057264
HR@10: 0.7741251325556734
NDCG@1: 0.2757158006362672
NDCG@5: 0.4548218406361024
NDCG@10: 0.5042780461008824


                                                                      

Epoch 31: Avg Loss/Batch 20.155168           


                                                   

HR@1: 0.2757158006362672
HR@5: 0.623541887592789
HR@10: 0.7741251325556734
NDCG@1: 0.2757158006362672
NDCG@5: 0.4551856042165857
NDCG@10: 0.5042640714129585


                                                                      

Epoch 32: Avg Loss/Batch 20.153000           


                                                   

HR@1: 0.2757158006362672
HR@5: 0.623541887592789
HR@10: 0.7741251325556734
NDCG@1: 0.2757158006362672
NDCG@5: 0.45530559063521286
NDCG@10: 0.5044083143043053


                                                                      

Epoch 33: Avg Loss/Batch 20.151049           


                                                   

HR@1: 0.2757158006362672
HR@5: 0.623541887592789
HR@10: 0.7741251325556734
NDCG@1: 0.2757158006362672
NDCG@5: 0.45530559063521286
NDCG@10: 0.5044083143043053


                                                                      

Epoch 34: Avg Loss/Batch 20.149290           


                                                   

HR@1: 0.2757158006362672
HR@5: 0.623541887592789
HR@10: 0.7751855779427359
NDCG@1: 0.2757158006362672
NDCG@5: 0.45530559063521286
NDCG@10: 0.5047301591690345


                                                                      

Epoch 35: Avg Loss/Batch 20.147699           


                                                   

HR@1: 0.2757158006362672
HR@5: 0.623541887592789
HR@10: 0.7751855779427359
NDCG@1: 0.2757158006362672
NDCG@5: 0.4554255770538401
NDCG@10: 0.5048744020603815


                                                                      

Epoch 36: Avg Loss/Batch 20.146256           


                                                   

HR@1: 0.2757158006362672
HR@5: 0.623541887592789
HR@10: 0.7751855779427359
NDCG@1: 0.2757158006362672
NDCG@5: 0.4554990907780465
NDCG@10: 0.504928967262268


                                                                      

Epoch 37: Avg Loss/Batch 20.144945           


                                                   

HR@1: 0.2757158006362672
HR@5: 0.623541887592789
HR@10: 0.7751855779427359
NDCG@1: 0.2757158006362672
NDCG@5: 0.4555455634724674
NDCG@10: 0.5049754399566888


                                                                      

Epoch 38: Avg Loss/Batch 20.143749           


                                                   

HR@1: 0.2757158006362672
HR@5: 0.623541887592789
HR@10: 0.7751855779427359
NDCG@1: 0.2757158006362672
NDCG@5: 0.4555185224426819
NDCG@10: 0.5049637063300018


                                                                      

Epoch 39: Avg Loss/Batch 20.142659           


                                                   

HR@1: 0.2757158006362672
HR@5: 0.623541887592789
HR@10: 0.7751855779427359
NDCG@1: 0.2757158006362672
NDCG@5: 0.4554914814128963
NDCG@10: 0.504964661111953


                                                                      

Epoch 40: Avg Loss/Batch 20.141661           


                                                   

HR@1: 0.2757158006362672
HR@5: 0.623541887592789
HR@10: 0.7751855779427359
NDCG@1: 0.2757158006362672
NDCG@5: 0.4554914814128963
NDCG@10: 0.504964661111953


                                                                      

Epoch 41: Avg Loss/Batch 20.140748           


                                                   

HR@1: 0.2757158006362672
HR@5: 0.623541887592789
HR@10: 0.7751855779427359
NDCG@1: 0.2757158006362672
NDCG@5: 0.4554914814128963
NDCG@10: 0.504964661111953


                                                                      

Epoch 42: Avg Loss/Batch 20.139912           


                                                   

HR@1: 0.27465535524920465
HR@5: 0.623541887592789
HR@10: 0.7751855779427359
NDCG@1: 0.27465535524920465
NDCG@5: 0.4551001025725691
NDCG@10: 0.5045579748685275


                                                                      

Epoch 43: Avg Loss/Batch 20.139144           


                                                   

HR@1: 0.27465535524920465
HR@5: 0.623541887592789
HR@10: 0.7762460233297985
NDCG@1: 0.27465535524920465
NDCG@5: 0.45496125871936505
NDCG@10: 0.5047499249496739


                                                                      

Epoch 44: Avg Loss/Batch 20.138439           


                                                   

HR@1: 0.27465535524920465
HR@5: 0.623541887592789
HR@10: 0.7762460233297985
NDCG@1: 0.27465535524920465
NDCG@5: 0.45496125871936505
NDCG@10: 0.5047372365410354


                                                                      

Epoch 45: Avg Loss/Batch 20.137792           


                                                   

HR@1: 0.27465535524920465
HR@5: 0.623541887592789
HR@10: 0.7762460233297985
NDCG@1: 0.27465535524920465
NDCG@5: 0.45510010257256917
NDCG@10: 0.5048760803942396


                                                                      

Epoch 46: Avg Loss/Batch 20.137197           


                                                   

HR@1: 0.27465535524920465
HR@5: 0.623541887592789
HR@10: 0.7762460233297985
NDCG@1: 0.27465535524920465
NDCG@5: 0.45510010257256917
NDCG@10: 0.5048760803942396


                                                                      

Epoch 47: Avg Loss/Batch 20.136650           


                                                   

HR@1: 0.27465535524920465
HR@5: 0.623541887592789
HR@10: 0.7762460233297985
NDCG@1: 0.27465535524920465
NDCG@5: 0.4551271436023547
NDCG@10: 0.504903121424025


                                                                      

Epoch 48: Avg Loss/Batch 20.136148           


                                                   

HR@1: 0.27465535524920465
HR@5: 0.6246023329798516
HR@10: 0.7762460233297985
NDCG@1: 0.27465535524920465
NDCG@5: 0.45553737987725873
NDCG@10: 0.5049356194305219


                                                                      

Epoch 49: Avg Loss/Batch 20.135686           


                                                   

HR@1: 0.27465535524920465
HR@5: 0.6246023329798516
HR@10: 0.7762460233297985
NDCG@1: 0.27465535524920465
NDCG@5: 0.45567622373046285
NDCG@10: 0.5050744632837261


#### Epoch 19: Avg Loss/Batch 21.236540           
lr=0.01, weight_decay=0.01                                               
- HR@1: 0.271474019088017
- HR@5: 0.6129374337221634
- HR@10: 0.7698833510074231
- NDCG@1: 0.271474019088017
- NDCG@5: 0.4491453703208111
- NDCG@10: 0.5000956103254153

In [None]:
# torch.save(trainer.model.state_dict(), "saved_models/mf.pt")
trainer.model.load_state_dict(torch.load("saved_models/mfbpr.pt"))