In [1]:
import numpy as np
import random
import torch
from src.dataloaders import Dataset100k
from src.models import MLPBCEModel
from src.trainer import Trainer
from src.metrics import hitratio, ndcg

np.random.seed(42)
random.seed(42)
torch.manual_seed(42)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = 'cpu'
print(f"{device=}")

device='cuda'


In [2]:
class config:
    data_dir = 'ml-100k'
    epochs = 40
    batch_size = 2048
    dim = 32
    lr = 0.005
    weight_decay = 0.0001

dataset = Dataset100k(config.data_dir)
dataset.gen_adjacency()
dataset.make_train_test()
print(f"{dataset.train_size=}, {dataset.test_size=}")

metrics = {
    "HR@1": (hitratio, {"top_n": 1}),
    "HR@5": (hitratio, {"top_n": 5}),
    "HR@10": (hitratio, {"top_n": 10}),
    "NDCG@1": (ndcg, {"top_n": 1}),
    "NDCG@5": (ndcg, {"top_n": 5}),
    "NDCG@10": (ndcg, {"top_n": 10}),
}

dataset.train_size=198114, dataset.test_size=943


In [3]:
model = MLPBCEModel(dataset.user_count, dataset.item_count, embed_size=config.dim, layers=[32, 16, 8])

optimizer = torch.optim.Adam(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)

trainer = Trainer(dataset, model, optimizer, metrics, 
                  epochs=config.epochs, batch_size=config.batch_size,
                  device=device)

In [12]:
trainer.train(evaluate=True, verbose=True, progressbar=True)
# trainer.test(verbose=False, pbar=False)

                                                                   

Epoch 0: Avg Loss/Batch 0.292200            


                                                   

HR@1: 0.13149522799575822
HR@5: 0.3828207847295864
HR@10: 0.5471898197242842
NDCG@1: 0.13149522799575822
NDCG@5: 0.25693256483407984
NDCG@10: 0.3106867825003414


                                                                   

Epoch 1: Avg Loss/Batch 0.291306            


                                                   

HR@1: 0.1325556733828208
HR@5: 0.3828207847295864
HR@10: 0.5471898197242842
NDCG@1: 0.1325556733828208
NDCG@5: 0.25732394367440703
NDCG@10: 0.3111636191674662


                                                                   

Epoch 2: Avg Loss/Batch 0.290412            


                                                   

HR@1: 0.13043478260869565
HR@5: 0.3870625662778367
HR@10: 0.5482502651113468
NDCG@1: 0.13043478260869565
NDCG@5: 0.25803452941489774
NDCG@10: 0.31035708847691207


                                                                   

Epoch 3: Avg Loss/Batch 0.289516            


                                                   

HR@1: 0.13043478260869565
HR@5: 0.3860021208907741
HR@10: 0.5440084835630965
NDCG@1: 0.13043478260869565
NDCG@5: 0.25785608238203944
NDCG@10: 0.30952059793830866


                                                                   

Epoch 4: Avg Loss/Batch 0.288673            


                                                   

HR@1: 0.13043478260869565
HR@5: 0.38812301166489926
HR@10: 0.542948038176034
NDCG@1: 0.13043478260869565
NDCG@5: 0.2590088989278852
NDCG@10: 0.309520536488716


                                                                   

Epoch 5: Avg Loss/Batch 0.287828            


                                                   

HR@1: 0.13149522799575822
HR@5: 0.3828207847295864
HR@10: 0.542948038176034
NDCG@1: 0.13149522799575822
NDCG@5: 0.25699028559792725
NDCG@10: 0.3092258896467716


                                                                   

Epoch 6: Avg Loss/Batch 0.286979            


                                                   

HR@1: 0.13149522799575822
HR@5: 0.38494167550371156
HR@10: 0.5397667020148462
NDCG@1: 0.13149522799575822
NDCG@5: 0.2579966489254188
NDCG@10: 0.30853995167552356


                                                                   

Epoch 7: Avg Loss/Batch 0.286111            


                                                   

HR@1: 0.12725344644750794
HR@5: 0.38812301166489926
HR@10: 0.5418875927889714
NDCG@1: 0.12725344644750794
NDCG@5: 0.2574765258411973
NDCG@10: 0.3074417927593053


                                                                   

Epoch 8: Avg Loss/Batch 0.285315            


                                                   

HR@1: 0.12195121951219512
HR@5: 0.38494167550371156
HR@10: 0.542948038176034
NDCG@1: 0.12195121951219512
NDCG@5: 0.25436243653905544
NDCG@10: 0.30575476731537754


                                                                   

Epoch 9: Avg Loss/Batch 0.284453            


                                                   

HR@1: 0.12407211028632026
HR@5: 0.3902439024390244
HR@10: 0.5418875927889714
NDCG@1: 0.12407211028632026
NDCG@5: 0.2574740633006383
NDCG@10: 0.3067201821151889


                                                                    

Epoch 10: Avg Loss/Batch 0.283689            


                                                   

HR@1: 0.12407211028632026
HR@5: 0.3870625662778367
HR@10: 0.5440084835630965
NDCG@1: 0.12407211028632026
NDCG@5: 0.25709335901562635
NDCG@10: 0.3079618105709713


                                                                    

Epoch 11: Avg Loss/Batch 0.282847            


                                                   

HR@1: 0.12513255567338283
HR@5: 0.3870625662778367
HR@10: 0.545068928950159
NDCG@1: 0.12513255567338283
NDCG@5: 0.25682004986387824
NDCG@10: 0.30802143782289265


                                                                    

Epoch 12: Avg Loss/Batch 0.282078            


                                                   

HR@1: 0.12407211028632026
HR@5: 0.38918345705196183
HR@10: 0.5461293743372216
NDCG@1: 0.12407211028632026
NDCG@5: 0.25758909693454707
NDCG@10: 0.3083287874482592


                                                                    

Epoch 13: Avg Loss/Batch 0.281246            


                                                   

HR@1: 0.12513255567338283
HR@5: 0.3870625662778367
HR@10: 0.5440084835630965
NDCG@1: 0.12513255567338283
NDCG@5: 0.2576311910743079
NDCG@10: 0.30839027837716537


                                                                    

Epoch 14: Avg Loss/Batch 0.280471            


                                                   

HR@1: 0.12301166489925769
HR@5: 0.383881230116649
HR@10: 0.545068928950159
NDCG@1: 0.12301166489925769
NDCG@5: 0.2560348303586119
NDCG@10: 0.30818158582514116


                                                                    

Epoch 15: Avg Loss/Batch 0.279739            


                                                   

HR@1: 0.1261930010604454
HR@5: 0.383881230116649
HR@10: 0.5471898197242842
NDCG@1: 0.1261930010604454
NDCG@5: 0.25744075612163925
NDCG@10: 0.31035557831382576


                                                                    

Epoch 16: Avg Loss/Batch 0.278921            


                                                   

HR@1: 0.12513255567338283
HR@5: 0.38494167550371156
HR@10: 0.5471898197242842
NDCG@1: 0.12513255567338283
NDCG@5: 0.2572742970085913
NDCG@10: 0.3098278786415063


                                                                    

Epoch 17: Avg Loss/Batch 0.278134            


                                                   

HR@1: 0.12937433722163308
HR@5: 0.3828207847295864
HR@10: 0.5493107104984093
NDCG@1: 0.12937433722163308
NDCG@5: 0.2577146110838982
NDCG@10: 0.3117131829725854


                                                                    

Epoch 18: Avg Loss/Batch 0.277354            


                                                   

HR@1: 0.12725344644750794
HR@5: 0.383881230116649
HR@10: 0.5503711558854719
NDCG@1: 0.12725344644750794
NDCG@5: 0.2580614339598527
NDCG@10: 0.31188579912275316


                                                                    

Epoch 19: Avg Loss/Batch 0.276585            


                                                   

HR@1: 0.1261930010604454
HR@5: 0.3828207847295864
HR@10: 0.5503711558854719
NDCG@1: 0.1261930010604454
NDCG@5: 0.25671511727117335
NDCG@10: 0.310825670943278


                                                                    

Epoch 20: Avg Loss/Batch 0.275790            


                                                   

HR@1: 0.1261930010604454
HR@5: 0.3828207847295864
HR@10: 0.5461293743372216
NDCG@1: 0.1261930010604454
NDCG@5: 0.257262309088781
NDCG@10: 0.31009193985524164


                                                                    

Epoch 21: Avg Loss/Batch 0.275079            


                                                   

HR@1: 0.1283138918345705
HR@5: 0.3828207847295864
HR@10: 0.5471898197242842
NDCG@1: 0.1283138918345705
NDCG@5: 0.2586550984718813
NDCG@10: 0.3118576078231747


                                                                    

Epoch 22: Avg Loss/Batch 0.274341            


                                                   

HR@1: 0.1283138918345705
HR@5: 0.383881230116649
HR@10: 0.545068928950159
NDCG@1: 0.1283138918345705
NDCG@5: 0.25825112906213765
NDCG@10: 0.3105046580450104


                                                                    

Epoch 23: Avg Loss/Batch 0.273555            


                                                   

HR@1: 0.12937433722163308
HR@5: 0.3828207847295864
HR@10: 0.545068928950159
NDCG@1: 0.12937433722163308
NDCG@5: 0.2583716897108233
NDCG@10: 0.3109418255695054


                                                                    

Epoch 24: Avg Loss/Batch 0.272845            


                                                   

HR@1: 0.13043478260869565
HR@5: 0.3828207847295864
HR@10: 0.5461293743372216
NDCG@1: 0.13043478260869565
NDCG@5: 0.2587906838109945
NDCG@10: 0.31162405391647224


                                                                    

Epoch 25: Avg Loss/Batch 0.272124            


                                                   

HR@1: 0.13043478260869565
HR@5: 0.38494167550371156
HR@10: 0.5471898197242842
NDCG@1: 0.13043478260869565
NDCG@5: 0.2594635546823315
NDCG@10: 0.31182372923990587


                                                                    

Epoch 26: Avg Loss/Batch 0.271347            


                                                   

HR@1: 0.12725344644750794
HR@5: 0.3796394485683987
HR@10: 0.5461293743372216
NDCG@1: 0.12725344644750794
NDCG@5: 0.25718061248531304
NDCG@10: 0.3110533087278991


                                                                    

Epoch 27: Avg Loss/Batch 0.270658            


                                                   

HR@1: 0.12937433722163308
HR@5: 0.383881230116649
HR@10: 0.5461293743372216
NDCG@1: 0.12937433722163308
NDCG@5: 0.25914131101155063
NDCG@10: 0.31151941196668326


                                                                    

Epoch 28: Avg Loss/Batch 0.269936            


                                                   

HR@1: 0.13043478260869565
HR@5: 0.3860021208907741
HR@10: 0.545068928950159
NDCG@1: 0.13043478260869565
NDCG@5: 0.2605379047192525
NDCG@10: 0.311801907572759


                                                                    

Epoch 29: Avg Loss/Batch 0.269258            


                                                   

HR@1: 0.13149522799575822
HR@5: 0.3806998939554613
HR@10: 0.5461293743372216
NDCG@1: 0.13149522799575822
NDCG@5: 0.25879334039142626
NDCG@10: 0.31212350355010227


                                                                    

Epoch 30: Avg Loss/Batch 0.268542            


                                                   

HR@1: 0.1325556733828208
HR@5: 0.3796394485683987
HR@10: 0.5471898197242842
NDCG@1: 0.1325556733828208
NDCG@5: 0.25877448295684946
NDCG@10: 0.31289294988933547


                                                                    

Epoch 31: Avg Loss/Batch 0.267835            


                                                   

HR@1: 0.13043478260869565
HR@5: 0.3796394485683987
HR@10: 0.5461293743372216
NDCG@1: 0.13043478260869565
NDCG@5: 0.25768642230994293
NDCG@10: 0.3113724478763811


                                                                    

Epoch 32: Avg Loss/Batch 0.267148            


                                                   

HR@1: 0.13043478260869565
HR@5: 0.3753976670201485
HR@10: 0.5440084835630965
NDCG@1: 0.13043478260869565
NDCG@5: 0.25530727549404486
NDCG@10: 0.31000830222692183


                                                                    

Epoch 33: Avg Loss/Batch 0.266467            


                                                   

HR@1: 0.12937433722163308
HR@5: 0.37857900318133614
HR@10: 0.5418875927889714
NDCG@1: 0.12937433722163308
NDCG@5: 0.2558689177720217
NDCG@10: 0.30895860597024927


                                                                    

Epoch 34: Avg Loss/Batch 0.265782            


                                                   

HR@1: 0.13149522799575822
HR@5: 0.37645811240721105
HR@10: 0.5461293743372216
NDCG@1: 0.13149522799575822
NDCG@5: 0.25557237263103666
NDCG@10: 0.31053489681805074


                                                                    

Epoch 35: Avg Loss/Batch 0.265104            


                                                   

HR@1: 0.13043478260869565
HR@5: 0.37857900318133614
HR@10: 0.5440084835630965
NDCG@1: 0.13043478260869565
NDCG@5: 0.25618678288814256
NDCG@10: 0.30987631111437214


                                                                    

Epoch 36: Avg Loss/Batch 0.264477            


                                                   

HR@1: 0.12937433722163308
HR@5: 0.3743372216330859
HR@10: 0.5471898197242842
NDCG@1: 0.12937433722163308
NDCG@5: 0.25430148639661176
NDCG@10: 0.3102290269495078


                                                                    

Epoch 37: Avg Loss/Batch 0.263854            


                                                   

HR@1: 0.12513255567338283
HR@5: 0.37327677624602335
HR@10: 0.545068928950159
NDCG@1: 0.12513255567338283
NDCG@5: 0.2524181059191822
NDCG@10: 0.3083012243762159


                                                                    

Epoch 38: Avg Loss/Batch 0.263212            


                                                   

HR@1: 0.1261930010604454
HR@5: 0.3753976670201485
HR@10: 0.5482502651113468
NDCG@1: 0.1261930010604454
NDCG@5: 0.25313933779544046
NDCG@10: 0.30887655197912994


                                                                    

Epoch 39: Avg Loss/Batch 0.262561            


                                                   

HR@1: 0.12513255567338283
HR@5: 0.37327677624602335
HR@10: 0.5471898197242842
NDCG@1: 0.12513255567338283
NDCG@5: 0.251788642552101
NDCG@10: 0.3080984164361115




In [None]:
best_epoch = np.argmax([r["NDCG@10"] for r in trainer.test_log])
print(f"{best_epoch}: {trainer.test_log[best_epoch]}")

104: {'HR@1': 0.13679745493107104, 'HR@5': 0.3860021208907741, 'HR@10': 0.5535524920466596, 'NDCG@1': 0.13679745493107104, 'NDCG@5': 0.26183027654913826, 'NDCG@10': 0.3160622946925831}


In [None]:
best_epoch = np.argmax([r["NDCG@10"] for r in trainer.test_log])
print(f"{best_epoch}: {trainer.test_log[best_epoch]}")

77: {'HR@1': 0.1357370095440085, 'HR@5': 0.37751855779427357, 'HR@10': 0.5556733828207847, 'NDCG@1': 0.1357370095440085, 'NDCG@5': 0.25860058009897624, 'NDCG@10': 0.3158008204897394}


In [None]:
best_epoch = np.argmax([r["NDCG@10"] for r in trainer.test_log])
print(f"{best_epoch}: {trainer.test_log[best_epoch]}")

25: {'HR@1': 0.1357370095440085, 'HR@5': 0.34994697773064687, 'HR@10': 0.5471898197242842, 'NDCG@1': 0.1357370095440085, 'NDCG@5': 0.24658333238245064, 'NDCG@10': 0.30979900867095866}


In [None]:
torch.save(trainer.model.state_dict(), "saved_models/nmfbce.pt")
# trainer.model.load_state_dict(torch.load("saved_models/nmfbce.pt"))