In [207]:
import numpy as np
import pandas as pd
import scipy as sp
from scipy.stats import poisson, nbinom
from scipy.optimize import curve_fit
import matplotlib.pyplot as plt
import sys, os

In [208]:
# estimate recall for a given data frame (data) and parameters (param)
# data:
#    data is a pandas DataFrame, sorted by scores from high to low
#    data has a column 'Score': classifier-predicted score
#    data has a column 'Label': 1: relevant; 0: non-relevant; -1: unlabeled 
# param:
#    param is a dictionary

#        'ignore_score': bool
#        if not 'ignore_score':
#            'score_threshold': double [0.0, 1.0]
#        'interval_length': int, >= 1
#        'rate_curve': {'exp', 'poly', 'exp_poly', 'const'}
#        'lookback': int, >= 1
#        'forget_factor': double, >= 0
#        'est_num_rel': {'all', 'rest'}
#        'confidence_bound': {'none', 'poisson'}
#        'confidence_level': double [0.0, 1.0]
#

In [209]:
# processes the 'Label' and 'Score' columns of the input data
# such that each reference has a guessed/pseudo inclusion/exclusion label
# 1 means inclusion, 0 means exclusion
# if param['ignore_score'] is True, 
#     then we assign all unlabeled data a pseudo exclusion label
# if param['ignore_score'] is False,
#     then we assign an unlabeled data point a pseudo inclusion label 
#     if it has a predicted score no less than param['score_threshold']
#     otherwise, we assign an unlabeled data point a pseudo exclusion label
# 
# params:
#     data: input pandas DataFrame
#     param: global parameter dictionary
# return:
#     a list of pseudo labels
#     1: inclusion; 0: exclusion

def get_inclusion_list(data, param):
    l = []
    if param['ignore_score']:
        for i, row in data.iterrows():
            if row['Label'] < 0:
                l.append(0)
            else:
                l.append(row['Label'])
    else:
        for i, row in data.iterrows():
            if row['Label'] < 0:
                if row['Score'] >= param['score_threshold']:
                    l.append(1)
                else:
                    l.append(0)
            else:
                l.append(row['Label'])
    return l    

In [210]:
# exponential rate function 
# x: position index, starting from 1
# a: function parameter, to be determined by curve-fitting
# a should be greater than 0. large a means fast rate decay 
# rate = exp^{-a * (position - 1)}

def exponential_rate(x, a):
    return np.exp(-a * (x - 1))

In [211]:
# polynomial rate function 
# x: position index, starting from 1
# a: function parameter, to be determined by curve-fitting
# a should be greater than 0. large a means fast rate decay 
# rate = position^{-a}

def polynomial_rate(x, a):
    return np.power(x, -a)

In [212]:
# mixture of exponential and polynomial rate function 
# x: position index, starting from 1
# a1: exponential function parameter, to be determined by curve-fitting
# a1: polynomial function parameter, to be determined by curve-fitting
# c: weight of exponential. (1-c): weight of polynomial
# a1, a2 should be greater than 0. large a1, a2 means fast rate decay
# c should be between 0 and 1.
# rate = c * exp^{-a * (position - 1)}  +  (1-c) * position^{-a}

def exp_poly_rate(x, a1, a2, c):
    return c * np.exp(-a1 * (x - 1)) + (1-c) * np.power(x, -a2)

In [213]:
# constant rate function 
# x: position index, starting from 1
# a: function parameter, to be determined by curve-fitting
# a should be greater than 0. larger a means high inclusion rate 
# rate = a, for all positions

def constant_rate(x, a):
    return a * np.ones_like(x)

In [214]:
# helper function used in fit_inclusion_rate_ls
# travel back from the end of y's, until we see param['lookback'] 
# inclusions. return the travel length

def get_lookback_travel_length(y_list, lookback):
    num_rel_so_far = 0
    # traverse the y_list in reverse order
    for i in range(1, len(y_list) + 1):
        if y_list[-i] > 0:
            num_rel_so_far += 1
        if num_rel_so_far == lookback:
            return i
    return len(y_list)

In [215]:
# helper function used in fit_inclusion_rate_ls
# smooth x's and y's by averaging y's in equal-length intervals
# smoothing helps mitigate overfitting when the rate function
# is not constant.
# smoothing is not needed when the rate function is a constant.

def smooth(x, y, param):
    interval_length = param['interval_length']
    if len(x) <= interval_length:
        return [np.average(x)], [np.average(y)]
    
    new_x, new_y = [], []
    num_intervals = int(len(x) / interval_length)
    interval_multiples = num_intervals * interval_length
    start_pos = 0
    
    for i in range(num_intervals):
        new_x.append( sum(x[start_pos : start_pos+interval_length])/interval_length )
        new_y.append( sum(y[start_pos : start_pos+interval_length])/interval_length )
        start_pos += interval_length
    
    if len(x) > interval_multiples:
        new_x.append( sum(x[interval_multiples:]) / (len(x) - interval_multiples) )
        new_y.append( sum(y[interval_multiples:]) / (len(x) - interval_multiples) )
    
    return new_x, new_y

In [1]:
# fit inclusion rate using nonlinear least square curve-fitting
# params:
#    inc_list: inclusion list, same length as the input dataframe
#              each entry is a pseudo inclusion label (1 or 0)
#    curr_pos: current position, an index in [0, len(inc_list)-1]
#              the current screening position. we will only use
#              the inc_list before curr_pos to fit a rate function
#    param:    global parameter dictionary
# return:
#    rate: a list with equal length as inc_list
#              each entry is an estimated probability of inclusion

def fit_inclusion_rate_ls(inc_list, curr_pos, param):
    # prepare the list of x's and y's for curve fitting
    x, y = [], []
    
    # first, take out pseudo labels from inc_list up to curr_pos
    # for example:
    # inc_list:     1 0 1 1 0 ... 0   1   0            0   0   ...   0
    # position:     1 2 3 4 5 ... 118 119 curr_pos=120 121 122 ...   end
    #               |<------- considered ----------->|
    for i in range(curr_pos):
        x.append( i + 1 )
        y.append( inc_list[i]  )
    
    # if no relevant articles from beginning to curr_pos
    # let's don't give up so quickly. assume there were
    # one relevant document at the beginning
    #
    # TODO: add heuristic method to detect if:
    #       1) the inclusion rate is too low 
    #       2) the ranker seems to perform worse than random
    # if so, switch to conservative rate estimates
    #
    if sum(y) == 0:
        y[0] = 1
    
    # print ('full_x', x)
    # print ('full_y', y)
    
    # travel back from the end of y's, until we see param['lookback'] 
    # inclusions. return the travel length
    travel_length = get_lookback_travel_length(y, param['lookback'])
    
    # only take the most recent x's and y's within travel_length
    # for example: travel_length = 117
    # inc_list:     1 0 1 1 0 ... 0   1   0            0   0   ...   0
    # position:     1 2 3 4 5 ... 118 119 curr_pos=120 121 122 ...   end
    #                     |<---- considered -------->| 
    x, y = x[-travel_length:], y[-travel_length:]
    
    # print ('recent_x', x)
    # print ('recent_y', y)
    
    # smooth x's and y's by averaging y's in equal-length intervals
    # smoothing helps mitigate overfitting when the rate function
    # is not constant.
    # smoothing is not needed when the rate function is a constant.
    x, y = smooth(x, y, param)
    
    # print ('smooth_x', x)
    # print ('smooth_y', y)
    
    if param['rate_curve'] == 'exp':
        func = exponential_rate
    elif param['rate_curve'] == 'poly':
        func = polynomial_rate
    elif param['rate_curve'] == 'const':
        func = constant_rate
    elif param['rate_curve'] == 'exp_poly':
        func = exp_poly_rate
    
    # the i-th entry in sigma_list indicates the importance 
    # of the i-th (x,y) data point in curve fitting
    # the larger sigma, the less important a data point
    # we set the sigma to be larger for data points that are
    # 'further back in the past', i.e., to the left of the x's and y's
    # this achieves a forgetting effect: 
    # the further back an inclusion event, the less influence it has 
    # on estimating the rate function 
    
    # here, we implement the sigma_list as an increasing function
    # with respect to the reversed position of a data point, so:
    #    the most recent data point has sigma = 1 
    #    the second most recent data point has sigma slightly larger than 1
    #    and so on.
    #
    # param['forget_factor'] controlls how fast the forgetting is.
    # param['forget_factor'] = 0: no forgetting
    # param['forget_factor'] > 0: the larger the value, the faster the forgetting.
    
    # sigma_list = 1 + param['forget_factor'] / 1000 * np.power(np.array(range(1, len(x)+1)), 2)
    sigma_list = np.exp(param['forget_factor'] / 50 * np.array(range(len(x))))
    sigma_list = np.flip(sigma_list)
    # print ('sigma_list', sigma_list)
    
    # fit the rate function using the most recent x's and y's with forgetting
    if param['rate_curve'] == 'exp' or param['rate_curve'] == 'poly' or param['rate_curve'] == 'const': 
        popt, pcov = curve_fit(func, x, y, sigma = sigma_list, maxfev = 500)
    elif param['rate_curve'] == 'exp_poly':
        popt, pcov = curve_fit(func, x, y, sigma = sigma_list, maxfev = 500, bounds=(0, [np.inf, np.inf, 1.0]))
    
    # use the fitted curve to generate the rates at each position 
    x_array = np.array(range(1, len(inc_list)+1))
    rate = func(x_array, *popt)
    
    # print ('popt', *popt)
    # print (rate)

    
    # plt.plot(x, y, 'b-', label='data for fitting')
    # if param['rate_curve'] == 'exp' or param['rate_curve'] == 'poly' or param['rate_curve'] == 'const':
    #    plt.plot(x_array, rate, 'r-', label='fit: a=%5.3f' % tuple(popt))
    # else:
    #    plt.plot(x_array, rate, 'r-', label='fit: a1=%5.3f, a2=%5.3f, c=%5.3f' % tuple(popt))
    # plt.xlabel('rank position')
    # plt.ylabel('inclusion rate')
    # plt.ylim([0, 1])
    # plt.legend()
    # plt.show()
    
    return rate

SyntaxError: invalid syntax (4190081695.py, line 114)

In [249]:
# recall is a ratio:
#               number of included documents so far
#               ----------------------------------------- 
#               total number of documents to be included 
# we know the denominator, by counting included documents so far
# we can estimate the denominator in two ways:
# 1) param['est_num_rel'] == 'all':
#    total number of documents estimated to be included
# 2) param['est_num_rel'] == 'rest':
#    number of included documents so far 
#       + number of *unlabeled* documents estimated to be included

def take_ratio(data, rate, param):
    num_rel = len(data[data['Label'] == 1]) + 1e-20
    if param['est_num_rel'] == 'all':
        est_num_rel_all = sum(rate)
        if param['confidence_bound'] == 'none':
            denominator = est_num_rel_all
        elif param['confidence_bound'] == 'poisson':
            denominator = poisson.ppf(param['confidence_level'], est_num_rel_all)
        else: # param['confidence_bound'] == 'none'
            denominator = est_num_rel_all
        return num_rel / denominator
    else: # param['est_num_rel'] == 'rest':
        est_num_rel_rest = sum([rate[i] for i, row in data.iterrows() if row['Label'] == -1])
        # print ('est_num_rel_rest', est_num_rel_rest)
        if param['confidence_bound'] == 'none':
            partial_denominator = est_num_rel_rest
        elif param['confidence_bound'] == 'poisson':
            partial_denominator = poisson.ppf(param['confidence_level'], est_num_rel_rest)
        else: # param['confidence_bound'] == 'none'
            partial_denominator = est_num_rel_rest
        return num_rel / (num_rel + partial_denominator)

In [218]:
def estimate_recall(data, current_position, param):
    
    # inclusion_list is a Python list that has equal length as data
    # it processes the 'Label' and 'Score' columns of the input data
    # such that each reference has a *guessed* inclusion/exclusion label
    # 1 means inclusion, 0 means exclusion
    inclusion_list = get_inclusion_list(data, param)
    
    # rate is a Python list that has equal length as data
    # each value in the list is an estimated rate 
    # (i.e., probability) of inclusion
    rate = fit_inclusion_rate_ls(inclusion_list, current_position, param)
    
    # calculate recall using the estimated rate from above
    recall_est = take_ratio(data, rate, param)
    
    return recall_est

# Helper Functions in Evaluation

In [219]:
def nondecreasify(y):
    new_y = []
    current_max = -1
    for value in y:
        if value > current_max:
            current_max = value
        if value <= current_max:
            new_y.append(current_max)
        else:
            new_y.append(value)
    return new_y

In [240]:
def plot_true_and_estimated_recall_curves(all_data, param, step_size, source, topic, num_screened, plot_path):
        
    x_list = []
    true_recall_list = []
    esti_recall_list = []

    total_rel = len(all_data[all_data['Label'] == 1])
    # iterate over all positions, collect points on the recall curves
    num_rel = 0
    for i, row in all_data.iterrows():
        if row['Label'].item() == 1:
            num_rel += 1

        if i % step_size == 0 and i > 0:
            x_list.append( i + 1 )
            true_r = num_rel / total_rel
            true_recall_list.append( true_r )

            cutoff = i
            all_data_copy = all_data.copy()
            for j, row in all_data_copy.iterrows():
                if j > cutoff:
                    all_data_copy.at[j,'Label'] = -1

            est_r = estimate_recall(all_data_copy, cutoff, param)

            esti_recall_list.append( est_r )
            
            print ('Position:', cutoff, 'True:', true_r, 'Est:', est_r)
            if est_r > 0.99:
                break

    plt.title('{}, {}, rel: {}, screened: {}, S: {}'.format(source, topic, total_rel, num_screened, len(all_data)))
    plt.plot(x_list, true_recall_list, 'r', label='true recall')
    plt.plot(x_list, esti_recall_list, 'b', label='estimated recall')
    plt.plot(x_list, nondecreasify(esti_recall_list), 'g', label='estimated recall (non-descreasing)')
    plt.grid()
    plt.axhline(y = 0.95, color = 'magenta', linestyle = '--')

    plt.legend()
    plt.xlabel('Rank position')
    plt.ylabel('Recall')
    plt.savefig(plot_path)
    plt.show()



# Generate per-topic curves

In [279]:
param = {'ignore_score': True,
         'interval_length': 2,
         'lookback': 5,
         'forget_factor': 0,
         'rate_curve': 'const',
         'est_num_rel': 'rest',
         'confidence_bound': 'poisson',
         'confidence_level': 0.95
        }

In [None]:
per_topic_data_folder = '../data/per_topic_simulation'
# per_topic_output_folder = '../output/per_topic_const_forget'
per_topic_output_folder = '../output/per_topic_const_no_forget'

for file in os.listdir(per_topic_data_folder):
    file_path = os.path.join(per_topic_data_folder, file)
    source, topic, num_screened = file.split('.')[0].split('_')
    num_screened = int(num_screened)
    rank_dataframe = pd.read_csv(file_path)
    plot_path = os.path.join(per_topic_output_folder, file.replace('csv', 'pdf'))
    
    print ('{}, {}, screened: {}'.format(source, topic, num_screened))
    plot_true_and_estimated_recall_curves(rank_dataframe, param, 30, source, topic, num_screened, plot_path)
    

ASREVIEW, STcardioCHE2, screened: 210
Position: 30 True: 0.5 Est: 0.00045938289564351886
Position: 60 True: 0.5833333333333334 Est: 0.0012336975678533663
Position: 90 True: 0.5833333333333334 Est: 0.00186219739292365
Position: 120 True: 0.5833333333333334 Est: 0.0024857954545454545
Position: 150 True: 0.5833333333333334 Est: 0.003105590062111801
Position: 180 True: 0.75 Est: 0.0045708481462671405
Position: 210 True: 1.0 Est: 0.001973684210526316
Position: 240 True: 1.0 Est: 0.0030511060259344014
Position: 270 True: 1.0 Est: 0.004118050789293068
Position: 300 True: 1.0 Est: 0.005176876617773943
Position: 330 True: 1.0 Est: 0.006227296315516347
Position: 360 True: 1.0 Est: 0.007272727272727273
Position: 390 True: 1.0 Est: 0.008310249307479225
Position: 420 True: 1.0 Est: 0.00933852140077821
Position: 450 True: 1.0 Est: 0.010362694300518135
Position: 480 True: 1.0 Est: 0.011385199240986717
Position: 510 True: 1.0 Est: 0.012396694214876033
Position: 540 True: 1.0 Est: 0.013407821229050279


Position: 4980 True: 1.0 Est: 0.13043478260869565
Position: 5010 True: 1.0 Est: 0.13043478260869565
Position: 5040 True: 1.0 Est: 0.13186813186813187
Position: 5070 True: 1.0 Est: 0.13186813186813187
Position: 5100 True: 1.0 Est: 0.13333333333333333
Position: 5130 True: 1.0 Est: 0.13333333333333333
Position: 5160 True: 1.0 Est: 0.13333333333333333
Position: 5190 True: 1.0 Est: 0.1348314606741573
Position: 5220 True: 1.0 Est: 0.1348314606741573
Position: 5250 True: 1.0 Est: 0.13636363636363635
Position: 5280 True: 1.0 Est: 0.13636363636363635
Position: 5310 True: 1.0 Est: 0.13793103448275862
Position: 5340 True: 1.0 Est: 0.13793103448275862
Position: 5370 True: 1.0 Est: 0.13793103448275862
Position: 5400 True: 1.0 Est: 0.13953488372093023
Position: 5430 True: 1.0 Est: 0.13953488372093023
Position: 5460 True: 1.0 Est: 0.13953488372093023
Position: 5490 True: 1.0 Est: 0.1411764705882353
Position: 5520 True: 1.0 Est: 0.1411764705882353
Position: 5550 True: 1.0 Est: 0.14285714285714285
Posi

Position: 10080 True: 1.0 Est: 0.2222222222222222
Position: 10110 True: 1.0 Est: 0.2222222222222222
Position: 10140 True: 1.0 Est: 0.2222222222222222
Position: 10170 True: 1.0 Est: 0.2222222222222222
Position: 10200 True: 1.0 Est: 0.2222222222222222
Position: 10230 True: 1.0 Est: 0.22641509433962265
Position: 10260 True: 1.0 Est: 0.22641509433962265
Position: 10290 True: 1.0 Est: 0.22641509433962265
Position: 10320 True: 1.0 Est: 0.22641509433962265
Position: 10350 True: 1.0 Est: 0.22641509433962265
Position: 10380 True: 1.0 Est: 0.22641509433962265
Position: 10410 True: 1.0 Est: 0.22641509433962265
Position: 10440 True: 1.0 Est: 0.22641509433962265
Position: 10470 True: 1.0 Est: 0.22641509433962265
Position: 10500 True: 1.0 Est: 0.23076923076923078
Position: 10530 True: 1.0 Est: 0.23076923076923078
Position: 10560 True: 1.0 Est: 0.23076923076923078
Position: 10590 True: 1.0 Est: 0.23076923076923078
Position: 10620 True: 1.0 Est: 0.23076923076923078
Position: 10650 True: 1.0 Est: 0.230

Position: 15180 True: 1.0 Est: 0.2926829268292683
Position: 15210 True: 1.0 Est: 0.2926829268292683
Position: 15240 True: 1.0 Est: 0.2926829268292683
Position: 15270 True: 1.0 Est: 0.2926829268292683
Position: 15300 True: 1.0 Est: 0.2926829268292683
Position: 15330 True: 1.0 Est: 0.2926829268292683
Position: 15360 True: 1.0 Est: 0.2926829268292683
Position: 15390 True: 1.0 Est: 0.2926829268292683
Position: 15420 True: 1.0 Est: 0.2926829268292683
Position: 15450 True: 1.0 Est: 0.2926829268292683
Position: 15480 True: 1.0 Est: 0.2926829268292683
Position: 15510 True: 1.0 Est: 0.2926829268292683
Position: 15540 True: 1.0 Est: 0.2926829268292683
Position: 15570 True: 1.0 Est: 0.2926829268292683
Position: 15600 True: 1.0 Est: 0.2926829268292683
Position: 15630 True: 1.0 Est: 0.3
Position: 15660 True: 1.0 Est: 0.3
Position: 15690 True: 1.0 Est: 0.3
Position: 15720 True: 1.0 Est: 0.3
Position: 15750 True: 1.0 Est: 0.3
Position: 15780 True: 1.0 Est: 0.3
Position: 15810 True: 1.0 Est: 0.3
Posit

Position: 20280 True: 1.0 Est: 0.34285714285714286
Position: 20310 True: 1.0 Est: 0.34285714285714286
Position: 20340 True: 1.0 Est: 0.34285714285714286
Position: 20370 True: 1.0 Est: 0.34285714285714286
Position: 20400 True: 1.0 Est: 0.34285714285714286
Position: 20430 True: 1.0 Est: 0.34285714285714286
Position: 20460 True: 1.0 Est: 0.34285714285714286
Position: 20490 True: 1.0 Est: 0.34285714285714286
Position: 20520 True: 1.0 Est: 0.34285714285714286
Position: 20550 True: 1.0 Est: 0.34285714285714286
Position: 20580 True: 1.0 Est: 0.35294117647058826
Position: 20610 True: 1.0 Est: 0.35294117647058826
Position: 20640 True: 1.0 Est: 0.35294117647058826
Position: 20670 True: 1.0 Est: 0.35294117647058826
Position: 20700 True: 1.0 Est: 0.35294117647058826
Position: 20730 True: 1.0 Est: 0.35294117647058826
Position: 20760 True: 1.0 Est: 0.35294117647058826
Position: 20790 True: 1.0 Est: 0.35294117647058826
Position: 20820 True: 1.0 Est: 0.35294117647058826
Position: 20850 True: 1.0 Est: 

Position: 25530 True: 1.0 Est: 0.3870967741935484
Position: 25560 True: 1.0 Est: 0.3870967741935484
Position: 25590 True: 1.0 Est: 0.3870967741935484
Position: 25620 True: 1.0 Est: 0.3870967741935484
Position: 25650 True: 1.0 Est: 0.3870967741935484
Position: 25680 True: 1.0 Est: 0.3870967741935484
Position: 25710 True: 1.0 Est: 0.3870967741935484
Position: 25740 True: 1.0 Est: 0.3870967741935484
Position: 25770 True: 1.0 Est: 0.3870967741935484
Position: 25800 True: 1.0 Est: 0.3870967741935484
Position: 25830 True: 1.0 Est: 0.3870967741935484
Position: 25860 True: 1.0 Est: 0.3870967741935484
Position: 25890 True: 1.0 Est: 0.3870967741935484
Position: 25920 True: 1.0 Est: 0.3870967741935484
Position: 25950 True: 1.0 Est: 0.4
Position: 25980 True: 1.0 Est: 0.4
Position: 26010 True: 1.0 Est: 0.4
Position: 26040 True: 1.0 Est: 0.4
Position: 26070 True: 1.0 Est: 0.4
Position: 26100 True: 1.0 Est: 0.4
Position: 26130 True: 1.0 Est: 0.4
Position: 26160 True: 1.0 Est: 0.4
Position: 26190 True