In [1]:
import os, time, copy
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
from sklearn.preprocessing import normalize
from scipy.optimize import brentq
from scipy.stats import binom
from tqdm import tqdm

In [2]:
# Get data 2006-2014 from the following link: https://darchive.mblwhoilibrary.org/handle/1912/7341
# Unzip and merge the datasets in the following directory
calib_data = np.load('../calib-outputs.npz')
test_data = np.load('../test-outputs.npz')
calib_preds = calib_data['preds'].astype(int)
calib_labels = calib_data['labels'].astype(int)
test_preds = test_data['preds'].astype(int)
test_labels = test_data['labels'].astype(int)
classes = np.load('../classes.npy')
num_classes = classes.shape[0]

plankton_classes = np.isin(classes,['mix','mix_elongated','detritus','bad', 'bead', 'bubble', 'other_interaction', 'pollen', 'spore'],invert=True)

In [3]:
# Look at the unique classes 
uq, uq_counts = np.unique(calib_labels, return_counts=True)
uq_freq = uq_counts/uq_counts.sum()
uq_sort = np.argsort(uq_freq)[::-1]
uq_freq = uq_freq[uq_sort]; uq = uq[uq_sort];
uq_cumsum = np.cumsum(uq_freq)
print(np.searchsorted(uq_cumsum,0.97))

12


In [4]:
# Problem setup
alpha = 0.02
K = 10
nu = plankton_classes.astype(int)
nu_trunc = nu[uq]
uq = uq[:K]

In [5]:
# Construct the confusion matrix
n = calib_preds.shape[0]
N = test_preds.shape[0]

calib_preds[np.logical_not(np.isin(calib_preds,uq))] = uq[-1]
test_preds[np.logical_not(np.isin(test_preds,uq))] = uq[-1]

C = np.zeros((num_classes,num_classes)).astype(int)
for j in range(num_classes):
    for l in range(num_classes):
        C[j,l] = ((calib_preds == j) & (calib_labels == l)).sum()

In [7]:
# Construct Ahat
Ahat = C[:,uq][uq,:]
Ahat = Ahat / (Ahat.sum(axis=1)[:,None])
print(np.array_str(Ahat, precision=3))
print(Ahat.sum(axis=1))

[[9.135e-01 4.902e-02 8.056e-04 1.554e-02 4.698e-03 8.744e-03 9.359e-04
  5.221e-03 1.372e-03 1.333e-04]
 [5.696e-02 7.993e-01 5.128e-03 3.461e-02 2.307e-02 1.108e-02 4.884e-02
  1.709e-02 1.357e-03 2.514e-03]
 [1.670e-02 2.794e-02 5.804e-01 8.913e-03 1.676e-01 0.000e+00 5.740e-02
  7.204e-02 6.585e-02 3.126e-03]
 [1.670e-02 5.188e-02 4.293e-02 7.728e-01 4.830e-02 0.000e+00 2.147e-02
  4.234e-02 5.963e-04 2.982e-03]
 [4.933e-02 4.596e-02 2.016e-01 3.369e-02 4.519e-01 0.000e+00 5.173e-02
  6.376e-02 1.008e-01 1.203e-03]
 [2.151e-01 1.097e-02 0.000e+00 1.062e-03 3.539e-04 7.718e-01 3.539e-04
  0.000e+00 0.000e+00 3.539e-04]
 [0.000e+00 2.149e-02 1.777e-01 5.731e-03 1.576e-01 0.000e+00 6.060e-01
  3.009e-02 0.000e+00 1.433e-03]
 [1.515e-02 3.788e-02 4.545e-02 3.788e-02 2.273e-02 0.000e+00 1.212e-01
  7.197e-01 0.000e+00 0.000e+00]
 [1.169e-02 6.098e-03 1.382e-01 5.081e-04 1.865e-01 0.000e+00 0.000e+00
  4.573e-03 6.524e-01 0.000e+00]
 [4.662e-02 6.827e-02 1.660e-01 4.829e-02 9.158e-02 7.0

In [8]:
# Construct the point estimate
uq, uq_counts = np.unique(test_preds, return_counts=True)
uq_freq = uq_counts/uq_counts.sum()
uq_sort = np.argsort(uq_freq)[::-1]
uq_freq = uq_freq[uq_sort]; uq = uq[uq_sort];
uq = uq[:K]
qf = uq_freq[:K]
nu1 = nu_trunc[:K]
nu2 = nu_trunc[K:]

In [10]:
# Run MAI
estimate = np.dot(nu1,np.dot(np.linalg.inv(Ahat),qf))
print(np.linalg.inv(Ahat)@qf)
print(estimate)
print(nu1)
print(qf)
print(qf.sum())
#count_plankton_lb = binom.ppf(delta/8, N, qyhat_lb)
#count_plankton_ub = binom.ppf(1-delta/8, N, qyhat_ub)

#counts = N*(np.linalg.pinv(Ahat)@fhat)[plankton_classes].sum()
#print(Ahat.sum(axis=0))

[ 0.92678428  0.05535354  0.03673673  0.00597457 -0.1044576  -0.25156296
  0.02157415 -0.02330044  0.00685244 -0.06073969]
-0.2644651995240214
[0 0 1 1 0 1 1 1 1 1]
[8.46688617e-01 9.27411531e-02 2.07075117e-02 1.87944166e-02
 8.91362876e-03 5.80901792e-03 3.57454704e-03 1.50682772e-03
 1.13391060e-03 1.30369400e-04]
1.0


In [385]:
print(f"The model-assisted estimate for the number of plankton observed in 2014 is {counts}.")
print(f"The true number of plankton observed in 2014 was {test_labels.mean()}, which lies in the interval.")
print(f"The point estimate was {test_preds.sum()}, which does not lie in the interval.")

NameError: name 'counts' is not defined