In [2]:
import os, time, copy
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
from sklearn.preprocessing import normalize
from scipy.optimize import brentq
from scipy.stats import binom
from tqdm import tqdm

In [9]:
# Get data 2006-2014 from the following link: https://darchive.mblwhoilibrary.org/handle/1912/7341
# Unzip and merge the datasets in the following directory
calib_data = np.load('../calib-outputs.npz')
test_data = np.load('../test-outputs.npz')
calib_preds = calib_data['preds']
calib_labels = calib_data['labels']
test_preds = test_data['preds']
test_labels = test_data['labels']
classes = np.load('../classes.npy')
plankton_classes = np.isin(classes,['mix','mix_elongated','detritus','bad', 'bead', 'bubble', 'other_interaction', 'pollen', 'spore'],invert=True)

In [8]:
# Construct the confusion matrix
C = confusion_matrix(calib_labels, calib_preds)



In [39]:
n_y1 = calib_labels.sum()
n_y0 = calib_labels.shape[0]-calib_labels.sum()
N = test_preds.shape[0]

n_f0y0 = ((calib_labels == 0) & (calib_preds == 0)).astype(int).sum()
n_f1y1 = ((calib_labels == 1) & (calib_preds == 1)).astype(int).sum()
N_f1 = test_preds.sum()
N_f0 = N-N_f1

print(f"calib acc among non-plankton: {n_f0y0/n_y0:.3f} | calib acc among plankton: {n_f1y1/n_y1:.3f} | predicted fraction plankton: {N_f1/N:.3f}")

calib acc among non-plankton: 0.983 | calib acc among plankton: 0.574 | predicted fraction plankton: 0.051


In [40]:
test_n_y1 = test_labels.sum()
test_n_y0 = test_labels.shape[0]-test_labels.sum()

test_n_f0y0 = ((test_labels == 0) & (test_preds == 0)).astype(int).sum()
test_n_f1y1 = ((test_labels == 1) & (test_preds == 1)).astype(int).sum()

print(f"test acc among non-plankton: {test_n_f0y0/test_n_y0:.3f} | test acc among plankton: {test_n_f1y1/test_n_y1:.3f}")

test acc among non-plankton: 0.992 | test acc among plankton: 0.612


In [41]:
# Run MAI, estimating confusion matrix
delta = 0.05

def invert_for_lb_0(r): return binom.cdf(n_f0y0,n_y0,r)-(1-delta/8)
def invert_for_lb_1(r): return binom.cdf(n_f1y1,n_y1,r)-(1-delta/8)
def invert_for_ub_0(r): return binom.cdf(n_f0y0,n_y0,r)-(delta/8)
def invert_for_ub_1(r): return binom.cdf(n_f1y1,n_y1,r)-(delta/8)

def invert_for_lb_f(r): return binom.cdf(N_f1,N,r)-(1-delta/8)
def invert_for_ub_f(r): return binom.cdf(N_f1,N,r)-(delta/8)

c0_lb = brentq(invert_for_lb_0,0,1)
c0_ub = brentq(invert_for_ub_0,0,1)

c1_lb = brentq(invert_for_lb_1,0,1)
c1_ub = brentq(invert_for_ub_1,0,1)

f1_lb = brentq(invert_for_lb_f,0,1)
f1_ub = brentq(invert_for_ub_f,0,1)

A_lb = np.array([[c0_lb, 1-c1_ub], [1-c0_lb, c1_ub]])
A_ub = np.array([[c0_ub, 1-c1_lb], [1-c0_ub, c1_lb]])

qyhat_lb = (np.linalg.inv(A_lb)@np.array([1-f1_lb, f1_lb]))[1]
qyhat_ub = (np.linalg.inv(A_ub)@np.array([1-f1_ub, f1_ub]))[1]

count_plankton_lb = binom.ppf(delta/8, N, qyhat_lb)
count_plankton_ub = binom.ppf(1-delta/8, N, qyhat_ub)

print(c0_lb,c0_ub)
print(c1_lb,c1_ub)
print(f1_lb,f1_ub)
print(A_lb,A_ub)

0.9823060211427301 0.9833797296163465
0.5689263506318462 0.5793329141665229
0.050512296719076165 0.052434142356214616
[[0.98230602 0.42066709]
 [0.01769398 0.57933291]] [[0.98337973 0.43107365]
 [0.01662027 0.56892635]]


In [43]:
print(f"The model-assisted confidence interval for the fraction of plankton observed in 2014 is [{qyhat_lb:.4f},{qyhat_ub:.4f}]; the true fraction was {test_labels.mean():.4f}")
print(f"The model-assisted confidence interval for the number of plankton observed in 2014 is [{int(count_plankton_lb)},{int(count_plankton_ub)}].")
print(f"The true number of plankton observed in 2014 was {test_labels.sum()}, which lies in the interval.")
print(f"The point estimate was {test_preds.sum()}, which does not lie in the interval.")

The model-assisted confidence interval for the fraction of plankton observed in 2014 is [0.0584,0.0648]; the true fraction was 0.0714
The model-assisted confidence interval for the number of plankton observed in 2014 is [18937,21742].
The true number of plankton observed in 2014 was 23538, which lies in the interval.
The point estimate was 16975, which does not lie in the interval.
