In [1]:
import os, time, copy
import sys
sys.path.insert(1, '../../')
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
from sklearn.preprocessing import normalize
from tqdm import tqdm
from scipy.stats import norm, binom
from scipy.optimize import brentq
from concentration import linfty_dkw, linfty_binom, wsr_iid

In [2]:
# Get data 2006-2014 from the following link: https://darchive.mblwhoilibrary.org/handle/1912/7341
# Unzip and merge the datasets in the following directory
calib_data = np.load('../calib-outputs.npz')
test_data = np.load('../test-outputs.npz')
calib_preds = calib_data['preds'].astype(int)
calib_labels = calib_data['labels'].astype(int)
test_preds = test_data['preds'].astype(int)
test_labels = test_data['labels'].astype(int)
classes = np.load('../classes.npy')
num_classes = classes.shape[0]

plankton_classes = np.isin(classes,['mix','mix_elongated','detritus','bad', 'bead', 'bubble', 'other_interaction', 'pollen', 'spore'],invert=True)
plankton_idxs = np.where(plankton_classes)[0]

true_count = np.isin(test_labels, plankton_idxs).sum()
uncorrected_est = np.isin(test_preds, plankton_idxs).sum()
print((test_labels == test_preds).astype(float).mean())

0.9334570326711781


In [3]:
# Look at the unique classes 
calib_uq, calib_uq_counts = np.unique(calib_labels, return_counts=True)
calib_uq_freq = calib_uq_counts/calib_uq_counts.sum()
calib_uq_sort = np.argsort(calib_uq_freq)[::-1]
calib_uq_freq = calib_uq_freq[calib_uq_sort]; calib_uq = calib_uq[calib_uq_sort];
calib_uq_cumsum = np.cumsum(calib_uq_freq)

In [4]:
# Problem setup
alpha = 0.018
delta_1 = 0.90-1e-2
delta_2 = 1e-2
K = 15
nu = 1-plankton_classes.astype(int)
nu_trunc = nu[calib_uq]
calib_uq = calib_uq[:K]
print(classes[calib_uq])
print(calib_uq_cumsum[15])

['mix' 'detritus' 'Leptocylindrus' 'Chaetoceros' 'mix_elongated' 'dino30'
 'Cerataulina' 'Skeletonema' 'Cylindrotheca' 'Guinardia_delicatula'
 'Rhizosolenia' 'Thalassiosira' 'bad' 'Ciliate_mix' 'DactFragCerataul']
0.9811721639548188


In [6]:
# Construct the confusion matrix
n = calib_preds.shape[0]
N = test_preds.shape[0]

C = np.zeros((num_classes,num_classes)).astype(int)
for j in range(num_classes):
    for l in range(num_classes):
        C[j,l] = ((calib_preds == j) & (calib_labels == l)).sum()

In [14]:
# Construct Ahat
Ahat = C[calib_uq,:][:,calib_uq]
Ahat = Ahat / Ahat.sum(axis=0)
Ahatinv = np.linalg.inv(Ahat)

In [15]:
# Construct the point estimate
target_uq, target_uq_counts = np.unique(test_preds, return_counts=True)
target_uq_freq = target_uq_counts/target_uq_counts.sum()
target_uq_sort = np.argsort(target_uq_freq)[::-1]
target_uq_freq = target_uq_freq[target_uq_sort]; target_uq = target_uq[target_uq_sort];
target_uq = target_uq[:K]
qfhat = target_uq_freq[:K]
nu1 = nu_trunc[:K]
nu2 = nu_trunc[K:]
print(qfhat)
print(Ahatinv@qfhat)
print(nu1@Ahatinv@qfhat)
print(target_uq)

[0.78815882 0.13339215 0.0110723  0.00990201 0.00918346 0.00676708
 0.00629411 0.00622438 0.00366247 0.00289845 0.00242245 0.00231027
 0.00178576 0.00154018 0.0014947 ]
[0.80449637 0.12426003 0.01349827 0.00641193 0.0056779  0.00238513
 0.00801195 0.00808476 0.00374792 0.00244556 0.0013734  0.00206961
 0.00167563 0.00136607 0.00160406]
0.9361099323597891
[94 88 95 49 90 65 19  8 38 15 77 21 25 66  5]


In [16]:
point_estimate = nu1@Ahatinv@qfhat

nmin = C[:,calib_uq][calib_uq,:].sum(axis=0).min()

theta = 0.99 # The solution to the optimal theta is extreme in this case, so we clip.
epsilon1 = max([linfty_binom(C[:,calib_uq][calib_uq,:].sum(axis=0)[k], K, theta*delta_1, Ahat[:,k]) for k in range(K)])
epsilon2 = linfty_dkw(N,K,(1-theta)*delta_1)

lower_constant = alpha * np.abs(np.maximum((nu1@Ahatinv).max(),0) + nu2.min())
upper_constant = alpha * np.abs(np.minimum((nu1@Ahatinv).min(),0) + nu2.max())

qyhat_lb = np.clip(point_estimate - epsilon1 - epsilon2 - lower_constant, 0, 1)
qyhat_ub = np.clip(point_estimate + epsilon1 + epsilon2 + upper_constant, 0, 1)

count_nonplankton_lb = int(binom.ppf(delta_2, N, qyhat_lb))
count_nonplankton_ub = int(binom.ppf(1-delta_2, N, qyhat_ub))

print(nmin)
print(epsilon1)
print(epsilon2)
print(lower_constant)
print(upper_constant)
print(nu1@Ahatinv)

  return _boost._binom_cdf(k, n, p)


1140
0.02789705999864933
0.005730095582478624
0.026490213141110375
0.00981241866581435
[ 1.01331872  1.01475416 -0.30177606 -0.24600134  1.47167851 -0.19618051
 -0.45486563 -0.4014275  -0.11443569 -0.00266586 -0.03275331 -0.15534191
  0.99991892 -0.0146222   0.09829654]


In [17]:
print(f"The rectified confidence interval for the number of plankton observed in 2014 is [{N-count_nonplankton_ub},{N-count_nonplankton_lb}] ([{100*(1-qyhat_ub):.1f}%,{100*(1-qyhat_lb):.1f}%]).")
print(f"The true number of plankton observed in 2014 was {true_count} ({true_count/N*100:.1f}%), which lies in the interval.")
print(f"The uncorrected estimate was {uncorrected_est} ({uncorrected_est/N*100:.1f}%).")
print(f"The corrected estimate was {int(N*(1-point_estimate))} ({(1-point_estimate)*100:.1f}%).")

The rectified confidence interval for the number of plankton observed in 2014 is [6557,41343] ([2.0%,12.4%]).
The true number of plankton observed in 2014 was 23538 (7.1%), which lies in the interval.
The uncorrected estimate was 22068 (6.7%).
The corrected estimate was 21072 (6.4%).
