In [1]:
import os, time, copy
import sys
sys.path.insert(1, '../../')
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
from sklearn.preprocessing import normalize
from tqdm import tqdm
from scipy.stats import norm, binom
from scipy.optimize import brentq
from concentration import linfty_dkw, linfty_binom, wsr_iid

In [2]:
# Get data 2006-2014 from the following link: https://darchive.mblwhoilibrary.org/handle/1912/7341
# Unzip and merge the datasets in the following directory
calib_data = np.load('../calib-outputs.npz')
test_data = np.load('../test-outputs.npz')
calib_preds = calib_data['preds'].astype(int)
calib_labels = calib_data['labels'].astype(int)
test_preds = test_data['preds'].astype(int)
test_labels = test_data['labels'].astype(int)
classes = np.load('../classes.npy')
num_classes = classes.shape[0]

plankton_classes = np.isin(classes,['mix','mix_elongated','detritus','bad', 'bead', 'bubble', 'other_interaction', 'pollen', 'spore'],invert=True)
plankton_idxs = np.where(plankton_classes)[0]

true_count = np.isin(test_labels, plankton_idxs).sum()
uncorrected_est = np.isin(test_preds, plankton_idxs).sum()
print((test_labels == test_preds).astype(float).mean())

# Store the class frequencies in the calibration and test sets 
calib_uq, calib_uq_counts = np.unique(calib_labels, return_counts=True)
calib_uq_freq = np.zeros((num_classes,))
calib_uq_freq[calib_uq] = calib_uq_counts/calib_labels.shape[0]
calib_uq_sort = np.argsort(calib_uq_freq)[::-1]
calib_uq_cumsum = np.cumsum(calib_uq_freq[calib_uq_sort])

test_pred_uq, test_pred_uq_counts = np.unique(test_preds, return_counts=True)
test_pred_uq_freq = np.zeros((num_classes,))
test_pred_uq_freq[test_pred_uq] = test_pred_uq_counts/test_preds.shape[0]

0.9334570326711781


In [3]:
# Problem setup
delta_1 = 0.95-1e-3
delta_2 = 1e-3
K = 12
Ical = calib_uq_sort[:K]
Icalc = calib_uq_sort[K:]
nu = 1-plankton_classes.astype(int)
nucal = nu[Ical]
nucalc = nu[Icalc]
alpha = 1-calib_uq_freq[Ical].sum()
print(classes[Ical])
print(alpha)
print(nucal)

['mix' 'detritus' 'Leptocylindrus' 'Chaetoceros' 'mix_elongated' 'dino30'
 'Cerataulina' 'Skeletonema' 'Cylindrotheca' 'Guinardia_delicatula'
 'Rhizosolenia' 'Thalassiosira']
0.031884587810216636
[1 1 0 0 1 0 0 0 0 0 0 0]


In [4]:
# Construct the confusion matrix
n = calib_preds.shape[0]
N = test_preds.shape[0]

C = np.zeros((num_classes,num_classes)).astype(int)
for j in range(num_classes):
    for l in range(num_classes):
        C[j,l] = ((calib_preds == j) & (calib_labels == l)).sum()

In [5]:
# Construct Ahat
CIcal2 = C[Ical,:][:,Ical]
Ahat = CIcal2 / CIcal2.sum(axis=0)
Ahatinv = np.linalg.inv(Ahat)

In [6]:
# Construct the point estimate
qfhatical = test_pred_uq_freq[Ical]
point_estimate = nucal@Ahatinv@qfhatical

In [7]:
# Do Prediction-Powered Inference
nmin = CIcal2.sum(axis=0).min()

theta = 0.99 # The solution to the optimal theta is extreme in this case, so we clip.
epsilon1 = max([linfty_binom(CIcal2.sum(axis=0)[k], K, theta*delta_1, Ahat[:,k]) for k in range(K)])
epsilon2 = linfty_dkw(N,K,(1-theta)*delta_1)

lower_constant = alpha * np.abs(np.maximum((nucal@Ahatinv).max(),0) + nucalc.min())
upper_constant = alpha * np.abs(np.minimum((nucal@Ahatinv).min(),0) + nucalc.max())

qyhat_lb = np.clip(point_estimate - epsilon1 - epsilon2 - lower_constant, 0, 1)
qyhat_ub = np.clip(point_estimate + epsilon1 + epsilon2 + upper_constant, 0, 1)

count_nonplankton_lb = int(binom.ppf(delta_2, N, qyhat_lb))
count_nonplankton_ub = int(binom.ppf(1-delta_2, N, qyhat_ub))

  return _boost._binom_cdf(k, n, p)


In [8]:
print(f"The prediction-powered confidence interval for the number of plankton observed in 2014 is [{N-count_nonplankton_ub},{N-count_nonplankton_lb}] ([{100*(1-qyhat_ub):.1f}%,{100*(1-qyhat_lb):.1f}%]).")
print(f"The true number of plankton observed in 2014 was {true_count} ({true_count/N*100:.1f}%), which lies in the interval.")
print(f"The uncorrected estimate was {uncorrected_est} ({uncorrected_est/N*100:.1f}%).")
print(f"The corrected estimate was {int(N*(1-point_estimate))} ({(1-point_estimate)*100:.1f}%).")

The prediction-powered confidence interval for the number of plankton observed in 2014 is [6769,41729] ([2.1%,12.5%]).
The true number of plankton observed in 2014 was 23538 (7.1%), which lies in the interval.
The uncorrected estimate was 22068 (6.7%).
The corrected estimate was 19515 (5.9%).
