In [188]:
import os, time, copy
import sys
sys.path.insert(1, '../../')
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
from sklearn.preprocessing import normalize
from tqdm import tqdm
from scipy.stats import norm, binom
from scipy.optimize import brentq
from concentration import linfty_dkw, linfty_binom, wsr_iid

In [189]:
# Get data 2006-2014 from the following link: https://darchive.mblwhoilibrary.org/handle/1912/7341
# Unzip and merge the datasets in the following directory
calib_data = np.load('../calib-outputs.npz')
test_data = np.load('../test-outputs.npz')
calib_preds = calib_data['preds'].astype(int)
calib_labels = calib_data['labels'].astype(int)
test_preds = test_data['preds'].astype(int)
test_labels = test_data['labels'].astype(int)
classes = np.load('../classes.npy')
num_classes = classes.shape[0]

plankton_classes = np.isin(classes,['mix','mix_elongated','detritus','bad', 'bead', 'bubble', 'other_interaction', 'pollen', 'spore'],invert=True)
plankton_classes_list = np.where(plankton_classes)[0]

true_count = np.isin(test_labels, plankton_classes_list).sum()
uncorrected_est = np.isin(test_preds, plankton_classes_list).sum()
print((test_labels == test_preds).astype(float).mean())

0.9334570326711781


In [190]:
# Look at the unique classes 
calib_uq, calib_uq_counts = np.unique(calib_labels, return_counts=True)
calib_uq_freq = calib_uq_counts/calib_uq_counts.sum()
calib_uq_sort = np.argsort(calib_uq_freq)[::-1]
calib_uq_freq = calib_uq_freq[calib_uq_sort]; calib_uq = calib_uq[calib_uq_sort];
calib_uq_cumsum = np.cumsum(calib_uq_freq)

In [191]:
# Problem setup
alpha = 0.05
delta_1 = 0.05-1e-2
delta_2 = 1e-2
K = 10
nu = plankton_classes.astype(int)
nu_trunc = nu[calib_uq]
calib_uq = calib_uq[:K]
print(classes[calib_uq])

['mix' 'detritus' 'Leptocylindrus' 'Chaetoceros' 'mix_elongated' 'dino30'
 'Cerataulina' 'Skeletonema' 'Cylindrotheca' 'Guinardia_delicatula']


In [192]:
# Construct the confusion matrix
n = calib_preds.shape[0]
N = test_preds.shape[0]

C = np.zeros((num_classes,num_classes)).astype(int)
for j in range(num_classes):
    for l in range(num_classes):
        C[j,l] = ((calib_preds == j) & (calib_labels == l)).sum()

In [193]:
# Construct Ahat
Ahat = C[:,calib_uq][calib_uq,:]
Ahat = Ahat / Ahat.sum(axis=0)
Ahatinv = np.linalg.inv(Ahat)
print(np.array_str(Ahat, precision=3))
print(np.array_str(Ahatinv, precision=3))

[[9.672e-01 7.258e-02 4.598e-03 1.267e-01 4.074e-02 1.410e-01 1.355e-03
  2.092e-02 3.830e-03 0.000e+00]
 [2.170e-02 9.093e-01 2.232e-02 5.128e-02 8.244e-02 2.693e-02 1.716e-01
  6.628e-02 1.596e-02 2.885e-03]
 [3.282e-06 3.499e-04 7.885e-01 6.047e-04 1.072e-01 0.000e+00 1.084e-02
  2.142e-02 1.596e-03 6.131e-03]
 [4.289e-03 9.365e-03 3.831e-04 8.045e-01 1.066e-02 3.591e-04 5.422e-04
  5.292e-03 9.575e-04 0.000e+00]
 [7.646e-04 5.743e-03 1.505e-01 1.282e-02 6.720e-01 0.000e+00 9.894e-02
  1.469e-01 5.841e-02 1.442e-03]
 [5.825e-03 5.146e-04 0.000e+00 1.209e-04 0.000e+00 8.317e-01 0.000e+00
  0.000e+00 0.000e+00 0.000e+00]
 [3.282e-06 7.616e-04 2.069e-02 1.209e-04 4.020e-02 0.000e+00 6.996e-01
  1.764e-03 0.000e+00 7.212e-04]
 [9.189e-05 7.616e-04 1.628e-03 2.902e-03 1.039e-02 0.000e+00 4.066e-03
  7.369e-01 0.000e+00 0.000e+00]
 [5.907e-05 4.117e-04 1.916e-03 0.000e+00 1.928e-02 0.000e+00 0.000e+00
  0.000e+00 9.192e-01 0.000e+00]
 [1.641e-05 2.264e-04 9.483e-03 9.675e-04 1.709e-02 0.0

In [194]:
# Construct the point estimate
target_uq, target_uq_counts = np.unique(test_preds, return_counts=True)
target_uq_freq = target_uq_counts/target_uq_counts.sum()
target_uq_sort = np.argsort(target_uq_freq)[::-1]
target_uq_freq = target_uq_freq[target_uq_sort]; target_uq = target_uq[target_uq_sort];
target_uq = target_uq[:K]
qfhat = target_uq_freq[:K]
nu1 = nu_trunc[:K]
nu2 = nu_trunc[K:]
print(qfhat)
print(Ahatinv@qfhat)
print(nu1@Ahatinv@qfhat)

[0.78815882 0.13339215 0.0110723  0.00990201 0.00918346 0.00676708
 0.00629411 0.00622438 0.00366247 0.00289845]
[0.80385235 0.12407735 0.01288754 0.00643702 0.00539448 0.00242871
 0.00814223 0.00804347 0.003737   0.0025551 ]
0.04423105813425994


In [195]:
point_estimate = nu1@Ahatinv@qfhat

nmin = C[:,calib_uq][calib_uq,:].sum(axis=0).min()

theta = 0.99 # The solution to the optimal theta is extreme in this case, so we clip.
epsilon1 = max([linfty_binom(C[:,calib_uq][calib_uq,:].sum(axis=0)[k], K, theta*delta_1, Ahat[:,k]) for k in range(K)])
epsilon2 = linfty_dkw(N,K,(1-theta)*delta_1)

qyhat_lb = np.clip(point_estimate - epsilon1 - epsilon2, 0, 1)
qyhat_ub = np.clip(point_estimate + epsilon1 + epsilon2, 0, 1) + alpha

count_plankton_lb = int(binom.ppf(delta_2, N, qyhat_lb))
count_plankton_ub = int(binom.ppf(1-delta_2, N, qyhat_ub))

print(nmin)
print(epsilon1)
print(epsilon2)

2773
0.021913501602999652
0.0071864904510905315


In [196]:
print(f"The rectified confidence interval for the number of plankton observed in 2014 is [{count_plankton_lb},{count_plankton_ub}] ([{100*qyhat_lb:.1f}%,{100*qyhat_ub:.1f}%]).")
print(f"The true number of plankton observed in 2014 was {true_count} ({true_count/N*100:.1f}%), which lies in the interval.")
print(f"The uncorrected estimate was {uncorrected_est} ({uncorrected_est/N*100:.1f}%).")
print(f"The corrected estimate was {int(N*point_estimate)} ({point_estimate*100:.1f}%).")

The rectified confidence interval for the number of plankton observed in 2014 is [4828,41118] ([1.5%,12.3%]).
The true number of plankton observed in 2014 was 23538 (7.1%), which lies in the interval.
The uncorrected estimate was 22068 (6.7%).
The corrected estimate was 14588 (4.4%).
