In [1]:
import numpy as np
import random
import torch
from src.data import Dataset, PairwiseDataset
from src.models import MatrixFactorizationModel, MatrixFactorizationBPRModel
from src.trainer import Trainer, BPRTrainer
from src.evaluation import hitratio, ndcg

np.random.seed(42)
random.seed(42)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = 'cpu'
print(f"{device=}")

device='cuda'


In [7]:
class config:
    data_dir = 'ml-100k'
    neg_count = 16
    epochs = 50
    batch_size = 128
    dim = 40
    lr = 0.1

In [8]:
# dataset = Dataset(config.data_dir)
dataset = PairwiseDataset(config.data_dir)
dataset.gen_adjacency()
dataset.make_train_test()

# model = MatrixFactorizationModel(dataset.user_count, dataset.item_count, config.dim)
model = MatrixFactorizationBPRModel(dataset.user_count, dataset.item_count, config.dim)
optimizer = torch.optim.SGD(model.parameters(), lr=config.lr, weight_decay=0.001)

metrics = {
    "HR@1": (hitratio, {"top_n": 1}),
    "HR@5": (hitratio, {"top_n": 5}),
    "HR@10": (hitratio, {"top_n": 10}),
    "NDCG@1": (ndcg, {"top_n": 1}),
    "NDCG@5": (ndcg, {"top_n": 5}),
    "NDCG@10": (ndcg, {"top_n": 10}),
}

# trainer = Trainer(dataset, model, optimizer, metrics, 
#                   epochs=config.epochs, batch_size=config.batch_size,
#                   device=device)
trainer = BPRTrainer(dataset, model, optimizer, metrics, 
                     epochs=config.epochs, batch_size=config.batch_size,
                     device=device)

In [9]:
trainer.train(evaluate=True)

                                                                       

Epoch 0: Avg Loss/Batch 27.770695           


                                                   

HR@1: 0.021208907741251327
HR@5: 0.07953340402969247
HR@10: 0.1474019088016967
NDCG@1: 0.021208907741251327
NDCG@5: 0.049077059217569774
NDCG@10: 0.07077449547767423


                                                                       

Epoch 1: Avg Loss/Batch 14.730691           


                                                   

HR@1: 0.07529162248144221
HR@5: 0.2513255567338282
HR@10: 0.3753976670201485
NDCG@1: 0.07529162248144221
NDCG@5: 0.16413535589558928
NDCG@10: 0.20497390442219876


                                                                       

Epoch 2: Avg Loss/Batch 6.196596            


                                                   

HR@1: 0.13997879109225875
HR@5: 0.39236479321314954
HR@10: 0.5662778366914104
NDCG@1: 0.13997879109225875
NDCG@5: 0.26753838890829595
NDCG@10: 0.3235394948426757


                                                                      

Epoch 3: Avg Loss/Batch 3.886213            


                                                   

HR@1: 0.19618239660657477
HR@5: 0.5079533404029692
HR@10: 0.6691410392364793
NDCG@1: 0.19618239660657477
NDCG@5: 0.35647060819366483
NDCG@10: 0.4085733746831486


                                                                      

Epoch 4: Avg Loss/Batch 3.250696            


                                                   

HR@1: 0.23860021208907742
HR@5: 0.5567338282078473
HR@10: 0.7444326617179216
NDCG@1: 0.23860021208907742
NDCG@5: 0.404114267168876
NDCG@10: 0.46498160354978735


                                                                      

Epoch 5: Avg Loss/Batch 2.969429            


                                                   

HR@1: 0.26299045599151644
HR@5: 0.5949098621420997
HR@10: 0.7667020148462355
NDCG@1: 0.26299045599151644
NDCG@5: 0.4331971134820213
NDCG@10: 0.48908178597659935


                                                                      

Epoch 6: Avg Loss/Batch 2.793473            


                                                   

HR@1: 0.2788971367974549
HR@5: 0.6129374337221634
HR@10: 0.7857900318133616
NDCG@1: 0.2788971367974549
NDCG@5: 0.4498606593989456
NDCG@10: 0.5055859408976854


                                                                      

Epoch 7: Avg Loss/Batch 2.665497            


                                                   

HR@1: 0.2926829268292683
HR@5: 0.6224814422057264
HR@10: 0.7910922587486744
NDCG@1: 0.2926829268292683
NDCG@5: 0.46063611655760445
NDCG@10: 0.5150974211266615


                                                                      

Epoch 8: Avg Loss/Batch 2.566079            


                                                   

HR@1: 0.30010604453870626
HR@5: 0.623541887592789
HR@10: 0.7963944856839873
NDCG@1: 0.30010604453870626
NDCG@5: 0.46660437398498267
NDCG@10: 0.5226120361559595


                                                                      

Epoch 9: Avg Loss/Batch 2.487267            


                                                   

HR@1: 0.2990455991516437
HR@5: 0.630965005302227
HR@10: 0.8038176033934252
NDCG@1: 0.2990455991516437
NDCG@5: 0.4702329214398446
NDCG@10: 0.5260823563658648


                                                                       

Epoch 10: Avg Loss/Batch 2.423366            


                                                   

HR@1: 0.30858960763520676
HR@5: 0.6394485683987274
HR@10: 0.806998939554613
NDCG@1: 0.30858960763520676
NDCG@5: 0.477628395469891
NDCG@10: 0.5315522641339968


                                                                       

Epoch 11: Avg Loss/Batch 2.370682            


                                                   

HR@1: 0.3117709437963945
HR@5: 0.6405090137857901
HR@10: 0.8038176033934252
NDCG@1: 0.3117709437963945
NDCG@5: 0.479549325196149
NDCG@10: 0.5324387151834145


                                                                       

Epoch 12: Avg Loss/Batch 2.326468            


                                                   

HR@1: 0.3075291622481442
HR@5: 0.6436903499469777
HR@10: 0.8038176033934252
NDCG@1: 0.3075291622481442
NDCG@5: 0.48057969349875573
NDCG@10: 0.5323774068446605


                                                                       

Epoch 13: Avg Loss/Batch 2.289120            


                                                   

HR@1: 0.30434782608695654
HR@5: 0.6468716861081655
HR@10: 0.8048780487804879
NDCG@1: 0.30434782608695654
NDCG@5: 0.48081090851074304
NDCG@10: 0.5319359841521711


                                                                       

Epoch 14: Avg Loss/Batch 2.257349            


                                                   

HR@1: 0.30328738069989397
HR@5: 0.6511134676564156
HR@10: 0.8048780487804879
NDCG@1: 0.30328738069989397
NDCG@5: 0.4817915448888911
NDCG@10: 0.5314859155528835


                                                                       

Epoch 15: Avg Loss/Batch 2.230203            


                                                   

HR@1: 0.3022269353128314
HR@5: 0.6542948038176034
HR@10: 0.8123011664899258
NDCG@1: 0.3022269353128314
NDCG@5: 0.4828049433514744
NDCG@10: 0.5335856812428974


                                                                       

Epoch 16: Avg Loss/Batch 2.206875            


                                                   

HR@1: 0.3064687168610817
HR@5: 0.6585365853658537
HR@10: 0.8112407211028632
NDCG@1: 0.3064687168610817
NDCG@5: 0.4860384448421851
NDCG@10: 0.535100790597066


                                                                       

Epoch 17: Avg Loss/Batch 2.186599            


                                                  

HR@1: 0.3075291622481442
HR@5: 0.6564156945917285
HR@10: 0.8123011664899258
NDCG@1: 0.3075291622481442
NDCG@5: 0.48574001139069956
NDCG@10: 0.5360752563500992


                                                                       

Epoch 18: Avg Loss/Batch 2.168873            


                                                   

HR@1: 0.3107104984093319
HR@5: 0.6564156945917285
HR@10: 0.8144220572640509
NDCG@1: 0.3107104984093319
NDCG@5: 0.4870223120308232
NDCG@10: 0.5380875584899648


                                                                       

Epoch 19: Avg Loss/Batch 2.153338            


                                                   

HR@1: 0.31283138918345704
HR@5: 0.6574761399787911
HR@10: 0.8144220572640509
NDCG@1: 0.31283138918345704
NDCG@5: 0.48877125562925644
NDCG@10: 0.539622901131599


                                                                       

Epoch 20: Avg Loss/Batch 2.139672            


                                                   

HR@1: 0.31601272534464475
HR@5: 0.6585365853658537
HR@10: 0.8144220572640509
NDCG@1: 0.31601272534464475
NDCG@5: 0.4907155876810238
NDCG@10: 0.541168375665005


                                                                       

Epoch 21: Avg Loss/Batch 2.127610            


                                                   

HR@1: 0.3170731707317073
HR@5: 0.6585365853658537
HR@10: 0.8144220572640509
NDCG@1: 0.3170731707317073
NDCG@5: 0.4915046406463864
NDCG@10: 0.5419383119267721


                                                                       

Epoch 22: Avg Loss/Batch 2.116935            


                                                   

HR@1: 0.3170731707317073
HR@5: 0.6595970307529162
HR@10: 0.8112407211028632
NDCG@1: 0.3170731707317073
NDCG@5: 0.4923860647705322
NDCG@10: 0.5416563561698551


                                                                       

Epoch 23: Avg Loss/Batch 2.107437            


                                                   

HR@1: 0.31283138918345704
HR@5: 0.6606574761399788
HR@10: 0.8123011664899258
NDCG@1: 0.31283138918345704
NDCG@5: 0.49185967482115023
NDCG@10: 0.5410941416836722


                                                                       

Epoch 24: Avg Loss/Batch 2.098942            


                                                   

HR@1: 0.3138918345705196
HR@5: 0.662778366914104
HR@10: 0.8133616118769883
NDCG@1: 0.3138918345705196
NDCG@5: 0.49321855365969836
NDCG@10: 0.5419507787795029


                                                                       

Epoch 25: Avg Loss/Batch 2.091308            


                                                   

HR@1: 0.31283138918345704
HR@5: 0.6659597030752916
HR@10: 0.8133616118769883
NDCG@1: 0.31283138918345704
NDCG@5: 0.493845526066673
NDCG@10: 0.5415292095786765


                                                                       

Epoch 26: Avg Loss/Batch 2.084414            


                                                   

HR@1: 0.31601272534464475
HR@5: 0.6691410392364793
HR@10: 0.8123011664899258
NDCG@1: 0.31601272534464475
NDCG@5: 0.49674917452145245
NDCG@10: 0.5429931057666033


                                                                       

Epoch 27: Avg Loss/Batch 2.078152            


                                                   

HR@1: 0.3149522799575822
HR@5: 0.672322375397667
HR@10: 0.8123011664899258
NDCG@1: 0.3149522799575822
NDCG@5: 0.4975803209106289
NDCG@10: 0.5426820882809368


                                                                       

Epoch 28: Avg Loss/Batch 2.072426            


                                                   

HR@1: 0.3181336161187699
HR@5: 0.672322375397667
HR@10: 0.8112407211028632
NDCG@1: 0.3181336161187699
NDCG@5: 0.49872741640182494
NDCG@10: 0.5435004623456209


                                                                       

Epoch 29: Avg Loss/Batch 2.067155            


                                                   

HR@1: 0.3181336161187699
HR@5: 0.6712619300106044
HR@10: 0.8133616118769883
NDCG@1: 0.3181336161187699
NDCG@5: 0.4985019224444875
NDCG@10: 0.5442557821272538


                                                                       

Epoch 30: Avg Loss/Batch 2.062261            


                                                   

HR@1: 0.31919406150583246
HR@5: 0.6765641569459173
HR@10: 0.8133616118769883
NDCG@1: 0.31919406150583246
NDCG@5: 0.5006856523875038
NDCG@10: 0.5445265642555144


                                                                       

Epoch 31: Avg Loss/Batch 2.057691            


                                                   

HR@1: 0.3213149522799576
HR@5: 0.6765641569459173
HR@10: 0.8123011664899258
NDCG@1: 0.3213149522799576
NDCG@5: 0.5013107087803773
NDCG@10: 0.5448144683805604


                                                                       

Epoch 32: Avg Loss/Batch 2.053458            


                                                   

HR@1: 0.3213149522799576
HR@5: 0.6744432661717922
HR@10: 0.8133616118769883
NDCG@1: 0.3213149522799576
NDCG@5: 0.5005826073893523
NDCG@10: 0.5451779454111002


                                                                       

Epoch 33: Avg Loss/Batch 2.049571            


                                                   

HR@1: 0.32025450689289503
HR@5: 0.6733828207847296
HR@10: 0.8133616118769883
NDCG@1: 0.32025450689289503
NDCG@5: 0.5000209651113755
NDCG@10: 0.5450336052773489


                                                                       

Epoch 34: Avg Loss/Batch 2.045985            


                                                   

HR@1: 0.3213149522799576
HR@5: 0.6712619300106044
HR@10: 0.8144220572640509
NDCG@1: 0.3213149522799576
NDCG@5: 0.4996848167907361
NDCG@10: 0.5457541630047548


                                                                       

Epoch 35: Avg Loss/Batch 2.042663            


                                                   

HR@1: 0.31919406150583246
HR@5: 0.6702014846235419
HR@10: 0.8154825026511134
NDCG@1: 0.31919406150583246
NDCG@5: 0.4987695105415859
NDCG@10: 0.5455896417901206


                                                                       

Epoch 36: Avg Loss/Batch 2.039585            


                                                   

HR@1: 0.31919406150583246
HR@5: 0.6702014846235419
HR@10: 0.8154825026511134
NDCG@1: 0.31919406150583246
NDCG@5: 0.4983718364165505
NDCG@10: 0.5452010149545021


                                                                       

Epoch 37: Avg Loss/Batch 2.036734            


                                                   

HR@1: 0.31919406150583246
HR@5: 0.6733828207847296
HR@10: 0.8165429480381761
NDCG@1: 0.31919406150583246
NDCG@5: 0.49978786178888773
NDCG@10: 0.5457786966993716


                                                                       

Epoch 38: Avg Loss/Batch 2.034091            


                                                  

HR@1: 0.32237539766702017
HR@5: 0.6765641569459173
HR@10: 0.8165429480381761
NDCG@1: 0.32237539766702017
NDCG@5: 0.5019609178925357
NDCG@10: 0.5467805709916781


                                                                       

Epoch 39: Avg Loss/Batch 2.031634            


                                                  

HR@1: 0.32343584305408274
HR@5: 0.6755037115588547
HR@10: 0.8165429480381761
NDCG@1: 0.32343584305408274
NDCG@5: 0.5017837849401194
NDCG@10: 0.5470334285921257


                                                                       

Epoch 40: Avg Loss/Batch 2.029343            


                                                   

HR@1: 0.32343584305408274
HR@5: 0.6755037115588547
HR@10: 0.8165429480381761
NDCG@1: 0.32343584305408274
NDCG@5: 0.5019502440531675
NDCG@10: 0.5471530415909127


                                                                       

Epoch 41: Avg Loss/Batch 2.027203            


                                                   

HR@1: 0.32343584305408274
HR@5: 0.6744432661717922
HR@10: 0.8165429480381761
NDCG@1: 0.32343584305408274
NDCG@5: 0.5015864804726843
NDCG@10: 0.5471617083284368


                                                                       

Epoch 42: Avg Loss/Batch 2.025200            


                                                   

HR@1: 0.32237539766702017
HR@5: 0.6733828207847296
HR@10: 0.8165429480381761
NDCG@1: 0.32237539766702017
NDCG@5: 0.5006924941986698
NDCG@10: 0.5467076138401888


                                                                       

Epoch 43: Avg Loss/Batch 2.023323            


                                                  

HR@1: 0.32237539766702017
HR@5: 0.672322375397667
HR@10: 0.8144220572640509
NDCG@1: 0.32237539766702017
NDCG@5: 0.5002822579237657
NDCG@10: 0.5460773483135288


                                                                       

Epoch 44: Avg Loss/Batch 2.021562            


                                                   

HR@1: 0.32237539766702017
HR@5: 0.672322375397667
HR@10: 0.8165429480381761
NDCG@1: 0.32237539766702017
NDCG@5: 0.5001893125349239
NDCG@10: 0.5466054047928085


                                                                       

Epoch 45: Avg Loss/Batch 2.019906            


                                                   

HR@1: 0.32237539766702017
HR@5: 0.6691410392364793
HR@10: 0.8144220572640509
NDCG@1: 0.32237539766702017
NDCG@5: 0.4992444750118286
NDCG@10: 0.5461833261461331


                                                                       

Epoch 46: Avg Loss/Batch 2.018344            


                                                   

HR@1: 0.32237539766702017
HR@5: 0.6691410392364793
HR@10: 0.8154825026511134
NDCG@1: 0.32237539766702017
NDCG@5: 0.49956863541265767
NDCG@10: 0.5468899862791481


                                                                       

Epoch 47: Avg Loss/Batch 2.016868            


                                                   

HR@1: 0.32025450689289503
HR@5: 0.6702014846235419
HR@10: 0.8133616118769883
NDCG@1: 0.32025450689289503
NDCG@5: 0.49956674710215715
NDCG@10: 0.5459125921800769


                                                                       

Epoch 48: Avg Loss/Batch 2.015467            


                                                   

HR@1: 0.3213149522799576
HR@5: 0.6702014846235419
HR@10: 0.8133616118769883
NDCG@1: 0.3213149522799576
NDCG@5: 0.4997916668294362
NDCG@10: 0.5461122333098749


                                                                       

Epoch 49: Avg Loss/Batch 2.014135            


                                                   

HR@1: 0.3213149522799576
HR@5: 0.6702014846235419
HR@10: 0.8123011664899258
NDCG@1: 0.3213149522799576
NDCG@5: 0.49979166682943627
NDCG@10: 0.5458003878978441


#### Epoch 49: Avg Loss/Batch 20.135686           
lr=0.1, weight_decay=0.01, batch_size=256                                               
                                                   
- HR@1: 0.27465535524920465
- HR@5: 0.6246023329798516
- HR@10: 0.7762460233297985
- NDCG@1: 0.27465535524920465
- NDCG@5: 0.45567622373046285
- NDCG@10: 0.5050744632837261

#### Epoch 49: Avg Loss/Batch 2.268884            
lr=0.1, weight_decay=0.001, batch_size=256                             

- HR@1: 0.271474019088017
- HR@5: 0.6521739130434783
- HR@10: 0.8154825026511134
- NDCG@1: 0.271474019088017
- NDCG@5: 0.46938699909136805
- NDCG@10: 0.5226068198556473

#### Epoch 49: Avg Loss/Batch 1.926479            
lr=0.1, weight_decay=0.001, batch_size=128                             

- HR@1: 0.31919406150583246
- HR@5: 0.7104984093319194
- HR@10: 0.8504772004241782
- NDCG@1: 0.31919406150583246
- NDCG@5: 0.5225957334522389
- NDCG@10: 0.5683831099482718

#### Epoch 99: Avg Loss/Batch 1.799917            
lr=0.1, weight_decay=0.001, batch_size=128        
                                                   
- HR@1: 0.3244962884411453
- HR@5: 0.7232237539766702
- HR@10: 0.8642629904559915
- NDCG@1: 0.3244962884411453
- NDCG@5: 0.5335037960184896
- NDCG@10: 0.5796056642923276

#### Epoch 143: Avg Loss/Batch 1.769505            
lr=0.1, weight_decay=0.001, batch_size=128                                                   

- HR@1: 0.3372216330858961
- HR@5: 0.7253446447507953
- HR@10: 0.8685047720042418
- NDCG@1: 0.3372216330858961
- NDCG@5: 0.5399725224060331
- NDCG@10: 0.5865801610610061

In [None]:
torch.save(trainer.model.state_dict(), "saved_models/mf.pt")
# trainer.model.load_state_dict(torch.load("saved_models/mfbpr.pt"))