In [5]:
%load_ext autoreload
%autoreload 2

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import pickle
from collections import defaultdict
import sys

from torch.utils.data import Dataset, DataLoader
from Mmetrics import *

import LTR
import datautil
import permutationgraph
import DTR
import EEL
import PPG
import PL

def df2ds(df_path):
    with open(df_path, 'rb') as f:
        df = pickle.load(f)
    ds = df.to_dict(orient='list')
    for k in ds:
        ds[k] = np.array(ds[k])
    ds['dlr'] = np.concatenate([np.zeros(1), np.where(np.diff(ds['qid'])==1)[0]+1, np.array([ds['qid'].shape[0]])]).astype(int)
    return type('ltr', (object,), ds)


def dict2ds(df_path):
    with open(df_path, 'rb') as f:
        ds = pickle.load(f)
    return type('ltr', (object,), ds)

ds2019 = df2ds('LTR2019.df')
ds2020 = df2ds('LTR2020.df')


def valid_queries(ds):
    dtr, eel = [], []
    groups = np.unique(ds.g)
    for qid in range(ds.dlr.shape[0] - 1):
        s, e = ds.dlr[qid:qid+2]
        lv = ds.lv[s:e]
        g = ds.g[s:e]
        z = False
        for group in groups:
            if lv[g==group].sum() == 0:
                z = True
                break
        if not z:
            dtr.append(qid)
        if len(np.unique(ds.g[s:e])) > 1:
            eel.append(qid)
            
    return {'DTR':np.array(dtr), 'EEL':np.array(eel)}

ds2019.valids = valid_queries(ds2019)
ds2020.valids = valid_queries(ds2020)


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
from tqdm.notebook import trange

exposure2020 = np.array([1./np.log2(2+i) for i in range(1,np.diff(ds2020.dlr).max()+2)])
exposure2019 = np.array([1./np.log2(2+i) for i in range(1,np.diff(ds2019.dlr).max()+2)])

def learn_one_PPG(metric, qid, verbose, y_pred, g, dlr, epochs, lr, exposure, grade_levels, samples_cnt, sessions_cnt):
    s, e = dlr[qid:qid+2]
    y_pred_s, g_s, sorted_docs_s, dlr_s = \
        EEL.copy_sessions(y=y_pred[s:e], g=g[s:e], sorted_docs=y_pred[s:e].argsort()[::-1], sessions=sessions_cnt)
    
    if metric == 'EEL':
        objective_ins = EEL.EEL(y_pred = y_pred_s, g = g_s, dlr = dlr_s, exposure=exposure, grade_levels = grade_levels)
    else:
        objective_ins = DTR.DTR(y_pred = y_pred_s, g = g_s, dlr = dlr_s, exposure=exposure)
        
    learner = PPG.Learner(  PPG_mat=None, samples_cnt=samples_cnt, 
                                objective_ins=objective_ins, 
                                sorted_docs = sorted_docs_s, 
                                dlr = dlr_s,
                                intra = g_s,
                                inter = np.repeat(dlr_s[:-1], np.diff(dlr_s)))
    vals = learner.fit(epochs, lr, verbose=verbose)
    return vals

def learn_all_PPG(metric, y_pred, g, dlr, epochs, lr, exposure, grade_levels, samples_cnt, sessions_cnt):
    sorted_docs = []
    
    for qid in trange(dlr.shape[0] - 1, leave=False):
#     for qid in range(dlr.shape[0] - 1):
        min_b = learn_one_PPG(metric, qid, 0, y_pred, g, dlr, epochs, lr, exposure, grade_levels, samples_cnt, sessions_cnt)
        sorted_docs.append(min_b)
        

    # print(ndcg_dtr(exposure, lv, np.concatenate(y_rerank), dlr, g, query_counts))
    return sorted_docs


In [31]:
def test(validid, res):
    qid = ds2020.valids['DTR'][validid]
    s, e = ds2020.dlr[qid:qid+2]

    # print(lv_1[qid])
    # for i in range(4):
    #     print('\t', lv_4[qid][i*(e-s):(i+1)*(e-s)] - i*(e-s))


    exposure = np.array([1./np.log2(2+i) for i in range(1,np.diff(ds2020.dlr).max()+2)])

    lv_s, g_s, sorted_docs_s, dlr_s = \
            EEL.copy_sessions(y=ds2020.lv[s:e], g=ds2020.g[s:e], sorted_docs=ds2020.lv[s:e].argsort()[::-1], sessions=4)



    objective_ins_4 = DTR.DTR(y_pred = lv_s, g = g_s, dlr = dlr_s, exposure=exposure)
    print('true labels:')
    print(objective_ins_4.eval(res[qid]))

    for i in range(4):
        print('\t', objective_ins_4.eval(res[qid][i*(e-s):(i+1)*(e-s)] - i*(e-s)))
        
    objective_ins_1 = DTR.DTR(y_pred = ds2020.lv[s:e], g = ds2020.g[s:e], dlr = np.array([0,e-s]), exposure=exposure)
#     print('true labels:')
    print(objective_ins_1.eval(res1[qid]))
       

In [35]:
res = learn_one_PPG('DTR', ds2020.valids['DTR'][0], 1, ds2020.lv, ds2020.g, ds2020.dlr, 30, 0.01, exposure=exposure2020, grade_levels=5, samples_cnt=16, sessions_cnt=4)
# res1 = learn_all_PPG('DTR', ds2020.lv, ds2020.g, ds2020.dlr, 30, 0.01, exposure=exposure2020, grade_levels=5, samples_cnt=16, sessions_cnt=1)

[ 1  5 18  8  2  3  4  6  7  9 17 10 11 12 13 14 15 16  0 20 24 37 27 21
 22 23 25 26 28 36 29 30 31 32 33 34 35 19 39 43 56 46 40 41 42 44 45 47
 55 48 49 50 51 52 53 54 38 58 62 75 65 59 60 61 63 64 66 74 67 68 69 70
 71 72 73 57] inter: [ 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0 19 19 19 19 19
 19 19 19 19 19 19 19 19 19 19 19 19 19 19 38 38 38 38 38 38 38 38 38 38
 38 38 38 38 38 38 38 38 38 57 57 57 57 57 57 57 57 57 57 57 57 57 57 57
 57 57 57 57] intra: ['L' 'L' 'H' 'H' 'H' 'H' 'H' 'L' 'H' 'L' 'H' 'L' 'H' 'H' 'H' 'L' 'L' 'H'
 'H' 'L' 'L' 'H' 'H' 'H' 'H' 'H' 'L' 'H' 'L' 'H' 'L' 'H' 'H' 'H' 'L' 'L'
 'H' 'H' 'L' 'L' 'H' 'H' 'H' 'H' 'H' 'L' 'H' 'L' 'H' 'L' 'H' 'H' 'H' 'L'
 'L' 'H' 'H' 'L' 'L' 'H' 'H' 'H' 'H' 'H' 'L' 'H' 'L' 'H' 'L' 'H' 'H' 'H'
 'L' 'L' 'H' 'H']
min_f: 0.7573575341340542 , mean_f: 0.9248748602532512
new ref permutation:
 [ 1  5 18  8  2  7  3  4  6 17 10  9 11 12 15 13 16 14  0 20 24 37 27 21
 26 22 23 25 36 29 28 30 31 34 32 35 33 19 39 43 56 46 40 45

In [46]:
res = learn_one_PPG('DTR', ds2020.valids['DTR'][0], 3, ds2020.lv, ds2020.g, ds2020.dlr, 30, 0.01, exposure=exposure2020, grade_levels=5, samples_cnt=16, sessions_cnt=4)

[ 1  5 18  8  2  3  4  6  7  9 17 10 11 12 13 14 15 16  0 20 24 37 27 21
 22 23 25 26 28 36 29 30 31 32 33 34 35 19 39 43 56 46 40 41 42 44 45 47
 55 48 49 50 51 52 53 54 38 58 62 75 65 59 60 61 63 64 66 74 67 68 69 70
 71 72 73 57] inter: [ 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0 19 19 19 19 19
 19 19 19 19 19 19 19 19 19 19 19 19 19 19 38 38 38 38 38 38 38 38 38 38
 38 38 38 38 38 38 38 38 38 57 57 57 57 57 57 57 57 57 57 57 57 57 57 57
 57 57 57 57] intra: ['L' 'L' 'H' 'H' 'H' 'H' 'H' 'L' 'H' 'L' 'H' 'L' 'H' 'H' 'H' 'L' 'L' 'H'
 'H' 'L' 'L' 'H' 'H' 'H' 'H' 'H' 'L' 'H' 'L' 'H' 'L' 'H' 'H' 'H' 'L' 'L'
 'H' 'H' 'L' 'L' 'H' 'H' 'H' 'H' 'H' 'L' 'H' 'L' 'H' 'L' 'H' 'H' 'H' 'L'
 'L' 'H' 'H' 'L' 'L' 'H' 'H' 'H' 'H' 'H' 'L' 'H' 'L' 'H' 'L' 'H' 'H' 'H'
 'L' 'L' 'H' 'H']
[ 0  1  2  3  4  5  6  7  8 10  9 11 13 12 14 16 17 15 18 19 20 21 22 23
 27 24 25 26 29 30 32 28 33 31 34 35 36 37 39 40 38 41 42 43 44 45 46 47
 48 49 51 50 52 53 54 55 56 57 58 59 60 61 62 63 64 65 67 66 68 

[ 0  1  2  3  4  7  5  6  9  8 10 12 11 13 14 15 16 17 18 19 20 21 22 23
 24 25 27 28 26 29 30 31 32 33 34 36 35 37 39 40 38 41 43 42 44 45 46 48
 47 51 49 50 52 55 53 54 56 57 58 59 60 61 62 63 65 64 67 66 68 69 72 70
 73 75 71 74] -> [ 1  5 18  8  2  7  3  4  6  9 17 10 11 12 13 14 15 16  0 20 24 37 27 21
 22 23 25 36 26 28 29 31 30 32 34 35 33 19 43 56 39 46 45 40 41 42 44 55
 47 49 48 50 53 54 51 52 38 58 62 75 65 59 60 61 63 64 74 66 67 68 72 69
 73 57 70 71]
[ 0  1  2  3  4  7  5  6  9  8 10 12 11 13 14 15 16 17 18 20 21 19 22 23
 24 25 27 26 29 28 30 32 34 31 33 36 35 37 38 39 40 41 42 43 44 45 47 46
 48 49 50 53 51 52 54 55 56 57 58 59 60 61 62 63 64 65 67 68 66 70 71 69
 72 73 75 74] -> [ 1  5 18  8  2  7  3  4  6  9 17 10 11 12 13 14 15 16  0 24 37 20 27 21
 22 23 25 26 28 36 29 30 34 31 32 35 33 19 39 43 56 46 40 45 41 42 47 44
 55 48 50 51 49 53 52 54 38 58 62 75 65 59 60 61 64 63 74 67 66 69 70 68
 72 73 57 71]
[ 0  1  2  3  4  5  6  7  9  8 10 12 11 13 14 15 16 17 18 19 2

[ 1  0  3  2  4  5  6  7  8  9 10 11 12 13 14 15 16 18 17 20 21 19 22 23
 24 25 26 28 27 29 30 32 31 34 35 33 37 36 38 39 40 41 42 43 44 46 45 47
 48 49 50 51 52 54 55 53 56 58 59 57 60 62 63 61 64 65 66 67 68 69 70 71
 72 73 74 75] -> [ 5  1  7 18  8  2  3  4  6 17 10  9 11 15 16 12 13  0 14 24 37 20 27 26
 21 22 23 25 28 36 29 31 30 34 35 32 33 19 39 45 43 56 46 40 41 42 47 49
 44 55 48 50 51 53 54 52 38 62 75 58 65 64 66 59 60 61 63 74 67 69 70 68
 71 72 73 57]
[ 1  0  2  3  4  5  6  7  8  9 10 11 12 13 15 14 18 16 17 19 20 21 22 23
 24 25 26 28 29 27 30 32 31 34 35 33 37 36 38 39 40 41 42 43 44 46 45 48
 47 49 50 51 52 53 54 55 56 57 58 59 62 63 60 61 64 65 66 67 68 69 70 72
 71 73 74 75] -> [ 5  1 18  7  8  2  3  4  6 17 10  9 11 15 12 16  0 13 14 20 24 37 27 26
 21 22 23 25 36 28 29 31 30 34 35 32 33 19 39 45 43 56 46 40 41 42 47 44
 49 55 48 50 51 52 53 54 38 58 62 75 64 66 65 59 60 61 63 74 67 69 70 71
 68 72 73 57]
[ 0  3  1  2  4  5  6  7  8  9 10 11 12 13 14 18 15 16 17 19 2

[ 0  1  2  3  4  6  5  7  8  9 11 10 12 15 13 14 16 17 18 19 20 22 21 23
 24 25 26 27 29 32 28 30 31 33 34 35 36 37 38 39 40 41 42 46 43 44 45 48
 47 49 51 50 53 54 55 52 56 57 60 58 59 61 62 63 64 65 66 67 68 69 70 71
 72 73 74 75] -> [ 1  7  5 18  8  2  9  3  4  6 11 17 15 16 10 12  0 13 14 20 26 28 24 37
 27 21 22 23 30 34 25 36 29 35 19 31 32 33 39 45 47 43 56 49 46 40 41 42
 53 44 55 54 48 50 51 38 52 58 62 64 66 75 65 59 60 61 63 74 67 69 70 71
 68 72 73 57]
[ 0  1  2  5  3  4  6  7  8  9 10 11 13 12 14 15 17 16 18 19 20 22 21 23
 24 25 26 29 27 28 30 32 31 33 34 35 36 37 38 39 40 41 42 46 43 44 45 48
 49 51 47 50 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
 72 73 74 75] -> [ 1  7  5  9 18  8  2  3  4  6 17 11 10 15 12 16 13  0 14 20 26 28 24 37
 27 21 22 30 23 25 36 34 29 35 19 31 32 33 39 45 47 43 56 49 46 40 41 42
 44 55 53 54 38 48 50 51 52 58 64 66 62 75 65 59 60 61 63 74 67 69 70 71
 68 72 73 57]
[ 0  2  1  3  5  4  6  7  8  9 10 13 11 12 14 15 17 16 18 19 2

[ 0  1  2  3  4  7  5  6  8  9 10 13 11 12 14 15 16 17 18 19 20 22 21 23
 24 25 27 26 28 30 29 31 32 34 33 35 36 37 38 39 40 41 42 43 44 45 46 47
 48 49 50 51 52 53 54 55 56 57 58 59 60 61 63 62 64 65 67 66 70 68 69 71
 72 73 74 75] -> [ 1  7  9  5 18 11  8  2 15  3  4 16  6 17 10  0 12 13 14 20 26 24 28 37
 27 21 30 22 23 34 25 36 35 19 29 31 32 33 39 45 47 49 53 43 56 46 40 54
 38 41 42 44 55 48 50 51 52 58 64 66 62 75 68 65 59 72 73 60 57 61 63 74
 67 69 70 71]
[ 0  1  2  3  4  7  5  6  9  8 10 13 11 15 12 14 16 17 18 19 20 21 22 23
 24 27 25 26 28 30 29 32 31 34 33 35 36 37 38 39 40 41 42 43 44 45 47 46
 49 50 48 51 52 53 54 55 56 57 58 60 59 63 61 65 62 64 66 67 70 68 69 71
 72 73 74 75] -> [ 1  7  9  5 18 11  8  2  3 15  4 16  6  0 17 10 12 13 14 20 26 28 24 37
 27 30 21 22 23 34 25 35 36 19 29 31 32 33 39 45 47 49 53 43 56 46 54 40
 41 42 38 44 55 48 50 51 52 58 64 62 66 68 75 72 65 59 60 73 57 61 63 74
 67 69 70 71]
[ 0  1  2  3  4  5  6  7  9  8 13 10 11 12 14 15 16 17 18 19 2

[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
 72 73 74 75] -> [ 1  7  9 11 15 16  5 18  8  2  3  0  4  6 17 10 12 13 14 20 26 28 30 34
 35 24 37 27 21 19 22 23 25 36 29 31 32 33 39 45 47 49 53 54 38 43 56 46
 40 41 42 44 55 48 50 51 52 58 64 66 68 72 73 57 62 75 65 59 60 61 63 74
 67 69 70 71]
[ 0  1  2  3  4  5  6  7  8  9 11 10 12 13 14 15 16 17 18 19 20 21 22 23
 24 25 26 29 27 28 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
 72 73 74 75] -> [ 1  7  9 11 15 16  5 18  8  2  0  3  4  6 17 10 12 13 14 20 26 28 30 34
 35 24 37 19 27 21 22 23 25 36 29 31 32 33 39 45 47 49 53 54 38 43 56 46
 40 41 42 44 55 48 50 51 52 58 64 66 68 72 73 57 62 75 65 59 60 61 63 74
 67 69 70 71]
[ 0  1  2  3  4  6  5  7  8  9 11 10 12 13 14 15 16 17 18 19 2

[ 0  1  2  3  4  7  5  6  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
 24 26 25 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
 72 73 74 75] -> [ 1  7  9 11 15  5 16  0 18  8  2  3  4  6 17 10 12 13 14 20 26 28 30 34
 35 24 19 37 27 21 22 23 25 36 29 31 32 33 39 45 47 49 53 54 38 43 56 46
 40 41 42 44 55 48 50 51 52 58 64 66 68 72 73 57 62 75 65 59 60 61 63 74
 67 69 70 71]
[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 45 44 46 47
 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
 72 73 74 75] -> [ 1  7  9 11 15 16  0  5 18  8  2  3  4  6 17 10 12 13 14 20 26 28 30 34
 35 19 24 37 27 21 22 23 25 36 29 31 32 33 39 45 47 49 53 54 43 38 56 46
 40 41 42 44 55 48 50 51 52 58 64 66 68 72 73 57 62 75 65 59 60 61 63 74
 67 69 70 71]
[ 0  1  2  3  4  5  7  6  8  9 10 11 12 13 14 15 16 17 18 19 2

In [39]:
res1[ds2020.valids['DTR'][0]]

array([ 1,  7,  9, 11, 15, 16,  0,  5, 18,  8,  2,  3,  4,  6, 17, 10, 12,
       13, 14])

In [37]:
test(0, res)

true labels:


IndexError: invalid index to scalar variable.

HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))

