In [1]:
import os, time, copy
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
from sklearn.preprocessing import normalize
from scipy.optimize import brentq
from scipy.stats import binom
from tqdm import tqdm

In [2]:
# Get data 2006-2014 from the following link: https://darchive.mblwhoilibrary.org/handle/1912/7341
# Unzip and merge the datasets in the following directory
calib_data = np.load('../calib-outputs.npz')
test_data = np.load('../test-outputs.npz')
classes = np.load('../classes.npy')
calib_preds, calib_labels = calib_data['preds'], calib_data['labels']
test_preds, test_labels = test_data['preds'], test_data['labels']
plankton_classes = np.array(np.where(np.isin(classes,['mix','mix_elongated','detritus','bad', 'bead', 'bubble', 'other_interaction', 'pollen', 'spore'],invert=True))[0])

#combined_preds = np.concatenate([calib_preds, test_preds])
#combined_labels = np.concatenate([calib_labels, test_labels])
#plankton_loc = np.where(np.isin(combined_labels, plankton_classes))[0]
#nonplankton_loc = np.where(np.bitwise_not(np.isin(combined_labels, plankton_classes)))[0]
#
#num_plankton_total = plankton_loc.shape[0]
#num_nonplankton_total = combined_preds.shape[0]-num_plankton_total
#num_calib_plankton = int(num_plankton_total/10.0)
#num_calib_nonplankton = int(num_nonplankton_total/2.0)
#
#shuffled_plankton_idxs = np.random.choice(num_plankton_total, size=(num_plankton_total), replace=False)
#shuffled_nonplankton_idxs = np.random.choice(num_nonplankton_total, size=(num_nonplankton_total), replace=False)
#calib_plankton_idx, test_plankton_idx = shuffled_plankton_idxs[:num_calib_plankton], shuffled_plankton_idxs[num_calib_plankton:]
#calib_nonplankton_idx, test_nonplankton_idx = shuffled_nonplankton_idxs[:num_calib_nonplankton], shuffled_nonplankton_idxs[num_calib_nonplankton:]
#
#calib_preds = np.concatenate([combined_preds[plankton_loc][calib_plankton_idx], combined_preds[nonplankton_loc][calib_nonplankton_idx]])
#calib_labels = np.concatenate([combined_labels[plankton_loc][calib_plankton_idx], combined_labels[nonplankton_loc][calib_nonplankton_idx]])
#test_preds = np.concatenate([combined_preds[plankton_loc][test_plankton_idx], combined_preds[nonplankton_loc][test_nonplankton_idx]])
#test_labels = np.concatenate([combined_labels[plankton_loc][test_plankton_idx], combined_labels[nonplankton_loc][test_nonplankton_idx]])

In [3]:
# Convert to binary
calib_preds = np.isin(calib_preds, plankton_classes)
calib_labels = np.isin(calib_labels, plankton_classes)
test_preds = np.isin(test_preds, plankton_classes)
test_labels = np.isin(test_labels, plankton_classes)

In [8]:
n_y1 = calib_labels.sum()
n_y0 = calib_labels.shape[0]-calib_labels.sum()
N = test_preds.shape[0]

n_f0y0 = ((calib_labels == 0) & (calib_preds == 0)).astype(int).sum()
n_f0y1 = ((calib_labels == 1) & (calib_preds == 0)).astype(int).sum()
n_f1y0 = ((calib_labels == 0) & (calib_preds == 1)).astype(int).sum()
n_f1y1 = ((calib_labels == 1) & (calib_preds == 1)).astype(int).sum()
N_f1 = test_preds.sum()
N_f0 = N-N_f1

print(f"calib acc among non-plankton: {n_f0y0/n_y0:.3f} | calib acc among plankton: {n_f1y1/n_y1:.3f} | predicted fraction plankton: {N_f1/N:.3f}")

calib acc among non-plankton: 0.980 | calib acc among plankton: 0.848 | predicted fraction plankton: 0.067


In [9]:
test_n_y1 = test_labels.sum()
test_n_y0 = test_labels.shape[0]-test_labels.sum()

test_n_f0y0 = ((test_labels == 0) & (test_preds == 0)).astype(int).sum()
test_n_f0y1 = ((test_labels == 1) & (test_preds == 0)).astype(int).sum()
test_n_f1y0 = ((test_labels == 0) & (test_preds == 1)).astype(int).sum()
test_n_f1y1 = ((test_labels == 1) & (test_preds == 1)).astype(int).sum()

print(f"test acc among non-plankton: {test_n_f0y0/test_n_y0:.3f} | test acc among plankton: {test_n_f1y1/test_n_y1:.3f}")

test acc among non-plankton: 0.991 | test acc among plankton: 0.821


In [13]:
# Run MAI, estimating confusion matrix
delta = 0.05

def invert_for_lb_0(r): return binom.cdf(n_f0y0,n_y0,r)-(1-delta/8)
def invert_for_lb_1(r): return binom.cdf(n_f1y1,n_y1,r)-(1-delta/8)
def invert_for_ub_0(r): return binom.cdf(n_f0y0,n_y0,r)-(delta/8)
def invert_for_ub_1(r): return binom.cdf(n_f1y1,n_y1,r)-(delta/8)

def invert_for_lb_f(r): return binom.cdf(N_f1,N,r)-(1-delta/8)
def invert_for_ub_f(r): return binom.cdf(N_f1,N,r)-(delta/8)

c0_lb = brentq(invert_for_lb_0,0,1)
c0_ub = brentq(invert_for_ub_0,0,1)

c1_lb = brentq(invert_for_lb_1,0,1)
c1_ub = brentq(invert_for_ub_1,0,1)

f1_lb = brentq(invert_for_lb_f,0,1)
f1_ub = brentq(invert_for_ub_f,0,1)

A_lb = np.array([[c0_lb, 1-c1_ub], [1-c0_lb, c1_ub]])
A_ub = np.array([[c0_ub, 1-c1_lb], [1-c0_ub, c1_lb]])
A = np.array([[n_f0y0/n_y0, n_f0y1/n_y1],[n_f1y0/n_y0, n_f1y1/n_y1]])

qyhat_lb = (np.linalg.inv(A_lb)@np.array([1-f1_lb, f1_lb]))
qyhat_ub = (np.linalg.inv(A_ub)@np.array([1-f1_ub, f1_ub]))

count_plankton_lb = binom.ppf(delta/8, N, qyhat_lb[1])
count_plankton_ub = binom.ppf(1-delta/8, N, qyhat_ub[1])

print("c0_lb, c0_ub")
print(c0_lb,c0_ub,"\n")
print("c1_lb, c1_ub")
print(c1_lb,c1_ub,"\n")
print("f1_lb, f1_ub")
print(f1_lb,f1_ub,"\n")
print("A")
print(A)
print("Ainv@qfhat")
print(np.linalg.inv(A)@np.array([N_f0/N, N_f1/N]))
print("A_lb, A_ub")
print(A_lb,"\n\n",A_ub, "\n")
print(qyhat_lb, "\n", qyhat_ub)

c0_lb, c0_ub
0.9794139279621336 0.9805717323590207 

c1_lb, c1_ub
0.8439714266221167 0.8515319007473106 

f1_lb, f1_ub
0.0658275321190476 0.06800087496151788 

A
[[0.97999737 0.15222948]
 [0.02000263 0.84777052]]
Ainv@qfhat
[0.94333658 0.05666342]
A_lb, A_ub
[[0.97941393 0.1484681 ]
 [0.02058607 0.8515319 ]] 

 [[0.98057173 0.15602857]
 [0.01942827 0.84397143]] 

[0.94555426 0.05444574] 
 [0.94109149 0.05890851]


In [None]:
print(f"The model-assisted confidence interval for the fraction of plankton observed in 2014 is [{qyhat_lb[1]:.4f},{qyhat_ub[1]:.4f}]; the true fraction was {test_labels.mean():.4f}")
print(f"The model-assisted confidence interval for the number of plankton observed in 2014 is [{int(count_plankton_lb)},{int(count_plankton_ub)}].")
print(f"The true number of plankton observed in 2014 was {test_labels.sum()}, which lies in the interval.")
print(f"The point estimate was {test_preds.sum()}, which does not lie in the interval.")