In [1]:
import json
import optuna
import os
import torch
from src.models.evaluate import check_ndcg_on_val_set
from src.models.train import TrainKNRM

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
PARENT_DIR = os.path.abspath(os.path.join('', os.pardir))
TRAIN_PATH = PARENT_DIR + '/data/raw/QQP/train.tsv'
VAL_PATH = PARENT_DIR + '/data/raw/QQP/dev.tsv'
GLOVE_PATH = PARENT_DIR + '/data/raw/glove.6B.50d.txt'

In [3]:
def objective(trial, train_path=TRAIN_PATH, val_path=VAL_PATH, glove_path=GLOVE_PATH):
    params = {
        'freeze_emb': trial.suggest_categorical('freeze_emb', [False, True]),
        'min_token_occurancies': trial.suggest_categorical('min_token_occurancies', [1, 2]),
        'num_kernels': trial.suggest_categorical('num_kernels', list(range(10, 25))),
        'sigma': trial.suggest_float('sigma', 1e-4, 1e-1),
        'lr': trial.suggest_float('lr', 1e-2, 6e-2),
        'num_pos_ex': trial.suggest_categorical('num_pos_ex', [2, 3, 4]),
        'num_same_rel_ex': trial.suggest_categorical('num_same_rel_ex', [2, 3, 4]),
    }
    
    model = TrainKNRM(train_path=train_path, val_path=val_path, glove_path=glove_path,
                      random_vec_bound=1.0, out_layers=[], seed=0, num_epochs=9,
                      batch_size=1024, change_every_num_ep=10, **params)
    model.get_ready_for_train()
    model.fit()
    val_ndcg = check_ndcg_on_val_set(model.knrm, model.val_dataloader)
    return val_ndcg

In [4]:
study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=30)
print('Number of finished trials:', len(study.trials))
print('Best trial:', study.best_trial.params)

[32m[I 2023-03-24 17:08:38,297][0m A new study created in memory with name: no-name-97328a4f-e45b-48c8-b7a1-f3d98d35b290[0m


Epoch: 0, validation ndcg 0.4999908658236024
Epoch: 1, validation ndcg 0.5408568331823683
Epoch: 2, validation ndcg 0.6078280302676119
Epoch: 3, validation ndcg 0.6972278865509562
Epoch: 4, validation ndcg 0.7654122405038731
Epoch: 5, validation ndcg 0.7939163805685272
Epoch: 6, validation ndcg 0.8208769793682593
Epoch: 7, validation ndcg 0.8421623581951306
Epoch: 8, validation ndcg 0.8569277655727069


[32m[I 2023-03-24 17:14:33,250][0m Trial 0 finished with value: 0.8569277655727069 and parameters: {'freeze_emb': False, 'min_token_occurancies': 1, 'num_kernels': 18, 'sigma': 0.02919610084991252, 'lr': 0.01789777003586229, 'num_pos_ex': 4, 'num_same_rel_ex': 2}. Best is trial 0 with value: 0.8569277655727069.[0m


Epoch: 0, validation ndcg 0.6910064902423401
Epoch: 1, validation ndcg 0.7536045953860123
Epoch: 2, validation ndcg 0.7867215964391047
Epoch: 3, validation ndcg 0.8057299715222321
Epoch: 4, validation ndcg 0.8181194595037061
Epoch: 5, validation ndcg 0.8134855432719692
Epoch: 6, validation ndcg 0.8236075069646928
Epoch: 7, validation ndcg 0.8000949431178515
Epoch: 8, validation ndcg 0.8239285272041342


[32m[I 2023-03-24 17:20:01,460][0m Trial 1 finished with value: 0.8239285272041342 and parameters: {'freeze_emb': False, 'min_token_occurancies': 1, 'num_kernels': 10, 'sigma': 0.06658337489032985, 'lr': 0.03293379269262377, 'num_pos_ex': 3, 'num_same_rel_ex': 3}. Best is trial 0 with value: 0.8569277655727069.[0m


Epoch: 0, validation ndcg 0.5462126999137429
Epoch: 1, validation ndcg 0.5835551129638429
Epoch: 2, validation ndcg 0.6299016394466223
Epoch: 3, validation ndcg 0.6713568987053395
Epoch: 4, validation ndcg 0.7104609191727518
Epoch: 5, validation ndcg 0.7434487869679933
Epoch: 6, validation ndcg 0.7638235546933912
Epoch: 7, validation ndcg 0.7835979377126969
Epoch: 8, validation ndcg 0.8041134924868433


[32m[I 2023-03-24 17:25:55,637][0m Trial 2 finished with value: 0.8041134924868433 and parameters: {'freeze_emb': True, 'min_token_occurancies': 1, 'num_kernels': 21, 'sigma': 0.033648331994969895, 'lr': 0.016713026153585005, 'num_pos_ex': 2, 'num_same_rel_ex': 3}. Best is trial 0 with value: 0.8569277655727069.[0m


Epoch: 0, validation ndcg 0.7071913493247128
Epoch: 1, validation ndcg 0.7805024414731012
Epoch: 2, validation ndcg 0.8029391720686986
Epoch: 3, validation ndcg 0.8325072245645103
Epoch: 4, validation ndcg 0.84330078890516
Epoch: 5, validation ndcg 0.8092954433292243
Epoch: 6, validation ndcg 0.8085583177698248
Epoch: 7, validation ndcg 0.8458482474219322
Epoch: 8, validation ndcg 0.8215980209018787


[32m[I 2023-03-24 17:31:13,283][0m Trial 3 finished with value: 0.8215980209018787 and parameters: {'freeze_emb': True, 'min_token_occurancies': 2, 'num_kernels': 10, 'sigma': 0.06178577554659226, 'lr': 0.04013418892172624, 'num_pos_ex': 3, 'num_same_rel_ex': 2}. Best is trial 0 with value: 0.8569277655727069.[0m


Epoch: 0, validation ndcg 0.4366095148931582
Epoch: 1, validation ndcg 0.6122283300812665
Epoch: 2, validation ndcg 0.7258644849190905
Epoch: 3, validation ndcg 0.7956836693194993
Epoch: 4, validation ndcg 0.8222408565798693
Epoch: 5, validation ndcg 0.8346138465177824
Epoch: 6, validation ndcg 0.8490149311084162
Epoch: 7, validation ndcg 0.8508174474819797
Epoch: 8, validation ndcg 0.8332070124548439


[32m[I 2023-03-24 17:37:02,113][0m Trial 4 finished with value: 0.8332070124548439 and parameters: {'freeze_emb': True, 'min_token_occurancies': 2, 'num_kernels': 13, 'sigma': 0.03112178008443607, 'lr': 0.025246562323033545, 'num_pos_ex': 4, 'num_same_rel_ex': 2}. Best is trial 0 with value: 0.8569277655727069.[0m


Epoch: 0, validation ndcg 0.3140124998833172
Epoch: 1, validation ndcg 0.39181881793792056
Epoch: 2, validation ndcg 0.446404342201079
Epoch: 3, validation ndcg 0.4935780088851667
Epoch: 4, validation ndcg 0.5333427874070024
Epoch: 5, validation ndcg 0.5885862368623387
Epoch: 6, validation ndcg 0.6496882942093846
Epoch: 7, validation ndcg 0.6914449070603449
Epoch: 8, validation ndcg 0.7250112604670533


[32m[I 2023-03-24 17:43:19,412][0m Trial 5 finished with value: 0.7250112604670533 and parameters: {'freeze_emb': True, 'min_token_occurancies': 1, 'num_kernels': 15, 'sigma': 0.019056836702174427, 'lr': 0.015314650440810547, 'num_pos_ex': 2, 'num_same_rel_ex': 3}. Best is trial 0 with value: 0.8569277655727069.[0m


Epoch: 0, validation ndcg 0.5456098044278181
Epoch: 1, validation ndcg 0.665283700287922
Epoch: 2, validation ndcg 0.7011370133267141
Epoch: 3, validation ndcg 0.7192229572624221
Epoch: 4, validation ndcg 0.7323902399839957
Epoch: 5, validation ndcg 0.7460734388510555
Epoch: 6, validation ndcg 0.7545258920969709
Epoch: 7, validation ndcg 0.7581304314609256
Epoch: 8, validation ndcg 0.7649617846972036


[32m[I 2023-03-24 17:49:10,659][0m Trial 6 finished with value: 0.7649617846972036 and parameters: {'freeze_emb': True, 'min_token_occurancies': 1, 'num_kernels': 10, 'sigma': 0.014494242075363178, 'lr': 0.011447205881410134, 'num_pos_ex': 3, 'num_same_rel_ex': 4}. Best is trial 0 with value: 0.8569277655727069.[0m


Epoch: 0, validation ndcg 0.7395042532236352
Epoch: 1, validation ndcg 0.8099356994958518
Epoch: 2, validation ndcg 0.8335159835661999
Epoch: 3, validation ndcg 0.8230621674238575
Epoch: 4, validation ndcg 0.8339652851643591
Epoch: 5, validation ndcg 0.808897446869677
Epoch: 6, validation ndcg 0.8425196062145367
Epoch: 7, validation ndcg 0.8219520310707944
Epoch: 8, validation ndcg 0.8388110438134614


[32m[I 2023-03-24 17:54:28,425][0m Trial 7 finished with value: 0.8388110438134614 and parameters: {'freeze_emb': True, 'min_token_occurancies': 2, 'num_kernels': 10, 'sigma': 0.09641201945728675, 'lr': 0.04367125657119838, 'num_pos_ex': 4, 'num_same_rel_ex': 2}. Best is trial 0 with value: 0.8569277655727069.[0m


Epoch: 0, validation ndcg 0.46966811515041856
Epoch: 1, validation ndcg 0.6045862457283313
Epoch: 2, validation ndcg 0.7525403786347759
Epoch: 3, validation ndcg 0.8245916664385019
Epoch: 4, validation ndcg 0.8692841554681817
Epoch: 5, validation ndcg 0.8662210138102927
Epoch: 6, validation ndcg 0.8632084164710533
Epoch: 7, validation ndcg 0.87099092757768
Epoch: 8, validation ndcg 0.8570790655508024


[32m[I 2023-03-24 18:00:32,873][0m Trial 8 finished with value: 0.8570790655508024 and parameters: {'freeze_emb': False, 'min_token_occurancies': 1, 'num_kernels': 17, 'sigma': 0.03170117877576946, 'lr': 0.02852924205694775, 'num_pos_ex': 4, 'num_same_rel_ex': 2}. Best is trial 8 with value: 0.8570790655508024.[0m


Epoch: 0, validation ndcg 0.7129066441273217
Epoch: 1, validation ndcg 0.7176544192873874
Epoch: 2, validation ndcg 0.7295938603254527
Epoch: 3, validation ndcg 0.7256914529619776
Epoch: 4, validation ndcg 0.7098361516102503
Epoch: 5, validation ndcg 0.7168689193018689
Epoch: 6, validation ndcg 0.7152329294208396
Epoch: 7, validation ndcg 0.7210785878251303
Epoch: 8, validation ndcg 0.7215900158330673


[32m[I 2023-03-24 18:06:09,041][0m Trial 9 finished with value: 0.7215900158330673 and parameters: {'freeze_emb': False, 'min_token_occurancies': 1, 'num_kernels': 10, 'sigma': 0.0020281368675768427, 'lr': 0.046749067658169045, 'num_pos_ex': 3, 'num_same_rel_ex': 2}. Best is trial 8 with value: 0.8570790655508024.[0m


Epoch: 0, validation ndcg 0.6904157110000462
Epoch: 1, validation ndcg 0.7274941694794896
Epoch: 2, validation ndcg 0.7328800555254272
Epoch: 3, validation ndcg 0.8449681493911871
Epoch: 4, validation ndcg 0.7530766844985178
Epoch: 5, validation ndcg 0.6032434270583402
Epoch: 6, validation ndcg 0.7328205439862929
Epoch: 7, validation ndcg 0.772510884047465
Epoch: 8, validation ndcg 0.6789389982009238


[32m[I 2023-03-24 18:12:08,915][0m Trial 10 finished with value: 0.6789389982009238 and parameters: {'freeze_emb': False, 'min_token_occurancies': 1, 'num_kernels': 17, 'sigma': 0.050562647227776876, 'lr': 0.059867339165684624, 'num_pos_ex': 4, 'num_same_rel_ex': 4}. Best is trial 8 with value: 0.8570790655508024.[0m


Epoch: 0, validation ndcg 0.5155516634450985
Epoch: 1, validation ndcg 0.5934097098539428
Epoch: 2, validation ndcg 0.7146630649460921
Epoch: 3, validation ndcg 0.8015883049383533
Epoch: 4, validation ndcg 0.8436361875164362
Epoch: 5, validation ndcg 0.8304632317910102
Epoch: 6, validation ndcg 0.8480092602408902
Epoch: 7, validation ndcg 0.8618223296793275
Epoch: 8, validation ndcg 0.8787920857481054


[32m[I 2023-03-24 18:18:09,124][0m Trial 11 finished with value: 0.8787920857481054 and parameters: {'freeze_emb': False, 'min_token_occurancies': 1, 'num_kernels': 18, 'sigma': 0.03513255787175094, 'lr': 0.02263052821483228, 'num_pos_ex': 4, 'num_same_rel_ex': 2}. Best is trial 11 with value: 0.8787920857481054.[0m


Epoch: 0, validation ndcg 0.5378159728255021
Epoch: 1, validation ndcg 0.6698191174082979
Epoch: 2, validation ndcg 0.7889056026096435
Epoch: 3, validation ndcg 0.8459283260581534
Epoch: 4, validation ndcg 0.8634766176817925
Epoch: 5, validation ndcg 0.8307357725014931
Epoch: 6, validation ndcg 0.8538748729073189
Epoch: 7, validation ndcg 0.8659725330430197
Epoch: 8, validation ndcg 0.8842465871584155


[32m[I 2023-03-24 18:24:07,024][0m Trial 12 finished with value: 0.8842465871584155 and parameters: {'freeze_emb': False, 'min_token_occurancies': 1, 'num_kernels': 18, 'sigma': 0.04244369635985321, 'lr': 0.026936622928302213, 'num_pos_ex': 4, 'num_same_rel_ex': 2}. Best is trial 12 with value: 0.8842465871584155.[0m


Epoch: 0, validation ndcg 0.5274884308074719
Epoch: 1, validation ndcg 0.6446558779756362
Epoch: 2, validation ndcg 0.7706539086123781
Epoch: 3, validation ndcg 0.8366608598099837
Epoch: 4, validation ndcg 0.8569604109606063
Epoch: 5, validation ndcg 0.8315509947095109
Epoch: 6, validation ndcg 0.8515447779374934
Epoch: 7, validation ndcg 0.8612613202183937
Epoch: 8, validation ndcg 0.8829973244307151


[32m[I 2023-03-24 18:30:05,477][0m Trial 13 finished with value: 0.8829973244307151 and parameters: {'freeze_emb': False, 'min_token_occurancies': 1, 'num_kernels': 18, 'sigma': 0.04502620853777602, 'lr': 0.02388062612433616, 'num_pos_ex': 4, 'num_same_rel_ex': 2}. Best is trial 12 with value: 0.8842465871584155.[0m


Epoch: 0, validation ndcg 0.5630098691998795
Epoch: 1, validation ndcg 0.7412720501710279
Epoch: 2, validation ndcg 0.8192551880382098
Epoch: 3, validation ndcg 0.8588474293885447
Epoch: 4, validation ndcg 0.8711406556044119
Epoch: 5, validation ndcg 0.8233281901112219
Epoch: 6, validation ndcg 0.8540042609483895
Epoch: 7, validation ndcg 0.8640411561888508
Epoch: 8, validation ndcg 0.884757673239304


[32m[I 2023-03-24 18:35:59,145][0m Trial 14 finished with value: 0.884757673239304 and parameters: {'freeze_emb': False, 'min_token_occurancies': 1, 'num_kernels': 18, 'sigma': 0.050537134338129945, 'lr': 0.031113348263449088, 'num_pos_ex': 4, 'num_same_rel_ex': 2}. Best is trial 14 with value: 0.884757673239304.[0m


Epoch: 0, validation ndcg 0.40290598987498094
Epoch: 1, validation ndcg 0.7234812104723193
Epoch: 2, validation ndcg 0.8286585675221333
Epoch: 3, validation ndcg 0.8221802132672457
Epoch: 4, validation ndcg 0.8262082802508657
Epoch: 5, validation ndcg 0.8204710488629066
Epoch: 6, validation ndcg 0.8607850738596411
Epoch: 7, validation ndcg 0.8440784437800377
Epoch: 8, validation ndcg 0.7777695193405296


[32m[I 2023-03-24 18:41:29,701][0m Trial 15 finished with value: 0.7777695193405296 and parameters: {'freeze_emb': False, 'min_token_occurancies': 2, 'num_kernels': 16, 'sigma': 0.06091006542572825, 'lr': 0.03151531312529715, 'num_pos_ex': 4, 'num_same_rel_ex': 4}. Best is trial 14 with value: 0.884757673239304.[0m


Epoch: 0, validation ndcg 0.7256163158328283
Epoch: 1, validation ndcg 0.6571319765697559
Epoch: 2, validation ndcg 0.7457355665016988
Epoch: 3, validation ndcg 0.7077280483652613
Epoch: 4, validation ndcg 0.7860599636295607
Epoch: 5, validation ndcg 0.7810276967829258
Epoch: 6, validation ndcg 0.7641784787005991
Epoch: 7, validation ndcg 0.8142508998109004
Epoch: 8, validation ndcg 0.7743898824618262


[32m[I 2023-03-24 18:47:17,012][0m Trial 16 finished with value: 0.7743898824618262 and parameters: {'freeze_emb': False, 'min_token_occurancies': 1, 'num_kernels': 24, 'sigma': 0.07386569059750386, 'lr': 0.03591278124328795, 'num_pos_ex': 2, 'num_same_rel_ex': 2}. Best is trial 14 with value: 0.884757673239304.[0m


Epoch: 0, validation ndcg 0.6767259188692447
Epoch: 1, validation ndcg 0.7468875889337914
Epoch: 2, validation ndcg 0.8021819358456931
Epoch: 3, validation ndcg 0.7986251210943567
Epoch: 4, validation ndcg 0.8474258244069819
Epoch: 5, validation ndcg 0.8455745312048858
Epoch: 6, validation ndcg 0.8715365270099298
Epoch: 7, validation ndcg 0.8699578241189004
Epoch: 8, validation ndcg 0.8311341761792491


[32m[I 2023-03-24 18:53:23,174][0m Trial 17 finished with value: 0.8311341761792491 and parameters: {'freeze_emb': False, 'min_token_occurancies': 1, 'num_kernels': 22, 'sigma': 0.045791260405216035, 'lr': 0.026925776390107105, 'num_pos_ex': 4, 'num_same_rel_ex': 2}. Best is trial 14 with value: 0.884757673239304.[0m


Epoch: 0, validation ndcg 0.46851603301333916
Epoch: 1, validation ndcg 0.6279081236834234
Epoch: 2, validation ndcg 0.7005904300893309
Epoch: 3, validation ndcg 0.7226813540614454
Epoch: 4, validation ndcg 0.7977438541199872
Epoch: 5, validation ndcg 0.8090954183407806
Epoch: 6, validation ndcg 0.835583627566719
Epoch: 7, validation ndcg 0.8102514174004826
Epoch: 8, validation ndcg 0.8401872396342511


[32m[I 2023-03-24 18:59:20,559][0m Trial 18 finished with value: 0.8401872396342511 and parameters: {'freeze_emb': False, 'min_token_occurancies': 2, 'num_kernels': 19, 'sigma': 0.05213441731452412, 'lr': 0.03480891132322845, 'num_pos_ex': 4, 'num_same_rel_ex': 4}. Best is trial 14 with value: 0.884757673239304.[0m


Epoch: 0, validation ndcg 0.36976915729901416
Epoch: 1, validation ndcg 0.5325709068333946
Epoch: 2, validation ndcg 0.692228088538499
Epoch: 3, validation ndcg 0.7467202207100445
Epoch: 4, validation ndcg 0.7761654818989508
Epoch: 5, validation ndcg 0.799121459722008
Epoch: 6, validation ndcg 0.807095818214659
Epoch: 7, validation ndcg 0.8245286320148809
Epoch: 8, validation ndcg 0.8339186344530718


[32m[I 2023-03-24 19:04:45,464][0m Trial 19 finished with value: 0.8339186344530718 and parameters: {'freeze_emb': False, 'min_token_occurancies': 1, 'num_kernels': 14, 'sigma': 0.07509750653602855, 'lr': 0.021565493311945013, 'num_pos_ex': 2, 'num_same_rel_ex': 3}. Best is trial 14 with value: 0.884757673239304.[0m


Epoch: 0, validation ndcg 0.6276147664778664
Epoch: 1, validation ndcg 0.7042737026846781
Epoch: 2, validation ndcg 0.7500306718921357
Epoch: 3, validation ndcg 0.7728634712977299
Epoch: 4, validation ndcg 0.7757254142729538
Epoch: 5, validation ndcg 0.7954995184454227
Epoch: 6, validation ndcg 0.8098539328774739
Epoch: 7, validation ndcg 0.8055479316299091
Epoch: 8, validation ndcg 0.8261922665294971


[32m[I 2023-03-24 19:11:02,452][0m Trial 20 finished with value: 0.8261922665294971 and parameters: {'freeze_emb': False, 'min_token_occurancies': 1, 'num_kernels': 23, 'sigma': 0.043350408545146456, 'lr': 0.027850387916877738, 'num_pos_ex': 4, 'num_same_rel_ex': 2}. Best is trial 14 with value: 0.884757673239304.[0m


Epoch: 0, validation ndcg 0.5193645082538918
Epoch: 1, validation ndcg 0.6093287624420354
Epoch: 2, validation ndcg 0.7311686151834248
Epoch: 3, validation ndcg 0.815285152401234
Epoch: 4, validation ndcg 0.8485046123576153
Epoch: 5, validation ndcg 0.8304862584406625
Epoch: 6, validation ndcg 0.8492969737913927
Epoch: 7, validation ndcg 0.8600304532622723
Epoch: 8, validation ndcg 0.8793134079951107


[32m[I 2023-03-24 19:16:58,831][0m Trial 21 finished with value: 0.8793134079951107 and parameters: {'freeze_emb': False, 'min_token_occurancies': 1, 'num_kernels': 18, 'sigma': 0.0411502177225315, 'lr': 0.02179746188424296, 'num_pos_ex': 4, 'num_same_rel_ex': 2}. Best is trial 14 with value: 0.884757673239304.[0m


Epoch: 0, validation ndcg 0.5550054559813599
Epoch: 1, validation ndcg 0.734574645050735
Epoch: 2, validation ndcg 0.8156568994501834
Epoch: 3, validation ndcg 0.8561652963961741
Epoch: 4, validation ndcg 0.8689348063487334
Epoch: 5, validation ndcg 0.8226695546116163
Epoch: 6, validation ndcg 0.8538070697542093
Epoch: 7, validation ndcg 0.8634667336493768
Epoch: 8, validation ndcg 0.8854371359627972


[32m[I 2023-03-24 19:22:40,073][0m Trial 22 finished with value: 0.8854371359627972 and parameters: {'freeze_emb': False, 'min_token_occurancies': 1, 'num_kernels': 18, 'sigma': 0.053412400073788134, 'lr': 0.02956711289116614, 'num_pos_ex': 4, 'num_same_rel_ex': 2}. Best is trial 22 with value: 0.8854371359627972.[0m


Epoch: 0, validation ndcg 0.5211779535733554
Epoch: 1, validation ndcg 0.6597913426479479
Epoch: 2, validation ndcg 0.759928534043926
Epoch: 3, validation ndcg 0.8054136349100572
Epoch: 4, validation ndcg 0.8403705137643803
Epoch: 5, validation ndcg 0.8582787107647402
Epoch: 6, validation ndcg 0.8202329588678141
Epoch: 7, validation ndcg 0.8709689713457953
Epoch: 8, validation ndcg 0.8486875548460024


[32m[I 2023-03-24 19:28:25,412][0m Trial 23 finished with value: 0.8486875548460024 and parameters: {'freeze_emb': False, 'min_token_occurancies': 1, 'num_kernels': 20, 'sigma': 0.05326091932664389, 'lr': 0.03063915376131891, 'num_pos_ex': 4, 'num_same_rel_ex': 2}. Best is trial 22 with value: 0.8854371359627972.[0m


Epoch: 0, validation ndcg 0.6759047219067625
Epoch: 1, validation ndcg 0.7497722962389591
Epoch: 2, validation ndcg 0.7878909201239717
Epoch: 3, validation ndcg 0.8038502457827252
Epoch: 4, validation ndcg 0.8183717588236
Epoch: 5, validation ndcg 0.843799364721312
Epoch: 6, validation ndcg 0.828819510815372
Epoch: 7, validation ndcg 0.8358168637444359
Epoch: 8, validation ndcg 0.8415448408460658


[32m[I 2023-03-24 19:33:37,888][0m Trial 24 finished with value: 0.8415448408460658 and parameters: {'freeze_emb': False, 'min_token_occurancies': 1, 'num_kernels': 11, 'sigma': 0.0521879422336814, 'lr': 0.03635311710176743, 'num_pos_ex': 4, 'num_same_rel_ex': 2}. Best is trial 22 with value: 0.8854371359627972.[0m


Epoch: 0, validation ndcg 0.5488041862962498
Epoch: 1, validation ndcg 0.7280407832201482
Epoch: 2, validation ndcg 0.8134765978127172
Epoch: 3, validation ndcg 0.8530272949625977
Epoch: 4, validation ndcg 0.8669480404257264
Epoch: 5, validation ndcg 0.8183141821993496
Epoch: 6, validation ndcg 0.8513937504961697
Epoch: 7, validation ndcg 0.8620616460917953
Epoch: 8, validation ndcg 0.8859967063176127


[32m[I 2023-03-24 19:39:11,535][0m Trial 25 finished with value: 0.8859967063176127 and parameters: {'freeze_emb': False, 'min_token_occurancies': 1, 'num_kernels': 18, 'sigma': 0.058603207330843614, 'lr': 0.028682702082785127, 'num_pos_ex': 4, 'num_same_rel_ex': 2}. Best is trial 25 with value: 0.8859967063176127.[0m


Epoch: 0, validation ndcg 0.6265750672031521
Epoch: 1, validation ndcg 0.752718582007985
Epoch: 2, validation ndcg 0.8244749593765469
Epoch: 3, validation ndcg 0.845581455136606
Epoch: 4, validation ndcg 0.8642939369774235
Epoch: 5, validation ndcg 0.8711142029889972
Epoch: 6, validation ndcg 0.8365591623768426
Epoch: 7, validation ndcg 0.858682659639196
Epoch: 8, validation ndcg 0.8566655488944341


[32m[I 2023-03-24 19:44:20,492][0m Trial 26 finished with value: 0.8566655488944341 and parameters: {'freeze_emb': False, 'min_token_occurancies': 2, 'num_kernels': 12, 'sigma': 0.05865779801344104, 'lr': 0.0310427386924709, 'num_pos_ex': 4, 'num_same_rel_ex': 2}. Best is trial 25 with value: 0.8859967063176127.[0m


Epoch: 0, validation ndcg 0.5156346183904225
Epoch: 1, validation ndcg 0.6172967020091513
Epoch: 2, validation ndcg 0.7461797937829543
Epoch: 3, validation ndcg 0.7959813106022549
Epoch: 4, validation ndcg 0.8185477962086982
Epoch: 5, validation ndcg 0.8003758862308809
Epoch: 6, validation ndcg 0.8206176599246396
Epoch: 7, validation ndcg 0.8355236516287206
Epoch: 8, validation ndcg 0.8636310889916566


[32m[I 2023-03-24 19:49:51,344][0m Trial 27 finished with value: 0.8636310889916566 and parameters: {'freeze_emb': False, 'min_token_occurancies': 1, 'num_kernels': 18, 'sigma': 0.07015793703925031, 'lr': 0.01995065497485282, 'num_pos_ex': 4, 'num_same_rel_ex': 2}. Best is trial 25 with value: 0.8859967063176127.[0m


Epoch: 0, validation ndcg 0.524288201433482
Epoch: 1, validation ndcg 0.6945383882163093
Epoch: 2, validation ndcg 0.7806298441112284
Epoch: 3, validation ndcg 0.7881938773304182
Epoch: 4, validation ndcg 0.7876545777911947
Epoch: 5, validation ndcg 0.8045577812394312
Epoch: 6, validation ndcg 0.8153696497734997
Epoch: 7, validation ndcg 0.8338897473751695
Epoch: 8, validation ndcg 0.8173318155495765


[32m[I 2023-03-24 19:55:20,700][0m Trial 28 finished with value: 0.8173318155495765 and parameters: {'freeze_emb': False, 'min_token_occurancies': 1, 'num_kernels': 18, 'sigma': 0.07870836188058805, 'lr': 0.024890727009044412, 'num_pos_ex': 3, 'num_same_rel_ex': 3}. Best is trial 25 with value: 0.8859967063176127.[0m


Epoch: 0, validation ndcg 0.43902903119057957
Epoch: 1, validation ndcg 0.5697943627357523
Epoch: 2, validation ndcg 0.6730489884105669
Epoch: 3, validation ndcg 0.7453054836845033
Epoch: 4, validation ndcg 0.7638719573524476
Epoch: 5, validation ndcg 0.7907860573537872
Epoch: 6, validation ndcg 0.7405044510340801
Epoch: 7, validation ndcg 0.8082786293821322
Epoch: 8, validation ndcg 0.764581740446755


[32m[I 2023-03-24 20:00:51,495][0m Trial 29 finished with value: 0.764581740446755 and parameters: {'freeze_emb': False, 'min_token_occurancies': 1, 'num_kernels': 18, 'sigma': 0.05753951453781773, 'lr': 0.018068288050994155, 'num_pos_ex': 2, 'num_same_rel_ex': 4}. Best is trial 25 with value: 0.8859967063176127.[0m


Number of finished trials: 30
Best trial: {'freeze_emb': False, 'min_token_occurancies': 1, 'num_kernels': 18, 'sigma': 0.058603207330843614, 'lr': 0.028682702082785127, 'num_pos_ex': 4, 'num_same_rel_ex': 2}


In [5]:
print('Number of finished trials:', len(study.trials))
print('Best trial:', study.best_trial.params)

Number of finished trials: 30
Best trial: {'freeze_emb': False, 'min_token_occurancies': 1, 'num_kernels': 18, 'sigma': 0.058603207330843614, 'lr': 0.028682702082785127, 'num_pos_ex': 4, 'num_same_rel_ex': 2}


In [4]:
best_params = {'freeze_emb': False, 'min_token_occurancies': 1, 'num_kernels': 18,
               'sigma': 0.058603207330843614, 'lr': 0.028682702082785127, 'num_pos_ex': 4,
               'num_same_rel_ex': 2}

In [5]:
# ndcg can vary within ~0.005 but it should end training on 9 epoch 
model = TrainKNRM(train_path=TRAIN_PATH, val_path=VAL_PATH, glove_path=GLOVE_PATH,
                  random_vec_bound=1.0, out_layers=[], seed=0, num_epochs=30,
                  batch_size=1024, change_every_num_ep=10, **best_params)
model.fit(benchmark_ndcg_score=0.88)

Epoch: 0, validation ndcg 0.547091338064236
Epoch: 1, validation ndcg 0.7272683947089158
Epoch: 2, validation ndcg 0.8101669488836125
Epoch: 3, validation ndcg 0.8554089298265041
Epoch: 4, validation ndcg 0.8704065931207523
Epoch: 5, validation ndcg 0.8181210311344154
Epoch: 6, validation ndcg 0.8499429726372575
Epoch: 7, validation ndcg 0.8606942492454339
Epoch: 8, validation ndcg 0.8886904451674117


In [8]:
MLP_SAVE_PATH = PARENT_DIR + '/models/knrm_mlp.bin'
EMB_SAVE_PATH = PARENT_DIR + '/models/knrm_emb.bin'
VOCAB_SAVE_PATH = PARENT_DIR + '/models/vocab.json'
state_mlp = model.knrm.mlp.state_dict()
state_emb = model.knrm.embeddings.state_dict()
vocab = model.vocab
torch.save(state_mlp, open(MLP_SAVE_PATH, 'wb'))
torch.save(state_emb, open(EMB_SAVE_PATH, 'wb'))
json.dump(vocab, open(VOCAB_SAVE_PATH, 'w', encoding='utf-8'),
          ensure_ascii=False, indent=4)