In [1]:
import numpy as np
import random
import torch
from src.data import Dataset
from src.models import MatrixFactorizationModel
from src.trainer import Trainer
from src.evaluation import hitratio, ndcg

np.random.seed(42)
random.seed(42)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = 'cpu'
print(f"{device=}")

device='cuda'


In [2]:
class config:
    data_dir = 'ml-100k'
    neg_count = 16
    epochs = 50
    batch_size = 256
    dim = 40
    lr = 0.01

In [3]:
dataset = Dataset(config.data_dir)
dataset.gen_adjacency()
dataset.make_train_test()

model = MatrixFactorizationModel(dataset.user_count, dataset.item_count, config.dim)
optimizer = torch.optim.SGD(model.parameters(), lr=config.lr, weight_decay=0.01)

metrics = {
    "HR@1": (hitratio, {"top_n": 1}),
    "HR@5": (hitratio, {"top_n": 5}),
    "HR@10": (hitratio, {"top_n": 10}),
    "NDCG@1": (ndcg, {"top_n": 1}),
    "NDCG@5": (ndcg, {"top_n": 5}),
    "NDCG@10": (ndcg, {"top_n": 10}),
}

trainer = Trainer(dataset, model, optimizer, metrics, 
                  epochs=config.epochs, batch_size=config.batch_size,
                  device=device)

In [4]:
# lens = []
# for u in dataset.user_item:
#     lens.append(len(dataset.user_item[u]))

# np.mean(lens)

dataset.item_count, dataset.user_count

(1682, 943)

In [5]:
trainer.train(evaluate=True)

                                                                         

Epoch 0: Avg Loss/Batch 5.236966            


                                                   

HR@1: 0.013785790031813362
HR@5: 0.05938494167550371
HR@10: 0.10498409331919406
NDCG@1: 0.013785790031813362
NDCG@5: 0.03517357968736294
NDCG@10: 0.04967662554833153


                                                                         

Epoch 1: Avg Loss/Batch 3.530673            


                                                   

HR@1: 0.013785790031813362
HR@5: 0.05938494167550371
HR@10: 0.10498409331919406
NDCG@1: 0.013785790031813362
NDCG@5: 0.03517357968736294
NDCG@10: 0.04967662554833153


                                                                         

Epoch 2: Avg Loss/Batch 2.393242            


                                                   

HR@1: 0.013785790031813362
HR@5: 0.05938494167550371
HR@10: 0.10498409331919406
NDCG@1: 0.013785790031813362
NDCG@5: 0.03517357968736294
NDCG@10: 0.04967662554833153


                                                                         

Epoch 3: Avg Loss/Batch 1.640817            


                                                   

HR@1: 0.013785790031813362
HR@5: 0.05832449628844114
HR@10: 0.10604453870625663
NDCG@1: 0.013785790031813362
NDCG@5: 0.03476334341245887
NDCG@10: 0.04995066500346554


                                                                         

Epoch 4: Avg Loss/Batch 1.150875            


                                                   

HR@1: 0.013785790031813362
HR@5: 0.05832449628844114
HR@10: 0.10604453870625663
NDCG@1: 0.013785790031813362
NDCG@5: 0.03471687071803803
NDCG@10: 0.04990419230904469


                                                                         

Epoch 5: Avg Loss/Batch 0.841272            


                                                   

HR@1: 0.013785790031813362
HR@5: 0.05832449628844114
HR@10: 0.10604453870625663
NDCG@1: 0.013785790031813362
NDCG@5: 0.03471687071803803
NDCG@10: 0.04990419230904469


                                                                         

Epoch 6: Avg Loss/Batch 0.655111            


                                                   

HR@1: 0.013785790031813362
HR@5: 0.05832449628844114
HR@10: 0.10498409331919406
NDCG@1: 0.013785790031813362
NDCG@5: 0.03476334341245887
NDCG@10: 0.049646746536294556


                                                                         

Epoch 7: Avg Loss/Batch 0.550467            


                                                   

HR@1: 0.013785790031813362
HR@5: 0.05832449628844114
HR@10: 0.10604453870625663
NDCG@1: 0.013785790031813362
NDCG@5: 0.03476334341245887
NDCG@10: 0.0499532839979254


                                                                         

Epoch 8: Avg Loss/Batch 0.495683            


                                                   

HR@1: 0.013785790031813362
HR@5: 0.05832449628844114
HR@10: 0.10604453870625663
NDCG@1: 0.013785790031813362
NDCG@5: 0.034790384442244394
NDCG@10: 0.04998032502771093


                                                                         

Epoch 9: Avg Loss/Batch 0.468612            


                                                   

HR@1: 0.013785790031813362
HR@5: 0.05832449628844114
HR@10: 0.10604453870625663
NDCG@1: 0.013785790031813362
NDCG@5: 0.034790384442244394
NDCG@10: 0.04998032502771093


                                                                          

Epoch 10: Avg Loss/Batch 0.455723            


                                                   

HR@1: 0.013785790031813362
HR@5: 0.05832449628844114
HR@10: 0.10498409331919406
NDCG@1: 0.013785790031813362
NDCG@5: 0.034790384442244394
NDCG@10: 0.04967378756608008


                                                                          

Epoch 11: Avg Loss/Batch 0.449709            


                                                   

HR@1: 0.013785790031813362
HR@5: 0.05832449628844114
HR@10: 0.10498409331919406
NDCG@1: 0.013785790031813362
NDCG@5: 0.034790384442244394
NDCG@10: 0.04967378756608008


                                                                          

Epoch 12: Avg Loss/Batch 0.446931            


                                                   

HR@1: 0.013785790031813362
HR@5: 0.05832449628844114
HR@10: 0.10604453870625663
NDCG@1: 0.013785790031813362
NDCG@5: 0.034790384442244394
NDCG@10: 0.05000665400826938


                                                                          

Epoch 13: Avg Loss/Batch 0.445655            


                                                   

HR@1: 0.013785790031813362
HR@5: 0.05832449628844114
HR@10: 0.10604453870625663
NDCG@1: 0.013785790031813362
NDCG@5: 0.034743911747823555
NDCG@10: 0.04993330584666897


                                                                          

Epoch 14: Avg Loss/Batch 0.445071            


                                                   

HR@1: 0.013785790031813362
HR@5: 0.05832449628844114
HR@10: 0.10604453870625663
NDCG@1: 0.013785790031813362
NDCG@5: 0.034743911747823555
NDCG@10: 0.04993330584666897


                                                                          

Epoch 15: Avg Loss/Batch 0.444805            


                                                   

HR@1: 0.013785790031813362
HR@5: 0.05832449628844114
HR@10: 0.10604453870625663
NDCG@1: 0.013785790031813362
NDCG@5: 0.03469743905340271
NDCG@10: 0.04989952156088662


                                                                          

Epoch 16: Avg Loss/Batch 0.444685            


                                                   

HR@1: 0.013785790031813362
HR@5: 0.05832449628844114
HR@10: 0.10604453870625663
NDCG@1: 0.013785790031813362
NDCG@5: 0.03469743905340271
NDCG@10: 0.04989952156088662


                                                                          

Epoch 17: Avg Loss/Batch 0.444632            


                                                   

HR@1: 0.013785790031813362
HR@5: 0.05832449628844114
HR@10: 0.10604453870625663
NDCG@1: 0.013785790031813362
NDCG@5: 0.03469743905340271
NDCG@10: 0.049886833152248136


                                                                          

Epoch 18: Avg Loss/Batch 0.444610            


                                                   

HR@1: 0.013785790031813362
HR@5: 0.05832449628844114
HR@10: 0.10604453870625663
NDCG@1: 0.013785790031813362
NDCG@5: 0.0348362829066068
NDCG@10: 0.05001298859681373


                                                                          

Epoch 19: Avg Loss/Batch 0.444601            


                                                  

HR@1: 0.013785790031813362
HR@5: 0.05726405090137858
HR@10: 0.10604453870625663
NDCG@1: 0.013785790031813362
NDCG@5: 0.034426046631702734
NDCG@10: 0.0500111053965136


                                                                          

Epoch 20: Avg Loss/Batch 0.444598            


                                                   

HR@1: 0.013785790031813362
HR@5: 0.05726405090137858
HR@10: 0.10604453870625663
NDCG@1: 0.013785790031813362
NDCG@5: 0.034426046631702734
NDCG@10: 0.04999841698787512


                                                                          

Epoch 21: Avg Loss/Batch 0.444597            


                                                   

HR@1: 0.013785790031813362
HR@5: 0.05726405090137858
HR@10: 0.10604453870625663
NDCG@1: 0.013785790031813362
NDCG@5: 0.034426046631702734
NDCG@10: 0.04999841698787512


                                                                          

Epoch 22: Avg Loss/Batch 0.444597            


                                                   

HR@1: 0.013785790031813362
HR@5: 0.05620360551431601
HR@10: 0.10604453870625663
NDCG@1: 0.013785790031813362
NDCG@5: 0.03396933766237782
NDCG@10: 0.04990675787831895


                                                                          

Epoch 23: Avg Loss/Batch 0.444598            


                                                   

HR@1: 0.013785790031813362
HR@5: 0.05620360551431601
HR@10: 0.10604453870625663
NDCG@1: 0.013785790031813362
NDCG@5: 0.03396933766237782
NDCG@10: 0.04990675787831895


                                                                          

Epoch 24: Avg Loss/Batch 0.444598            


                                                   

HR@1: 0.013785790031813362
HR@5: 0.05620360551431601
HR@10: 0.10604453870625663
NDCG@1: 0.013785790031813362
NDCG@5: 0.03396933766237782
NDCG@10: 0.04989406946968047


                                                                          

Epoch 25: Avg Loss/Batch 0.444599            


                                                   

HR@1: 0.013785790031813362
HR@5: 0.05514316012725345
HR@10: 0.1071049840933192
NDCG@1: 0.013785790031813362
NDCG@5: 0.033559101387473755
NDCG@10: 0.05016810892481448


                                                                          

Epoch 26: Avg Loss/Batch 0.444599            


                                                   

HR@1: 0.013785790031813362
HR@5: 0.05408271474019088
HR@10: 0.1071049840933192
NDCG@1: 0.013785790031813362
NDCG@5: 0.033102392418148836
NDCG@10: 0.0500891382238968


                                                                          

Epoch 27: Avg Loss/Batch 0.444600            


                                                   

HR@1: 0.013785790031813362
HR@5: 0.053022269353128315
HR@10: 0.10816542948038176
NDCG@1: 0.013785790031813362
NDCG@5: 0.03269215614324476
NDCG@10: 0.05035048927039233


                                                                          

Epoch 28: Avg Loss/Batch 0.444600            


                                                   

HR@1: 0.013785790031813362
HR@5: 0.053022269353128315
HR@10: 0.1071049840933192
NDCG@1: 0.013785790031813362
NDCG@5: 0.03269215614324476
NDCG@10: 0.05006932862603845


                                                                          

Epoch 29: Avg Loss/Batch 0.444600            


                                                   

HR@1: 0.013785790031813362
HR@5: 0.053022269353128315
HR@10: 0.10604453870625663
NDCG@1: 0.013785790031813362
NDCG@5: 0.03269215614324476
NDCG@10: 0.049775479573046094


                                                                          

Epoch 30: Avg Loss/Batch 0.444600            


                                                   

HR@1: 0.013785790031813362
HR@5: 0.053022269353128315
HR@10: 0.10604453870625663
NDCG@1: 0.013785790031813362
NDCG@5: 0.032830999996448854
NDCG@10: 0.04989006695353047


                                                                          

Epoch 31: Avg Loss/Batch 0.444600            


                                                   

HR@1: 0.013785790031813362
HR@5: 0.053022269353128315
HR@10: 0.10604453870625663
NDCG@1: 0.013785790031813362
NDCG@5: 0.032830999996448854
NDCG@10: 0.04987475955043212


                                                                          

Epoch 32: Avg Loss/Batch 0.444600            


                                                   

HR@1: 0.013785790031813362
HR@5: 0.053022269353128315
HR@10: 0.10498409331919406
NDCG@1: 0.013785790031813362
NDCG@5: 0.032830999996448854
NDCG@10: 0.04956822208880127


                                                                          

Epoch 33: Avg Loss/Batch 0.444601            


                                                   

HR@1: 0.013785790031813362
HR@5: 0.053022269353128315
HR@10: 0.10498409331919406
NDCG@1: 0.013785790031813362
NDCG@5: 0.032969843849652944
NDCG@10: 0.049694377533366876


                                                                          

Epoch 34: Avg Loss/Batch 0.444601            


                                                   

HR@1: 0.013785790031813362
HR@5: 0.053022269353128315
HR@10: 0.10498409331919406
NDCG@1: 0.013785790031813362
NDCG@5: 0.032969843849652944
NDCG@10: 0.049694377533366876


                                                                          

Epoch 35: Avg Loss/Batch 0.444601            


                                                   

HR@1: 0.013785790031813362
HR@5: 0.053022269353128315
HR@10: 0.10498409331919406
NDCG@1: 0.013785790031813362
NDCG@5: 0.032969843849652944
NDCG@10: 0.049694377533366876


                                                                          

Epoch 36: Avg Loss/Batch 0.444601            


                                                   

HR@1: 0.013785790031813362
HR@5: 0.051961823966065745
HR@10: 0.10498409331919406
NDCG@1: 0.013785790031813362
NDCG@5: 0.03260608026916971
NDCG@10: 0.04970835222129088


                                                                          

Epoch 37: Avg Loss/Batch 0.444601            


                                                   

HR@1: 0.013785790031813362
HR@5: 0.051961823966065745
HR@10: 0.10498409331919406
NDCG@1: 0.013785790031813362
NDCG@5: 0.0327449241223738
NDCG@10: 0.04984719607449497


                                                                          

Epoch 38: Avg Loss/Batch 0.444601            


                                                  

HR@1: 0.012725344644750796
HR@5: 0.051961823966065745
HR@10: 0.10498409331919406
NDCG@1: 0.012725344644750796
NDCG@5: 0.03235354528204661
NDCG@10: 0.049455817234167775


                                                                          

Epoch 39: Avg Loss/Batch 0.444601            


                                                   

HR@1: 0.012725344644750796
HR@5: 0.051961823966065745
HR@10: 0.10498409331919406
NDCG@1: 0.012725344644750796
NDCG@5: 0.03235354528204661
NDCG@10: 0.049455817234167775


                                                                          

Epoch 40: Avg Loss/Batch 0.444601            


                                                   

HR@1: 0.012725344644750796
HR@5: 0.051961823966065745
HR@10: 0.10498409331919406
NDCG@1: 0.012725344644750796
NDCG@5: 0.03235354528204661
NDCG@10: 0.049455817234167775


                                                                          

Epoch 41: Avg Loss/Batch 0.444601            


                                                   

HR@1: 0.012725344644750796
HR@5: 0.051961823966065745
HR@10: 0.10392364793213149
NDCG@1: 0.012725344644750796
NDCG@5: 0.03235354528204661
NDCG@10: 0.0491176428415786


                                                                          

Epoch 42: Avg Loss/Batch 0.444601            


                                                   

HR@1: 0.012725344644750796
HR@5: 0.051961823966065745
HR@10: 0.10392364793213149
NDCG@1: 0.012725344644750796
NDCG@5: 0.03235354528204661
NDCG@10: 0.04908069796022041


                                                                          

Epoch 43: Avg Loss/Batch 0.444601            


                                                   

HR@1: 0.013785790031813362
HR@5: 0.051961823966065745
HR@10: 0.10392364793213149
NDCG@1: 0.013785790031813362
NDCG@5: 0.0327449241223738
NDCG@10: 0.049472076800547604


                                                                          

Epoch 44: Avg Loss/Batch 0.444601            


                                                   

HR@1: 0.013785790031813362
HR@5: 0.051961823966065745
HR@10: 0.10498409331919406
NDCG@1: 0.013785790031813362
NDCG@5: 0.032810254251371525
NDCG@10: 0.04984394439117617


                                                                          

Epoch 45: Avg Loss/Batch 0.444601            


                                                   

HR@1: 0.013785790031813362
HR@5: 0.053022269353128315
HR@10: 0.10498409331919406
NDCG@1: 0.013785790031813362
NDCG@5: 0.0332204905262756
NDCG@10: 0.04989539091999284


                                                                          

Epoch 46: Avg Loss/Batch 0.444601            


                                                   

HR@1: 0.013785790031813362
HR@5: 0.053022269353128315
HR@10: 0.10498409331919406
NDCG@1: 0.013785790031813362
NDCG@5: 0.0332204905262756
NDCG@10: 0.049871134447273126


                                                                          

Epoch 47: Avg Loss/Batch 0.444601            


                                                  

HR@1: 0.013785790031813362
HR@5: 0.053022269353128315
HR@10: 0.10498409331919406
NDCG@1: 0.013785790031813362
NDCG@5: 0.0332204905262756
NDCG@10: 0.049871134447273126


                                                                          

Epoch 48: Avg Loss/Batch 0.444601            


                                                  

HR@1: 0.013785790031813362
HR@5: 0.053022269353128315
HR@10: 0.10498409331919406
NDCG@1: 0.013785790031813362
NDCG@5: 0.0332204905262756
NDCG@10: 0.049852185924953295


                                                                          

Epoch 49: Avg Loss/Batch 0.444601            


                                                   

HR@1: 0.013785790031813362
HR@5: 0.053022269353128315
HR@10: 0.10498409331919406
NDCG@1: 0.013785790031813362
NDCG@5: 0.0332204905262756
NDCG@10: 0.049852185924953295


In [None]:
# torch.save(trainer.model.state_dict(), "saved_models/mf.pt")
trainer.model.load_state_dict(torch.load("saved_models/mf.pt"))