In [126]:
import os, time, copy
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
from sklearn.preprocessing import normalize
from scipy.optimize import brentq
from scipy.stats import binom
from tqdm import tqdm

In [133]:
# Get data 2006-2014 from the following link: https://darchive.mblwhoilibrary.org/handle/1912/7341
# Unzip and merge the datasets in the following directory
calib_data = np.load('../calib-outputs.npz')
test_data = np.load('../test-outputs.npz')
calib_preds = calib_data['preds'].astype(int)
calib_labels = calib_data['labels'].astype(int)
test_preds = test_data['preds'].astype(int)
test_labels = test_data['labels'].astype(int)
classes = np.load('../classes.npy')

valid = np.array([j for j in range(classes.shape[0]) if ((calib_labels == j).sum() > 0)]) # CHECK FOR VALID CLASSES

plankton_classes = np.isin(classes,['mix','mix_elongated','detritus','bad', 'bead', 'bubble', 'other_interaction', 'pollen', 'spore'],invert=True)

In [128]:
# Construct the confusion matrix
num_classes = classes.shape[0]
n = calib_preds.shape[0]
N = test_preds.shape[0]

C = np.zeros((num_classes,num_classes)).astype(int)
for i in range(num_classes):
    for j in range(num_classes):
        C[i,j] = ((calib_labels == i) & (calib_preds == j)).sum()
num_per_class = C.sum(axis=0)
Ahat = np.nan_to_num(C/num_per_class[None,:])
fhat = np.zeros((test_preds.shape[0],num_classes))
fhat[np.arange(test_preds.shape[0]), test_preds-1] = 1
fhat = fhat.mean(axis=0)

  Ahat = np.nan_to_num(C/num_per_class[None,:])


In [129]:
# Upper and lower bounds on A based on https://rlgammazero.github.io/docs/JFPL2018notesPSRL.pdf
delta = 15
A_ub = Ahat + np.sqrt( 2 * num_classes * np.log(16*num_classes/delta) / num_per_class )[None,:]
A_ub = np.maximum(np.minimum(np.nan_to_num(A_ub), 1),0)
A_lb = Ahat - np.sqrt( 2 * num_classes * np.log(16*num_classes/delta) / num_per_class )[None,:]
A_lb = np.maximum(np.minimum(np.nan_to_num(A_lb), 1),0)
f_lb = np.minimum(np.maximum(fhat - np.sqrt( 2 * num_classes * np.log(16/delta) / N),0),1)
f_ub = np.minimum(np.maximum(fhat + np.sqrt( 2 * num_classes * np.log(16/delta) / N),0),1)
print(f_ub)

[0.00634887 0.00640345 0.00634887 0.00634887 0.00748278 0.00634887
 0.00635494 0.0078557  0.00634887 0.006361   0.00634887 0.00634887
 0.00634887 0.00634887 0.00872281 0.00634887 0.00734938 0.00636707
 0.00992342 0.00636403 0.00742518 0.00634887 0.00642467 0.00634887
 0.00768289 0.00634887 0.00648531 0.00634887 0.00634887 0.00634887
 0.00637313 0.00634887 0.00634887 0.00634887 0.00634887 0.00638526
 0.00634887 0.01416801 0.00641557 0.00656717 0.00635191 0.00634887
 0.00634887 0.00634887 0.00634887 0.0064277  0.00634887 0.00634887
 0.02514329 0.00634887 0.00634887 0.00647924 0.00634887 0.00635191
 0.00634887 0.00655504 0.00634887 0.0064368  0.00679152 0.00634887
 0.00634887 0.00634887 0.00637313 0.00634887 0.0104358  0.00647924
 0.00634887 0.00634887 0.00634887 0.00634887 0.00634887 0.00634887
 0.00634887 0.00634887 0.00634887 0.00655201 0.00705226 0.00651866
 0.00634887 0.00643983 0.00634887 0.00634887 0.00634887 0.00653988
 0.00634887 0.00634887 0.00634887 0.09909003 0.00634887 0.0121

  A_ub = Ahat + np.sqrt( 2 * num_classes * np.log(16*num_classes/delta) / num_per_class )[None,:]
  A_lb = Ahat - np.sqrt( 2 * num_classes * np.log(16*num_classes/delta) / num_per_class )[None,:]


In [130]:
# Run MAI, estimating confusion matrix
qyhat_lb = (np.linalg.pinv(A_lb)@f_lb)[plankton_classes].sum()
qyhat_ub = (np.linalg.pinv(A_ub)@f_ub)[plankton_classes].sum()

count_plankton_lb = binom.ppf(delta/8, N, qyhat_lb)
count_plankton_ub = binom.ppf(1-delta/8, N, qyhat_ub)

counts = N*(np.linalg.pinv(Ahat)@fhat)[plankton_classes].sum()
print(Ahat.sum(axis=0))

[0. 0. 1. 0. 0. 1. 0. 1. 1. 0. 1. 0. 0. 0. 0. 1. 0. 1. 1. 1. 1. 1. 0. 1.
 0. 1. 0. 1. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 1. 1. 1. 0. 0. 0. 0. 0. 1. 0.
 0. 1. 0. 0. 1. 0. 1. 0. 1. 0. 1. 1. 0. 0. 0. 1. 0. 1. 1. 0. 0. 0. 0. 0.
 1. 0. 0. 0. 1. 1. 1. 0. 1. 0. 0. 0. 1. 0. 0. 0. 1. 0. 1. 0. 0. 0. 1. 1.
 0. 1. 0. 0. 0. 0. 0.]


In [131]:
print(f"The model-assisted estimate for the number of plankton observed in 2014 is {counts}.")
print(f"The true number of plankton observed in 2014 was {test_labels.mean()}, which lies in the interval.")
print(f"The point estimate was {test_preds.sum()}, which does not lie in the interval.")

The model-assisted estimate for the number of plankton observed in 2014 is -229485.63748657712.
The true number of plankton observed in 2014 was 90.079334327779, which lies in the interval.
The point estimate was 30010989, which does not lie in the interval.
