In [39]:
import os, time, copy
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
from sklearn.preprocessing import normalize
from scipy.optimize import brentq
from scipy.stats import binom
from tqdm import tqdm

In [40]:
# Get data 2006-2014 from the following link: https://darchive.mblwhoilibrary.org/handle/1912/7341
# Unzip and merge the datasets in the following directory
calib_data = np.load('./calib-outputs.npz')
test_data = np.load('./test-outputs.npz')
calib_preds = calib_data['preds']
calib_labels = calib_data['labels']
test_preds = test_data['preds']
test_labels = test_data['labels']
classes = np.load('./classes.npy')

In [41]:
plankton_classes = np.array(np.where(np.isin(classes,['mix','mix_elongated','detritus','bad', 'bead', 'bubble', 'other_interaction', 'pollen', 'spore'],invert=True))[0])
print({ classes[j] : j for j in range(classes.shape[0])})

{'Akashiwo': 0, 'Amphidinium_sp': 1, 'Asterionellopsis': 2, 'Bacillaria': 3, 'Bidulphia': 4, 'Cerataulina': 5, 'Cerataulina_flagellate': 6, 'Ceratium': 7, 'Chaetoceros': 8, 'Chaetoceros_didymus': 9, 'Chaetoceros_didymus_flagellate': 10, 'Chaetoceros_flagellate': 11, 'Chaetoceros_other': 12, 'Chaetoceros_pennate': 13, 'Chrysochromulina': 14, 'Ciliate_mix': 15, 'Cochlodinium': 16, 'Corethron': 17, 'Coscinodiscus': 18, 'Cylindrotheca': 19, 'DactFragCerataul': 20, 'Dactyliosolen': 21, 'Delphineis': 22, 'Dictyocha': 23, 'Didinium_sp': 24, 'Dinobryon': 25, 'Dinophysis': 26, 'Ditylum': 27, 'Ditylum_parasite': 28, 'Emiliania_huxleyi': 29, 'Ephemera': 30, 'Eucampia': 31, 'Euglena': 32, 'Euplotes_sp': 33, 'G_delicatula_detritus': 34, 'G_delicatula_external_parasite': 35, 'G_delicatula_parasite': 36, 'Gonyaulax': 37, 'Guinardia_delicatula': 38, 'Guinardia_flaccida': 39, 'Guinardia_striata': 40, 'Gyrodinium': 41, 'Hemiaulus': 42, 'Heterocapsa_triquetra': 43, 'Karenia': 44, 'Katodinium_or_Torodiniu

In [42]:
calib_unique, calib_counts = np.unique(calib_labels, return_counts=True)
argsort = np.argsort(calib_counts)[::-1] # Most to least likely
calib_unique, calib_counts = calib_unique[argsort], calib_counts[argsort]
calib_accuracies = np.array([ (calib_preds[calib_labels == label] == label).mean() for label in calib_unique])
print(calib_unique)
print(np.around(calib_counts/calib_counts.sum(),3))
print(np.around(calib_accuracies,3))

[ 94.  88.  49.   8.  95.  90.   5.  66.  19.  38.  65.  77.  84.  15.
  20.  92.  17.  21.  97.  59.   2.  25.  78.  96.  52.  63.  34.   9.
  41.  40.  82.  13.  60. 101.  64.  27.  76.  23.  32.  99.  80.  83.
  22.  62.  89.  43.  35.  45.  18.  39.  58.  73.  72.  29.   6.  36.
  14.  31.   1.  54.  11.  71.  68.  91.  98.  30.   7.  75.  46.  12.
  26.  74.  81.  10.  51.  57.  28.  33.  70. 102.  56.  69.  47.  48.
  53.  55.  85.   3.  16.  61.  87.  67. 100.]
[0.725 0.117 0.027 0.02  0.019 0.014 0.01  0.01  0.008 0.007 0.007 0.006
 0.005 0.003 0.003 0.002 0.002 0.002 0.001 0.001 0.001 0.001 0.001 0.001
 0.001 0.001 0.001 0.    0.    0.    0.    0.    0.    0.    0.    0.
 0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.
 0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.
 0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.
 0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.
 0.    0.    0.    0.    0.    

In [43]:
test_unique, test_counts = np.unique(test_labels, return_counts=True)
argsort = np.argsort(test_counts)[::-1] # Most to least likely
test_unique, test_counts = test_unique[argsort], test_counts[argsort]
test_accuracies = np.array([ (test_preds[test_labels == label] == label).mean() for label in test_unique])
print(test_unique)
print(np.around(test_counts/test_counts.sum(),3))
print(np.around(test_accuracies,3))

[ 94.  88.  49.  90.  95.  19.  65.   8.  15.  38.  77.  25.  63.  21.
  40.  17.   5.  66.  97.  27.  96.  78.  92.  20.  59.   2.  52.  58.
  60.  83.  62.  80.  43.   1.  56.  23.  41.  55.  36.  22.  14.  72.
  31.  39.  76.  45.  50.  89.  99.  30. 101.  28.  64.  32.  98.  85.
  18.  51.  54.  26.  73.  82.   9.  57.  84.  12.  74.  46.  35.  29.
  68.  87.  24.  13.  71.   7.   6. 102.  91.  16.  11.  93. 100.  33.
  48.  53.  61.  67.  70.  34.  47.  10.  69.   0.]
[0.807 0.11  0.013 0.012 0.011 0.007 0.007 0.006 0.003 0.003 0.002 0.002
 0.002 0.002 0.001 0.001 0.001 0.001 0.001 0.001 0.001 0.001 0.001 0.001
 0.001 0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.
 0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.
 0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.
 0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.
 0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.
 0.    0.    0.    0.    0