In [11]:
from numba.core.errors import NumbaDeprecationWarning, NumbaPendingDeprecationWarning
import warnings

warnings.simplefilter('ignore', category=NumbaDeprecationWarning)
warnings.simplefilter('ignore', category=NumbaPendingDeprecationWarning)

from warnings import filterwarnings
filterwarnings(action='ignore', category=DeprecationWarning, message='`np.bool` is a deprecated alias')

In [12]:
import numpy as np
import sys, os
sys.path.append('../')
from numba import jit, float32, boolean, int32, float64
from utils.config import cfg
from utils.tools.evaluation import dBZ_to_pixel
from utils.tools.numba_accelerated import *

In [13]:
import numpy as np
from utils.config import cfg
import matplotlib as mpl
import matplotlib.pyplot as plt

In [14]:
from utils.tools.evaluation import get_GDL, get_hit_miss_counts, get_balancing_weights
from numpy.testing import assert_allclose, assert_almost_equal
import time

In [15]:
prediction = np.random.uniform(size=(10, 16, 1, 480, 480))
truth = np.random.uniform(size=(10, 16, 1, 480, 480))
mask = np.random.randint(low=0, high=2, size=(10, 16, 1, 480, 480)).astype(np.bool)

begin = time.time()
gdl = get_GDL(prediction=prediction, truth=truth, mask=mask)
end = time.time()
print("numpy gdl:", end - begin)

begin = time.time()
gdl_numba = get_GDL_numba(prediction=prediction, truth=truth, mask=mask)
end = time.time()
print("numba gdl:", end - begin)

# gdl_mx = mx_get_GDL(prediction=prediction, truth=truth, mask=mask)
# print gdl_mx
assert_allclose(gdl, gdl_numba, rtol=1E-4, atol=1E-3)

begin = time.time()
for i in range(5):
    hits, misses, false_alarms, true_negatives = get_hit_miss_counts(prediction, truth, mask)
end = time.time()
print("numpy hits misses:", end - begin)

begin = time.time()
for i in range(5):
    hits_numba, misses_numba, false_alarms_numba, true_negatives_numba = get_hit_miss_counts_numba(prediction, truth, mask)
end = time.time()
print("numba hits misses:", end - begin)

print(np.abs(hits - hits_numba).max())
print(np.abs(misses - misses_numba).max(), np.abs(misses - misses_numba).argmax())
print(np.abs(false_alarms - false_alarms_numba).max(),
        np.abs(false_alarms - false_alarms_numba).argmax())
print(np.abs(true_negatives - true_negatives_numba).max(),
        np.abs(true_negatives - true_negatives_numba).argmax())

begin = time.time()
for i in range(5):
    weights_npy = get_balancing_weights(data=truth, mask=mask,
                                        base_balancing_weights=None, thresholds=None)
end = time.time()
print("numpy balancing weights:", end - begin)

begin = time.time()
for i in range(5):
    weights_numba = get_balancing_weights_numba(data=truth, mask=mask,
                                                base_balancing_weights=None, thresholds=None)
end = time.time()
print("numba balancing weights:", end - begin)
print("Inconsistent Number:", (np.abs(weights_npy - weights_numba) > 1E-5).sum())

numpy gdl: 0.9822108745574951
numba gdl: 0.36331725120544434
numpy hits misses: 22.973840951919556
numba hits misses: 4.682291030883789
1
1 486
1 554
1 554
numpy balancing weights: 13.710271120071411
numba balancing weights: 3.3589236736297607
Inconsistent Number: 3
