In [1]:
import random
import numpy as np
import torch
import matplotlib.pyplot as plt
import os
import pandas as pd
from nn_edm import *
from gru_nn import *

In [2]:
'''
Binary prediction skill (“bloom” or “no bloom”) is
evaluated by how well bloom predictions above a defined threshold
correspond to bloom observations above the same threshold
Input: Series of observed chlorophyll-a, series of predicted chlorophyll-1
Output: Arrray containing [Accuracy, True Positive Rate, False Positive Rate, True Negative Rate, False Negative Rate]
'''
def thresh_bloom_binary_prediction(obs,pred,threshold=8.03199999999999):
    obs_blooms = obs > threshold
    pred_blooms = pred > threshold
    Accuracy = 1 - (obs_blooms ^ pred_blooms).mean()
    True_pos = (obs_blooms & pred_blooms).sum() / obs_blooms.sum()
    False_pos = ((~obs_blooms) & pred_blooms).sum() / (~obs_blooms).sum()
    True_neg = ((~obs_blooms) & (~pred_blooms)).sum() / (~obs_blooms).sum()
    False_neg = (obs_blooms & (~pred_blooms)).sum() / obs_blooms.sum()
    
    return [Accuracy, True_pos, False_pos, True_neg, False_neg]

In [3]:
'''
Binary prediction skill (“bloom” or “no bloom”) is
evaluated by how well bloom predictions (5% largest predicted values)
correspond to bloom observations (5%largest observed values)
Input: Series of observed chlorophyll-a, series of predicted chlorophyll-1
Output: Arrray containing [Accuracy, True Positive Rate, False Positive Rate, True Negative Rate, False Negative Rate]
'''
def bloom_binary_prediction(obs,pred):
    obs_bloom_95 = np.percentile(obs, 95) #incorrect
    pred_bloom_95 = np.percentile(pred, 95) #incorrect
    obs_blooms = obs > obs_bloom_95
    pred_blooms = pred > pred_bloom_95
    Accuracy = 1 - (obs_blooms ^ pred_blooms).mean()
    True_pos = (obs_blooms & pred_blooms).sum() / obs_blooms.sum()
    False_pos = ((~obs_blooms) & pred_blooms).sum() / (~obs_blooms).sum()
    True_neg = ((~obs_blooms) & (~pred_blooms)).sum() / (~obs_blooms).sum()
    False_neg = (obs_blooms & (~pred_blooms)).sum() / obs_blooms.sum()
    
    return [Accuracy, True_pos, False_pos, True_neg, False_neg]

In [27]:
#Load embedding data
input_file_path = '../Data/cleaned_data.csv'
target = 'Avg_Chloro'
data = pd.read_csv(input_file_path)
data = data.set_index('time (UTC)')
data['Time'] = data.index.astype(int)
data = data.drop(columns=['Time'])

tau_lengths = [-1,-2,-3]
E = 6
X, y = get_data(data, E, tau_lengths, target=target)
embd_sz = len(data.columns) * E * len(tau_lengths)
X = torch.tensor(X[533:], dtype=torch.float) 
y = torch.tensor(y[533:], dtype=torch.float)
len(X)

284

In [22]:
model = NNEDMModel(embd_sz,hidden_size=100)
#model = GRUEDMModel(embd_sz=embd_sz)
model.load_state_dict(torch.load("base_model.pth"))

model.eval()
prediction = []
for inp in X:
    prediction.append(model(inp))
prediction = torch.tensor(prediction).numpy()

  model.load_state_dict(torch.load("base_model.pth"))


In [26]:
prediction

array([ 2.209867 ,  1.8998418,  1.8998704,  2.1029317,  2.1288238,
        2.0142043,  2.1412094,  1.9577843,  2.2683644,  2.164142 ,
        2.3551452,  2.1614013,  2.1952121,  2.3856342,  2.3262348,
        2.5009825,  1.9819975,  2.1345713,  2.044676 ,  2.325422 ,
        3.973245 ,  3.4538572, 15.843585 ,  4.7164826,  6.080781 ,
        3.2018096,  2.2384748,  1.9614384,  2.5669827,  2.1194422,
        2.114854 ,  2.0926416,  3.2585194,  2.4836943,  3.3147755,
        2.8412163,  2.186163 ,  2.0373578,  2.1411843,  2.1480374,
        2.179303 ,  1.9396598,  1.5985733,  1.8525819,  1.804415 ,
        1.7511569,  1.9507066,  2.0212023,  2.3076136,  2.0069497,
        2.3133547,  2.113256 ,  2.11911  ,  2.0236065,  2.018995 ,
        2.1007395,  1.90765  ,  1.9542669,  2.0014944,  2.0074935,
        1.9852434,  1.9764105,  1.9646956,  2.2082717,  2.029346 ,
        2.464256 ,  2.508235 ,  2.5208268,  2.5087414,  2.5187771,
        2.5194933,  2.4989066,  2.4446843,  2.4874933,  2.9144

In [23]:
y.numpy()

array([  0.7 ,   0.67,   1.41,   1.48,   0.97,   1.56,   0.74,   1.97,
         1.49,   2.26,   1.39,   1.54,   2.3 ,   2.01,   2.7 ,   0.45,
         1.43,   0.91,   2.21,   8.54,   5.82,  55.68,   6.28,  16.68,
         4.29,   1.86,   0.94,   3.26,   1.74,   1.92,   1.84,   6.26,
         2.69,   6.8 ,   4.42,   2.08,   1.78,   2.17,   2.36,   2.17,
         1.36,   0.16,   0.98,   0.84,   0.94,   1.28,   1.43,   2.42,
         1.28,   2.39,   1.49,   1.63,   1.11,   1.12,   1.42,   0.57,
         0.8 ,   0.99,   0.9 ,   0.84,   0.83,   0.72,   1.73,   0.86,
         2.71,   2.71,   2.71,   2.71,   2.71,   2.71,   2.71,   2.71,
         2.71,   4.56,   1.6 ,   2.35,   2.54,   2.08,   2.44,   3.29,
         2.79,   3.86,   1.59,   3.66,   0.93,   2.12,   1.66,   1.6 ,
         1.58,   1.76,   1.83,   1.58,   1.78,   1.61,   1.61,   1.45,
         2.02,   4.76,   2.93,   1.54,   1.5 ,   1.92,   2.88,   2.9 ,
         3.48,   4.22,   5.44,   2.47,   4.11,   7.96,   3.23,   2.46,
      

In [24]:
bloom_binary_prediction(y.numpy(),prediction)

ValueError: operands could not be broadcast together with shapes (283,) (284,) 