In [1]:
import os, time, copy
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
from sklearn.preprocessing import normalize
from scipy.optimize import brentq
from scipy.stats import binom
from tqdm import tqdm

In [2]:
# Get data 2006-2014 from the following link: https://darchive.mblwhoilibrary.org/handle/1912/7341
# Unzip and merge the datasets in the following directory
calib_data = np.load('../calib-outputs.npz')
test_data = np.load('../test-outputs.npz')
calib_preds = calib_data['preds'].astype(int)
calib_labels = calib_data['labels'].astype(int)
test_preds = test_data['preds'].astype(int)
test_labels = test_data['labels'].astype(int)
classes = np.load('../classes.npy')
num_classes = classes.shape[0]

plankton_classes = np.isin(classes,['mix','mix_elongated','detritus','bad', 'bead', 'bubble', 'other_interaction', 'pollen', 'spore'],invert=True)
plankton_classes_list = np.array(np.where(plankton_classes)[0])

true_count = np.isin(test_labels, plankton_classes_list).sum()
uncorrected_est = np.isin(test_preds, plankton_classes_list).sum()

calib_preds = np.isin(calib_preds, plankton_classes_list)
calib_labels = np.isin(calib_labels, plankton_classes_list)
test_preds = np.isin(test_preds, plankton_classes_list)
test_labels = np.isin(test_labels, plankton_classes_list)

print(f"Calib acc: {(calib_preds == calib_labels).astype(int).mean()}")
print(f"Test acc: {(test_preds == test_labels).astype(int).mean()}")

Calib acc: 0.962313466496375
Test acc: 0.9789407941012395


In [3]:
# Look at the unique classes 
calib_uq, calib_uq_counts = np.unique(calib_labels, return_counts=True)
calib_uq_freq = calib_uq_counts/calib_uq_counts.sum()
calib_uq_sort = np.argsort(calib_uq_freq)[::-1]
calib_uq_freq = calib_uq_freq[calib_uq_sort]; calib_uq = calib_uq[calib_uq_sort];
calib_uq_cumsum = np.cumsum(calib_uq_freq)

In [4]:
# Problem setup
alpha = 0.1
delta_1 = 0.099
delta_2 = 0.001
nu = np.array([0,1])

In [5]:
# Construct the confusion matrix
n = calib_preds.shape[0]
N = test_preds.shape[0]

C = np.zeros((2,2)).astype(int)
for j in range(2):
    for l in range(2):
        C[j,l] = np.bitwise_and(calib_preds == j,calib_labels == l).astype(int).sum()
Ahat = C / C.sum(axis=0)

In [6]:
# Construct Ahat
Ahatinv = np.linalg.inv(Ahat)
print(np.array_str(Ahat, precision=3))
print(np.array_str(Ahatinv, precision=3))

[[0.98  0.152]
 [0.02  0.848]]
[[ 1.024 -0.184]
 [-0.024  1.184]]


In [7]:
# Construct the point estimate
target_uq, target_uq_counts = np.unique(test_preds, return_counts=True)
target_uq_freq = target_uq_counts/target_uq_counts.sum()
target_uq_sort = np.argsort(target_uq_freq)[::-1]
target_uq_freq = target_uq_freq[target_uq_sort]; target_uq = target_uq[target_uq_sort];
qfhat = target_uq_freq

In [8]:
# Run MAI
nmin = int(calib_uq_freq.min()*n)
point_estimate = nu@Ahatinv@qfhat

term1 = np.abs(nu@Ahatinv).sum()*np.sqrt(2/nmin*np.log(6/delta_1))
term2 = np.abs(nu@Ahatinv).sum()*np.sqrt(2/N*np.log(6/delta_1))

qyhat_lb = np.maximum(np.minimum(point_estimate - term1 - term2,1),0)
qyhat_ub = np.maximum(np.minimum(point_estimate + term1 + term2,1),0)

count_plankton_lb = int(binom.ppf(delta_2, N, qyhat_lb))
count_plankton_ub = int(binom.ppf(1-delta_2, N, qyhat_ub))

In [15]:
print(f"The rectified confidence interval for the number of plankton observed in 2014 is [{count_plankton_lb},{count_plankton_ub}] ([{100*qyhat_lb:.1f}%,{100*qyhat_ub:.1f}%]).")
print(f"The true number of plankton observed in 2014 was {true_count} ({true_count/N*100:.1f}%), which lies in the interval.")
print(f"The uncorrected estimate was {uncorrected_est} ({uncorrected_est/N*100:.1f}%).")

The rectified confidence interval for the number of plankton observed in 2014 is [11562,25962] ([3.6%,7.7%]).
The true number of plankton observed in 2014 was 23538 (7.1%), which lies in the interval.
The uncorrected estimate was 22068 (6.7%).


In [10]:
# Baselines

# Naive intervals
naive_epsilon = np.sqrt(1/(2*N) * np.log(1/delta_1))
naive_lb = qfhat[1] - naive_epsilon
naive_ub = qfhat[1] + naive_epsilon
naive_count_lb = int(binom.ppf(delta_2, N, naive_lb))
naive_count_ub = int(binom.ppf(1-delta_2, N, naive_ub))
print(f"The uncorrected confidence interval for the number of plankton observed in 2014 is [{naive_count_lb},{naive_count_ub}] ([{100*naive_lb:.1f}%,{100*naive_ub:.1f}%]).")

# IID intervals

iid_epsilon = np.sqrt(2)*(2*np.sqrt(np.log(3/delta_1)/n) + np.sqrt(np.log(3/delta_1)/N))
bias_estimate = (calib_preds.astype(float) - calib_labels.astype(float)).mean()
iid_lb = qfhat[1] - bias_estimate - iid_epsilon
iid_ub = qfhat[1] - bias_estimate + iid_epsilon
iid_count_lb = int(binom.ppf(delta_2, N, naive_lb))
iid_count_ub = int(binom.ppf(1-delta_2, N, naive_ub))
print(f"The i.i.d. confidence interval for the number of plankton observed in 2014 is [{iid_count_lb},{iid_count_ub}] ([{100*iid_lb:.1f}%,{100*iid_ub:.1f}%]).")

The uncorrected confidence interval for the number of plankton observed in 2014 is [21014,23136] ([6.5%,6.9%]).
The i.i.d. confidence interval for the number of plankton observed in 2014 is [21014,23136] ([5.7%,8.3%]).
