In [1]:
import os, time, copy
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
from sklearn.preprocessing import normalize
from scipy.optimize import brentq
from scipy.stats import binom
from tqdm import tqdm

In [2]:
# Get data 2006-2014 from the following link: https://darchive.mblwhoilibrary.org/handle/1912/7341
# Unzip and merge the datasets in the following directory
calib_data = np.load('../calib-outputs.npz')
test_data = np.load('../test-outputs.npz')
calib_preds = calib_data['preds'].astype(int)
calib_labels = calib_data['labels'].astype(int)
test_preds = test_data['preds'].astype(int)
test_labels = test_data['labels'].astype(int)
classes = np.load('../classes.npy')
num_classes = classes.shape[0]

plankton_classes = np.isin(classes,['mix','mix_elongated','detritus','bad', 'bead', 'bubble', 'other_interaction', 'pollen', 'spore'],invert=True)
plankton_classes_list = np.where(plankton_classes)[0]

true_count = np.isin(test_labels, plankton_classes_list).sum()
uncorrected_est = np.isin(test_preds, plankton_classes_list).sum()
print((test_labels == test_preds).astype(float).mean())

0.9334570326711781


In [3]:
# Look at the unique classes 
calib_uq, calib_uq_counts = np.unique(calib_labels, return_counts=True)
calib_uq_freq = calib_uq_counts/calib_uq_counts.sum()
calib_uq_sort = np.argsort(calib_uq_freq)[::-1]
calib_uq_freq = calib_uq_freq[calib_uq_sort]; calib_uq = calib_uq[calib_uq_sort];
calib_uq_cumsum = np.cumsum(calib_uq_freq)

In [4]:
# Problem setup
alpha = 0.1
delta_1 = 0.1-1e-2
delta_2 = 1e-2
K = 3
nu = plankton_classes.astype(int)
nu_trunc = nu[calib_uq]
calib_uq = calib_uq[:K]
print(classes[calib_uq])

['mix' 'detritus' 'Leptocylindrus']


In [5]:
# Construct the confusion matrix
n = calib_preds.shape[0]
N = test_preds.shape[0]

C = np.zeros((num_classes,num_classes)).astype(int)
for j in range(num_classes):
    for l in range(num_classes):
        C[j,l] = ((calib_preds == j) & (calib_labels == l)).sum()

In [6]:
# Construct Ahat
Ahat = C[:,calib_uq][calib_uq,:]
Ahat = Ahat / Ahat.sum(axis=0)
Ahatinv = np.linalg.inv(Ahat)
print(np.array_str(Ahat, precision=3))
print(np.array_str(Ahatinv, precision=3))

[[9.781e-01 7.389e-02 5.638e-03]
 [2.194e-02 9.258e-01 2.737e-02]
 [3.318e-06 3.563e-04 9.670e-01]]
[[ 1.024e+00 -8.175e-02 -3.658e-03]
 [-2.428e-02  1.082e+00 -3.049e-02]
 [ 5.430e-06 -3.984e-04  1.034e+00]]


In [8]:
# Construct the point estimate
target_uq, target_uq_counts = np.unique(test_preds, return_counts=True)
target_uq_freq = target_uq_counts/target_uq_counts.sum()
target_uq_sort = np.argsort(target_uq_freq)[::-1]
target_uq_freq = target_uq_freq[target_uq_sort]; target_uq = target_uq[target_uq_sort];
target_uq = target_uq[:K]
qfhat = target_uq_freq[:K]
nu1 = nu_trunc[:K]
nu2 = nu_trunc[K:]
print(qfhat)
print(Ahatinv@qfhat)
print(nu1@Ahatinv@qfhat)

[0.78815882 0.13339215 0.0110723 ]
[0.79634491 0.12487685 0.01140152]
0.011401518499597628


In [9]:
# Run MAI
nmin = int(calib_uq_freq[:K].min()*n)
point_estimate = nu1@Ahatinv@qfhat
term1 = np.abs(nu1@Ahatinv).sum()*np.sqrt(2/nmin*np.log(2*(K+1)/delta_1))
term2 = np.abs(nu1@Ahatinv).sum()*np.sqrt(2/N*np.log(2*(K+1)/delta_1))
term3_ub = nu1.max()*alpha
term3_lb = nu1.min()*alpha

qyhat_ub = np.minimum(point_estimate + term1 + term2 + term3_ub,1)
qyhat_lb = np.maximum(point_estimate - term1 - term2 - term3_lb,0)

count_plankton_lb = int(binom.ppf(delta_2, N, qyhat_lb))
count_plankton_ub = int(binom.ppf(1-delta_2, N, qyhat_ub))

print(term1)
print(term2)
print(term3_lb)

0.028894820879305373
0.005396558931344003
0.0


In [10]:
print(f"The rectified confidence interval for the number of plankton observed in 2014 is [{count_plankton_lb},{count_plankton_ub}] ([{100*qyhat_lb:.1f}%,{100*qyhat_ub:.1f}%]).")
print(f"The true number of plankton observed in 2014 was {true_count} ({true_count/N*100:.1f}%), which lies in the interval.")
print(f"The uncorrected estimate was {uncorrected_est} ({uncorrected_est/N*100:.1f}%).")
print(f"The corrected estimate was {int(N*point_estimate)} ({point_estimate*100:.1f}%).")

The rectified confidence interval for the number of plankton observed in 2014 is [0,48526] ([0.0%,14.6%]).
The true number of plankton observed in 2014 was 23538 (7.1%), which lies in the interval.
The uncorrected estimate was 22068 (6.7%).
The corrected estimate was 3760 (1.1%).
