In [1]:
import collections
import os
import re

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

plt.rcParams['figure.figsize'] = (10, 6)

%matplotlib inline

In [2]:
from ps import dataset_io
from ps import rating_utils

from ps.oracles import map_oracle
from ps.metrics import map

In [3]:
dataset_dir = 'bases/100k/'

In [4]:
ranking_set_by_id = dataset_io.load_ranking_sets(dataset_dir)
rating_set_by_fold = dataset_io.load_ratings_for_all_folds(dataset_dir)

In [5]:
hits_by_fold = rating_utils.compute_hits_by_fold(rating_set_by_fold)

In [6]:
def count_hits_by_user(ranking_set, hits_by_fold, cutoff):
    hits_by_user = hits_by_fold[ranking_set.id.fold]
    
    hit_count_by_user = {}
    for user_id, user_ranking in zip(ranking_set.user_ids, ranking_set.matrix[:, :cutoff]):
        user_hits = hits_by_user[user_id]
        hit_count_by_user[user_id] = sum(item_id in user_hits
                                         for item_id in user_ranking)
        
    return hit_count_by_user

In [7]:
u3_wrmf = ranking_set_by_id[dataset_io.RankingSetId(fold='3', source='WRMF')]
u3_wrmf_hits_by_user = count_hits_by_user(u3_wrmf, hits_by_fold, cutoff=10)

In [8]:
oracle = map_oracle.MAPOracle(ranking_set_by_id, rating_set_by_fold)
u3_map = oracle.compute_optimal_ranking_set(fold='3', input_cutoff=20, output_cutoff=10)
u3_map_hits_by_user = count_hits_by_user(u3_map, hits_by_fold, cutoff=10)

In [9]:
map_metric = map.MAP(ranking_set_by_id, rating_set_by_fold)
map_metric.compute(u3_wrmf, num_items=10), map_metric.compute(u3_map, num_items=10)

(0.14120755574957256, 0.8039669807316853)

In [10]:
def users_with_more_hits_in_wrmf():
    for user_id, wrmf_hit_count in u3_wrmf_hits_by_user.items():
        if u3_map_hits_by_user[user_id] < wrmf_hit_count:
            yield user_id

In [12]:
u3_wrmf_hits_by_user

{1: 2,
 2: 3,
 3: 1,
 4: 0,
 5: 1,
 6: 2,
 7: 4,
 8: 2,
 9: 0,
 10: 4,
 11: 2,
 12: 2,
 13: 1,
 14: 2,
 15: 0,
 16: 3,
 17: 2,
 18: 4,
 19: 0,
 20: 1,
 21: 2,
 22: 4,
 23: 1,
 24: 2,
 25: 3,
 26: 1,
 27: 1,
 28: 4,
 29: 0,
 30: 0,
 31: 0,
 32: 1,
 33: 1,
 34: 0,
 35: 1,
 36: 1,
 37: 3,
 38: 2,
 39: 1,
 40: 0,
 41: 2,
 42: 4,
 43: 4,
 44: 5,
 45: 2,
 46: 0,
 47: 1,
 48: 3,
 49: 1,
 50: 1,
 51: 0,
 52: 3,
 53: 1,
 54: 4,
 55: 1,
 56: 3,
 57: 4,
 58: 3,
 59: 4,
 60: 2,
 61: 1,
 62: 5,
 63: 0,
 64: 3,
 65: 1,
 66: 0,
 67: 0,
 68: 0,
 69: 1,
 70: 2,
 71: 0,
 72: 3,
 73: 3,
 74: 1,
 75: 1,
 76: 1,
 77: 2,
 78: 0,
 79: 2,
 80: 0,
 81: 1,
 82: 1,
 83: 3,
 84: 0,
 85: 4,
 86: 0,
 87: 3,
 88: 1,
 89: 4,
 90: 6,
 91: 4,
 92: 2,
 93: 1,
 94: 6,
 95: 4,
 96: 3,
 97: 3,
 98: 0,
 99: 2,
 100: 2,
 101: 0,
 102: 0,
 103: 1,
 104: 1,
 105: 1,
 106: 2,
 107: 1,
 108: 0,
 109: 3,
 110: 2,
 111: 1,
 112: 3,
 113: 2,
 114: 3,
 115: 0,
 116: 2,
 117: 4,
 118: 1,
 119: 1,
 120: 2,
 121: 1,
 122: 2,
 123: 4,
 

In [13]:
u3_map_hits_by_user

{1: 10,
 2: 6,
 3: 2,
 4: 2,
 5: 5,
 6: 10,
 7: 10,
 8: 6,
 9: 2,
 10: 10,
 11: 7,
 12: 8,
 13: 10,
 14: 8,
 15: 4,
 16: 10,
 17: 2,
 18: 10,
 19: 0,
 20: 2,
 21: 4,
 22: 5,
 23: 9,
 24: 6,
 25: 10,
 26: 3,
 27: 1,
 28: 5,
 29: 2,
 30: 3,
 31: 1,
 32: 3,
 33: 2,
 34: 2,
 35: 1,
 36: 1,
 37: 5,
 38: 4,
 39: 2,
 40: 1,
 41: 5,
 42: 10,
 43: 10,
 44: 8,
 45: 2,
 46: 1,
 47: 3,
 48: 5,
 49: 6,
 50: 2,
 51: 1,
 52: 7,
 53: 3,
 54: 7,
 55: 2,
 56: 10,
 57: 6,
 58: 10,
 59: 10,
 60: 10,
 61: 1,
 62: 10,
 63: 4,
 64: 10,
 65: 10,
 66: 2,
 67: 2,
 68: 3,
 69: 5,
 70: 8,
 71: 3,
 72: 10,
 73: 6,
 74: 4,
 75: 4,
 76: 4,
 77: 3,
 78: 0,
 79: 4,
 80: 1,
 81: 3,
 82: 7,
 83: 10,
 84: 7,
 85: 10,
 86: 0,
 87: 10,
 88: 1,
 89: 7,
 90: 10,
 91: 10,
 92: 10,
 93: 2,
 94: 10,
 95: 10,
 96: 7,
 97: 6,
 98: 1,
 99: 9,
 100: 3,
 101: 3,
 102: 1,
 103: 2,
 104: 5,
 105: 2,
 106: 5,
 107: 1,
 108: 2,
 109: 10,
 110: 7,
 111: 2,
 112: 5,
 113: 3,
 114: 4,
 115: 7,
 116: 5,
 117: 7,
 118: 7,
 119: 10,
 120: 3,


In [14]:
u3_map_7_index = u3_map.user_ids.index(7)
u3_wrmf_7_index = u3_wrmf.user_ids.index(7)

In [15]:
u3_map.matrix[u3_map_7_index]

array([   4.,   86.,   96.,  127.,  135.,  657.,  659.,  153.,  172.,  173.])

In [16]:
u3_wrmf.matrix[u3_wrmf_7_index, :10]

array([ 98, 176,  56, 479, 180, 191, 496,  96, 684,  86], dtype=int32)

In [17]:
u3_hits = hits_by_fold['3'][7]

In [18]:
recommended_to_u3 = oracle._compute_recommended_in_fold('3', 20)[7]

In [19]:
set(u3_hits).intersection(u3_wrmf.matrix[u3_wrmf_7_index, :10])

{86, 96, 191, 479}