In [1]:
import numpy as np
import sys

sys.path.append('../general')
from pool import Pool
from metric import metric

sys.path.append('../eleven_regressions/')
from counterfactual_model import CounterfactualModel

from collections import Counter
import xgboost as xgb
from matplotlib import pyplot as pp
%matplotlib inline

In [2]:
def make_uniform_probas(pool):
    unique_queries = list(set(map(tuple, pool.queries)))
    positions_for_queries = {
        query: []
        for query in unique_queries
    }
    for position, query in zip(pool.positions, pool.queries):
        positions_for_queries[tuple(query)].append(position)
    new_probas = [
        1 / len(positions_for_queries[tuple(query)])
        for query in pool.queries
    ]
    return new_probas

def make_uniform_near_prod_probas(pool):
    mask = [
        abs(position - prod_pos) < 2
        for position, prod_pos in zip(pool.positions, pool.prod_positions)
    ]
    pool.filter(mask)
    return make_uniform_probas(pool)

In [3]:
pool = Pool('../data')
pool.probas = make_uniform_near_prod_probas(pool)

In [4]:
scores = []
for i in range(100):
    train_pool, test_pool = pool.train_test_split()
    train_pools = train_pool.split_by_position()
    models = [
        xgb.XGBRegressor()
        for position in range(CounterfactualModel.NONE_POSITION)
    ]
    model = CounterfactualModel(models)
    model.fit(train_pools)
    prediction = model.predict(test_pool)
    scores.append(
        metric(prediction, test_pool.positions, test_pool.targets, test_pool.probas)
    )

In [5]:
np.mean(scores)

0.02358460804109976

In [6]:
np.std(scores)

0.0547106709195335

In [7]:
scores

[0.026719860675004856,
 0.0592049457668479,
 0.04758205707060462,
 0.060999214578369015,
 0.04310655972820198,
 0.05736445456265128,
 0.006452009568806567,
 0.002761289397662974,
 -0.028663656184592563,
 0.10537846752099883,
 -0.0001582984264167518,
 0.009585234074810751,
 0.08838773294353655,
 0.036636563103183424,
 0.00911608503039083,
 0.0035828372161443647,
 0.03172564209038774,
 -0.02078027935589061,
 0.027900249421136124,
 0.07025092733947563,
 0.03522763252966615,
 0.04449002790932703,
 0.006184878054645468,
 0.11031880289475549,
 -0.013271431801082546,
 0.12063956970766196,
 0.04547092865707615,
 0.034204076282764145,
 -0.023530383267292152,
 -0.09966484719773382,
 0.016247156593534682,
 0.034578875955625865,
 0.005085165381644074,
 0.0069281857635602314,
 0.0001424481949491722,
 0.035613487819370175,
 0.0015710234217040254,
 0.05172077782863814,
 0.0027899352553623176,
 0.08465060752839186,
 0.10881931283823876,
 0.00488410917549564,
 0.03323496372587824,
 0.04423567293232846,