In [None]:
from Podkopaev_Ramdas_Code.utils.concentrations import *
from utils.tests import Drop_tester,misclas_losses


def podkopaev_ramdas_algorithm1(cal_test_losses, n_cal, source_conc_type='betting', target_conc_type='conj-bern', \
                                verbose, eps_tol=0.05):
    """
    DRAFT implementation of Podkopaev & Ramdas baseline, i.e., algorithm 1 in that paper. 

    Parameters
    ----------
    cal_test_losses  : Losses for calibration (holdout) and test sets; for evaluating set losses, should be 
                      *mis*coverage indicators; (if evaluating point losses, would be conformity scores)
    n_cal            : Number of calibration points.
    source_conc_type : Concentration used for source UCB
    target_conc_type : Concentration used for target LCB
    eps_tol          : Epsilon tolerance

    Returns
    ------- 
    """
    
    ## Index in test set of first alarm
    alarm_idx = None
    
    ## cal and test losses:
    cal_losses = cal_test_losses[:n_cal]
    test_losses = cal_test_losses[n_cal:]
    
    
    ## Set up Drop_tester for computer UCB on source risk and LCB on target risk
    tester = Drop_tester()
    tester.eps_tol = eps_tol
    tester.source_conc_type = source_conc_type
    tester.target_conc_type = target_conc_type
    tester.change_type = 'absolute'
    
    
    ## Estimate source risk UCB
    tester.estimate_risk_source(cal_losses)
    source_upper_bound_plus_tol = tester.source_rejection_threshold
    
    
    ## Estimate target risk LCB (double check this, ie whether tester is storing running risk estimates)
    T = len(test_losses)
    target_lower_bounds = np.zeros(T)
    
    for t in range(T):
        tester.estimate_risk_target(test_losses[t])
        target_lower_bounds[t] = tester.target_risk_lower_bound
    
        if (target_lower_bounds[t] > source_upper_bound_plus_tol and alarm_idx is None):
            alarm_idx = t
            print(f'Alarm raise at test point {t}!')
    
    return alarm_idx, source_upper_bound_plus_tol, target_lower_bounds
