In [1]:
import numpy as np
import random
import torch
from src.dataloaders import Dataset, PairwiseDataset
from src.models import MatrixFactorizationRMSEModel, MatrixFactorizationBPRModel
from src.trainer import Trainer
from src.metrics import hitratio, ndcg

np.random.seed(42)
random.seed(42)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = 'cpu'
print(f"{device=}")

device='cuda'


In [2]:
class config:
    data_dir = 'ml-100k'
    neg_count = 10
    epochs = 50
    batch_size = 128
    dim = 40
    lr = 0.02

In [3]:
dataset = Dataset(config.data_dir)
# dataset = PairwiseDataset(config.data_dir)
dataset.gen_adjacency()
dataset.make_train_test()

model = MatrixFactorizationRMSEModel(dataset.user_count, dataset.item_count, config.dim)
# model = MatrixFactorizationBPRModel(dataset.user_count, dataset.item_count, config.dim)
# optimizer = torch.optim.SGD(model.parameters(), lr=config.lr, weight_decay=0.1)
optimizer = torch.optim.Adam(model.parameters(), lr=config.lr, weight_decay=0.1)

metrics = {
    "HR@1": (hitratio, {"top_n": 1}),
    "HR@5": (hitratio, {"top_n": 5}),
    "HR@10": (hitratio, {"top_n": 10}),
    "NDCG@1": (ndcg, {"top_n": 1}),
    "NDCG@5": (ndcg, {"top_n": 5}),
    "NDCG@10": (ndcg, {"top_n": 10}),
}

trainer = Trainer(dataset, model, optimizer, metrics, 
                  epochs=config.epochs, batch_size=config.batch_size,
                  device=device)

In [4]:
trainer.train(evaluate=True)

                                                                       

Epoch 0: Avg Loss/Batch 0.482849            


                                                   

HR@1: 0.012725344644750796
HR@5: 0.08589607635206786
HR@10: 0.1823966065747614
NDCG@1: 0.012725344644750796
NDCG@5: 0.047902182130257605
NDCG@10: 0.078729489394622


                                                                       

Epoch 1: Avg Loss/Batch 0.443693            


                                                   

HR@1: 0.015906680805938492
HR@5: 0.07529162248144221
HR@10: 0.1474019088016967
NDCG@1: 0.015906680805938492
NDCG@5: 0.04440924843405658
NDCG@10: 0.06723657894733304


                                                                       

Epoch 2: Avg Loss/Batch 0.443693            


                                                   

HR@1: 0.010604453870625663
HR@5: 0.04665959703075292
HR@10: 0.10816542948038176
NDCG@1: 0.010604453870625663
NDCG@5: 0.027012426597915442
NDCG@10: 0.046440197805015355


                                                                       

Epoch 3: Avg Loss/Batch 0.443693            


                                                   

HR@1: 0.019088016967126194
HR@5: 0.05514316012725345
HR@10: 0.11664899257688228
NDCG@1: 0.019088016967126194
NDCG@5: 0.03647722872217011
NDCG@10: 0.05616749360038396


                                                                       

Epoch 4: Avg Loss/Batch 0.443693            


                                                   

HR@1: 0.01166489925768823
HR@5: 0.04984093319194061
HR@10: 0.0975609756097561
NDCG@1: 0.01166489925768823
NDCG@5: 0.029827536638061022
NDCG@10: 0.044748051361000483


                                                                       

Epoch 5: Avg Loss/Batch 0.443693            


                                                   

HR@1: 0.008483563096500531
HR@5: 0.04665959703075292
HR@10: 0.09331919406150584
NDCG@1: 0.008483563096500531
NDCG@5: 0.026986728068119643
NDCG@10: 0.04193026783110648


                                                                       

Epoch 6: Avg Loss/Batch 0.443693            


                                                   

HR@1: 0.006362672322375398
HR@5: 0.05090137857900318
HR@10: 0.10392364793213149
NDCG@1: 0.006362672322375398
NDCG@5: 0.02783424164771332
NDCG@10: 0.04476733473889774


                                                                       

Epoch 7: Avg Loss/Batch 0.443693            


                                                   

HR@1: 0.009544008483563097
HR@5: 0.04772004241781548
HR@10: 0.09650053022269353
NDCG@1: 0.009544008483563097
NDCG@5: 0.02866596298289514
NDCG@10: 0.04402228010555428


                                                                       

Epoch 8: Avg Loss/Batch 0.443693            


                                                   

HR@1: 0.007423117709437964
HR@5: 0.044538706256627786
HR@10: 0.09437963944856839
NDCG@1: 0.007423117709437964
NDCG@5: 0.025756019243407424
NDCG@10: 0.04175557929222965


                                                                       

Epoch 9: Avg Loss/Batch 0.443693            


                                                   

HR@1: 0.0042417815482502655
HR@5: 0.043478260869565216
HR@10: 0.08907741251325557
NDCG@1: 0.0042417815482502655
NDCG@5: 0.023803503596431495
NDCG@10: 0.03824630026660365


                                                                        

Epoch 10: Avg Loss/Batch 0.443693            


                                                   

HR@1: 0.016967126193001062
HR@5: 0.05620360551431601
HR@10: 0.1102863202545069
NDCG@1: 0.016967126193001062
NDCG@5: 0.036722922614074296
NDCG@10: 0.05382767193453573


                                                                        

Epoch 11: Avg Loss/Batch 0.443693            


                                                   

HR@1: 0.009544008483563097
HR@5: 0.05408271474019088
HR@10: 0.10392364793213149
NDCG@1: 0.009544008483563097
NDCG@5: 0.03026158313226044
NDCG@10: 0.04610942787704149


                                                                        

Epoch 12: Avg Loss/Batch 0.443693            


                                                   

HR@1: 0.013785790031813362
HR@5: 0.04559915164369035
HR@10: 0.0975609756097561
NDCG@1: 0.013785790031813362
NDCG@5: 0.02861757355842613
NDCG@10: 0.04522396210130697


                                                                        

Epoch 13: Avg Loss/Batch 0.443693            


                                                   

HR@1: 0.005302226935312832
HR@5: 0.05620360551431601
HR@10: 0.1007423117709438
NDCG@1: 0.005302226935312832
NDCG@5: 0.030669357582552643
NDCG@10: 0.04515352569884035


                                                                        

Epoch 14: Avg Loss/Batch 0.443693            


                                                   

HR@1: 0.018027571580063628
HR@5: 0.06256627783669141
HR@10: 0.11983032873807
NDCG@1: 0.018027571580063628
NDCG@5: 0.03827146813535963
NDCG@10: 0.056770931178710037


                                                                        

Epoch 15: Avg Loss/Batch 0.443693            


                                                   

HR@1: 0.006362672322375398
HR@5: 0.04878048780487805
HR@10: 0.0975609756097561
NDCG@1: 0.006362672322375398
NDCG@5: 0.027376218597946388
NDCG@10: 0.042839042835877215


                                                                        

Epoch 16: Avg Loss/Batch 0.443693            


                                                   

HR@1: 0.007423117709437964
HR@5: 0.05408271474019088
HR@10: 0.1102863202545069
NDCG@1: 0.007423117709437964
NDCG@5: 0.03156583481667376
NDCG@10: 0.049411355387102916


                                                                        

Epoch 17: Avg Loss/Batch 0.443693            


                                                   

HR@1: 0.010604453870625663
HR@5: 0.03923647932131495
HR@10: 0.09013785790031813
NDCG@1: 0.010604453870625663
NDCG@5: 0.025202560560639232
NDCG@10: 0.041610010843402684


                                                                        

Epoch 18: Avg Loss/Batch 0.443693            


                                                   

HR@1: 0.008483563096500531
HR@5: 0.04559915164369035
HR@10: 0.09225874867444327
NDCG@1: 0.008483563096500531
NDCG@5: 0.025698871993671338
NDCG@10: 0.04061931821866086


                                                                        

Epoch 19: Avg Loss/Batch 0.443693            


                                                   

HR@1: 0.005302226935312832
HR@5: 0.042417815482502653
HR@10: 0.08695652173913043
NDCG@1: 0.005302226935312832
NDCG@5: 0.023062811635990207
NDCG@10: 0.03719671825114308


                                                                        

Epoch 20: Avg Loss/Batch 0.443693            


                                                   

HR@1: 0.01166489925768823
HR@5: 0.05620360551431601
HR@10: 0.10604453870625663
NDCG@1: 0.01166489925768823
NDCG@5: 0.03351834974770267
NDCG@10: 0.04925570989543896


                                                                        

Epoch 21: Avg Loss/Batch 0.443693            


                                                   

HR@1: 0.009544008483563097
HR@5: 0.03923647932131495
HR@10: 0.09437963944856839
NDCG@1: 0.009544008483563097
NDCG@5: 0.024367034900855802
NDCG@10: 0.04157526520193852


                                                                        

Epoch 22: Avg Loss/Batch 0.443693            


                                                   

HR@1: 0.010604453870625663
HR@5: 0.03817603393425239
HR@10: 0.08483563096500531
NDCG@1: 0.010604453870625663
NDCG@5: 0.024062306164662095
NDCG@10: 0.03904524162219123


                                                                        

Epoch 23: Avg Loss/Batch 0.443693            


                                                   

HR@1: 0.012725344644750796
HR@5: 0.04029692470837752
HR@10: 0.09650053022269353
NDCG@1: 0.012725344644750796
NDCG@5: 0.025665536395124623
NDCG@10: 0.043392239746988935


                                                                        

Epoch 24: Avg Loss/Batch 0.443693            


                                                   

HR@1: 0.014846235418875928
HR@5: 0.05090137857900318
HR@10: 0.09862142099681867
NDCG@1: 0.014846235418875928
NDCG@5: 0.03295421722556625
NDCG@10: 0.04842443706974414


                                                                        

Epoch 25: Avg Loss/Batch 0.443693            


                                                   

HR@1: 0.01166489925768823
HR@5: 0.04878048780487805
HR@10: 0.0911983032873807
NDCG@1: 0.01166489925768823
NDCG@5: 0.03006370515071384
NDCG@10: 0.04388482933856297


                                                                        

Epoch 26: Avg Loss/Batch 0.443693            


                                                   

HR@1: 0.018027571580063628
HR@5: 0.06256627783669141
HR@10: 0.11134676564156946
NDCG@1: 0.018027571580063628
NDCG@5: 0.03945573268519871
NDCG@10: 0.05513468011812291


                                                                        

Epoch 27: Avg Loss/Batch 0.443693            


                                                   

HR@1: 0.006362672322375398
HR@5: 0.04878048780487805
HR@10: 0.08695652173913043
NDCG@1: 0.006362672322375398
NDCG@5: 0.028041480820080154
NDCG@10: 0.03994563630507638


                                                                        

Epoch 28: Avg Loss/Batch 0.443693            


                                                   

HR@1: 0.009544008483563097
HR@5: 0.044538706256627786
HR@10: 0.0975609756097561
NDCG@1: 0.009544008483563097
NDCG@5: 0.027260611449926222
NDCG@10: 0.04420275183052765


                                                                        

Epoch 29: Avg Loss/Batch 0.443693            


                                                   

HR@1: 0.01166489925768823
HR@5: 0.060445387062566275
HR@10: 0.11346765641569459
NDCG@1: 0.01166489925768823
NDCG@5: 0.03528803909121325
NDCG@10: 0.05208451415738529


                                                                        

Epoch 30: Avg Loss/Batch 0.443693            


                                                   

HR@1: 0.012725344644750796
HR@5: 0.04665959703075292
HR@10: 0.09437963944856839
NDCG@1: 0.012725344644750796
NDCG@5: 0.02929350961992814
NDCG@10: 0.04448933201373068


                                                                        

Epoch 31: Avg Loss/Batch 0.443693            


                                                   

HR@1: 0.012725344644750796
HR@5: 0.05938494167550371
HR@10: 0.1102863202545069
NDCG@1: 0.012725344644750796
NDCG@5: 0.03513033780343229
NDCG@10: 0.0514865861848163


                                                                        

Epoch 32: Avg Loss/Batch 0.443693            


                                                   

HR@1: 0.013785790031813362
HR@5: 0.044538706256627786
HR@10: 0.09013785790031813
NDCG@1: 0.013785790031813362
NDCG@5: 0.029444342108889496
NDCG@10: 0.04354303917181251


                                                                        

Epoch 33: Avg Loss/Batch 0.443693            


                                                   

HR@1: 0.008483563096500531
HR@5: 0.04772004241781548
HR@10: 0.10180275715800637
NDCG@1: 0.008483563096500531
NDCG@5: 0.027435827672294334
NDCG@10: 0.04480262239704417


                                                                        

Epoch 34: Avg Loss/Batch 0.443693            


                                                   

HR@1: 0.01166489925768823
HR@5: 0.05408271474019088
HR@10: 0.11983032873807
NDCG@1: 0.01166489925768823
NDCG@5: 0.03263254706889681
NDCG@10: 0.053719457438719316


                                                                        

Epoch 35: Avg Loss/Batch 0.443693            


                                                   

HR@1: 0.008483563096500531
HR@5: 0.04029692470837752
HR@10: 0.09331919406150584
NDCG@1: 0.008483563096500531
NDCG@5: 0.024007649875032604
NDCG@10: 0.04105450406161621


                                                                        

Epoch 36: Avg Loss/Batch 0.443693            


                                                   

HR@1: 0.014846235418875928
HR@5: 0.05408271474019088
HR@10: 0.11240721102863202
NDCG@1: 0.014846235418875928
NDCG@5: 0.03420435771491377
NDCG@10: 0.05257455435664521


                                                                        

Epoch 37: Avg Loss/Batch 0.443693            


                                                   

HR@1: 0.01166489925768823
HR@5: 0.042417815482502653
HR@10: 0.07953340402969247
NDCG@1: 0.01166489925768823
NDCG@5: 0.025986465985463465
NDCG@10: 0.03787681418079415


                                                                        

Epoch 38: Avg Loss/Batch 0.443693            


                                                   

HR@1: 0.010604453870625663
HR@5: 0.053022269353128315
HR@10: 0.1007423117709438
NDCG@1: 0.010604453870625663
NDCG@5: 0.03138678513420931
NDCG@10: 0.046379284645930864


                                                                        

Epoch 39: Avg Loss/Batch 0.443693            


                                                   

HR@1: 0.013785790031813362
HR@5: 0.05090137857900318
HR@10: 0.0975609756097561
NDCG@1: 0.013785790031813362
NDCG@5: 0.03211869156578282
NDCG@10: 0.04685983235165465


                                                                        

Epoch 40: Avg Loss/Batch 0.443693            


                                                   

HR@1: 0.012725344644750796
HR@5: 0.043478260869565216
HR@10: 0.08589607635206786
NDCG@1: 0.012725344644750796
NDCG@5: 0.027525708586918007
NDCG@10: 0.04080936714797035


                                                                        

Epoch 41: Avg Loss/Batch 0.443693            


                                                   

HR@1: 0.005302226935312832
HR@5: 0.031813361611876985
HR@10: 0.08271474019088017
NDCG@1: 0.005302226935312832
NDCG@5: 0.01810226075204055
NDCG@10: 0.03446457345488602


                                                                        

Epoch 42: Avg Loss/Batch 0.443693            


                                                   

HR@1: 0.012725344644750796
HR@5: 0.05514316012725345
HR@10: 0.10604453870625663
NDCG@1: 0.012725344644750796
NDCG@5: 0.03227885467817566
NDCG@10: 0.048335707097252684


                                                                        

Epoch 43: Avg Loss/Batch 0.443693            


                                                   

HR@1: 0.0042417815482502655
HR@5: 0.04878048780487805
HR@10: 0.10392364793213149
NDCG@1: 0.0042417815482502655
NDCG@5: 0.025133998905204315
NDCG@10: 0.04271977249863316


                                                                        

Epoch 44: Avg Loss/Batch 0.443693            


                                                   

HR@1: 0.015906680805938492
HR@5: 0.05938494167550371
HR@10: 0.09225874867444327
NDCG@1: 0.015906680805938492
NDCG@5: 0.03655454677097807
NDCG@10: 0.04722378666911381


                                                                        

Epoch 45: Avg Loss/Batch 0.443693            


                                                   

HR@1: 0.005302226935312832
HR@5: 0.05726405090137858
HR@10: 0.11134676564156946
NDCG@1: 0.005302226935312832
NDCG@5: 0.03247621598470628
NDCG@10: 0.04994568026591505


                                                                        

Epoch 46: Avg Loss/Batch 0.443693            


                                                   

HR@1: 0.009544008483563097
HR@5: 0.044538706256627786
HR@10: 0.08271474019088017
NDCG@1: 0.009544008483563097
NDCG@5: 0.0264187905054346
NDCG@10: 0.03864659842545753


                                                                        

Epoch 47: Avg Loss/Batch 0.443693            


                                                   

HR@1: 0.007423117709437964
HR@5: 0.04984093319194061
HR@10: 0.088016967126193
NDCG@1: 0.007423117709437964
NDCG@5: 0.027799017022719138
NDCG@10: 0.03989756000465327


                                                                        

Epoch 48: Avg Loss/Batch 0.443693            


                                                   

HR@1: 0.006362672322375398
HR@5: 0.043478260869565216
HR@10: 0.08695652173913043
NDCG@1: 0.006362672322375398
NDCG@5: 0.024103825358417454
NDCG@10: 0.037858508739738755


                                                                        

Epoch 49: Avg Loss/Batch 0.443693            


                                                   

HR@1: 0.003181336161187699
HR@5: 0.04984093319194061
HR@10: 0.09650053022269353
NDCG@1: 0.003181336161187699
NDCG@5: 0.025890483826004472
NDCG@10: 0.04089193387087757




#### Epoch 49: Avg Loss/Batch 20.135686           
lr=0.1, weight_decay=0.01, batch_size=256                                               
                                                   
- HR@1: 0.27465535524920465
- HR@5: 0.6246023329798516
- HR@10: 0.7762460233297985
- NDCG@1: 0.27465535524920465
- NDCG@5: 0.45567622373046285
- NDCG@10: 0.5050744632837261

#### Epoch 49: Avg Loss/Batch 2.268884            
lr=0.1, weight_decay=0.001, batch_size=256                             

- HR@1: 0.271474019088017
- HR@5: 0.6521739130434783
- HR@10: 0.8154825026511134
- NDCG@1: 0.271474019088017
- NDCG@5: 0.46938699909136805
- NDCG@10: 0.5226068198556473

#### Epoch 49: Avg Loss/Batch 1.926479            
lr=0.1, weight_decay=0.001, batch_size=128                             

- HR@1: 0.31919406150583246
- HR@5: 0.7104984093319194
- HR@10: 0.8504772004241782
- NDCG@1: 0.31919406150583246
- NDCG@5: 0.5225957334522389
- NDCG@10: 0.5683831099482718

#### Epoch 49: Avg Loss/Batch 1.937064            
lr=0.1, weight_decay=0.001, batch_size=128                             
 
- HR@1: 0.3297985153764581
- HR@5: 0.6839872746553552
- HR@10: 0.848356309650053
- NDCG@1: 0.3297985153764581
- NDCG@5: 0.5171409534856304
- NDCG@10: 0.5709949698866551

#### Epoch 99: Avg Loss/Batch 1.799917            
lr=0.1, weight_decay=0.001, batch_size=128        
                                                   
- HR@1: 0.3244962884411453
- HR@5: 0.7232237539766702
- HR@10: 0.8642629904559915
- NDCG@1: 0.3244962884411453
- NDCG@5: 0.5335037960184896
- NDCG@10: 0.5796056642923276

#### Epoch 143: Avg Loss/Batch 1.769505            
lr=0.1, weight_decay=0.001, batch_size=128                                                   

- HR@1: 0.3372216330858961
- HR@5: 0.7253446447507953
- HR@10: 0.8685047720042418
- NDCG@1: 0.3372216330858961
- NDCG@5: 0.5399725224060331
- NDCG@10: 0.5865801610610061

In [None]:
torch.save(trainer.model.state_dict(), "saved_models/mf.pt")
# trainer.model.load_state_dict(torch.load("saved_models/mfbpr.pt"))