In [1]:
import numpy as np
import sys

sys.path.append('../general')
from pool import Pool
from metric import metric

sys.path.append('../eleven_regressions/')
from counterfactual_model import CounterfactualModel

from collections import Counter
import xgboost as xgb
from matplotlib import pyplot as pp
%matplotlib inline

In [2]:
def make_uniform_probas(pool):
    unique_queries = list(set(map(tuple, pool.queries)))
    positions_for_queries = {
        query: []
        for query in unique_queries
    }
    for position, query in zip(pool.positions, pool.queries):
        positions_for_queries[tuple(query)].append(position)
    new_probas = [
        1 / len(positions_for_queries[tuple(query)])
        for query in pool.queries
    ]
    return new_probas

In [3]:
pool = Pool('../data')
pool.probas = make_uniform_probas(pool)

In [5]:
scores = []
for i in range(100):
    train_pool, test_pool = pool.train_test_split()
    train_pools = train_pool.split_by_position()
    models = [
        xgb.XGBRegressor()
        for position in range(CounterfactualModel.NONE_POSITION)
    ]
    model = CounterfactualModel(models)
    model.fit(train_pools)
    prediction = model.predict(test_pool)
    scores.append(
        metric(prediction, test_pool.positions, test_pool.targets, test_pool.probas)
    )

In [6]:
np.mean(scores)

-0.0135343472099525

In [7]:
np.std(scores)

0.05019685561637613

In [8]:
scores

[-0.05083378607696957,
 0.012610223637686836,
 -0.06890368215243022,
 0.004797334021992153,
 -0.028906123201052347,
 -0.04947087431837011,
 0.014662805903840948,
 -0.0030661492149035136,
 -0.059829374688642994,
 -0.060039739379176386,
 0.026248488706571083,
 0.05761221792079784,
 -0.019525067974019934,
 -0.06774828891400937,
 -0.08858890645813058,
 0.04786186971487822,
 0.030322464982668558,
 -0.06265173255514216,
 0.0016100763347413804,
 -0.013303698976834616,
 -0.013262648756857271,
 -0.043151995007638576,
 0.031052817432968706,
 -0.01855678230744821,
 -0.07505318243415267,
 -0.04608134800416452,
 -0.2059195612139172,
 0.034908658553052835,
 0.0459641540867311,
 0.0007352010983007925,
 -0.07572078253604747,
 -0.03654697981145647,
 0.010289089856324372,
 -0.0022552040020624798,
 0.030369918598069087,
 0.04451807105982745,
 -0.11665899314703765,
 0.03922929565994828,
 0.016774457132374798,
 -0.1562500838747731,
 -0.05813260558799461,
 -0.0008278772919266312,
 -0.06565281097094218,
 -0.