In [1]:
import logging

import numpy as np
from tabulate import tabulate
from tqdm.notebook import tqdm

from nems import log
from nems.modelspec import eval_ms_layer
from nems.tf.cnnlink import eval_tf_layer

log.setLevel(logging.ERROR)

[nems.configs.defaults INFO] Saving log messages to /tmp/nems\NEMS 2020-04-09 153935.log
[nems.registry INFO] importing xforms function: ldcol


In [13]:
array = np.random.random((20, 100, 18))
kern_size = array.shape[-1]
out_size = 4

layer_specs = [
    f'wc.{kern_size}x{out_size}.g',
    f'wc.{kern_size}x{out_size}.b',
    f'fir.{kern_size}x{out_size}',
    f'do.{kern_size}x{out_size}',
    f'stategain.{kern_size}x3',
    f'relu.{kern_size}',
    f'dlog.c{kern_size}',
    f'stp.{kern_size}',
    f'dexp.{kern_size}',
]

In [14]:
allclose_results = []
ms_succs = []
tf_succs = []
max_diffs = []

pbar = tqdm(layer_specs)

for layer_spec in pbar:
    pbar.set_description(f'Evaluating "{layer_spec}"')

    ms_succ, tf_succ = False, False
    try:
        ms_resp = eval_ms_layer(array, layer_spec)
        ms_succ = True
    except:
        pass

    try:
        tf_resp = eval_tf_layer(array, layer_spec)
        tf_succ = True
    except:
        pass

    if not all([ms_succ, tf_succ]):
        ms_succs.append(ms_succ)
        tf_succs.append(tf_succ)
        allclose_results.append('nan')
        max_diffs.append('nan')
        continue

    ms_succs.append(ms_succ)
    tf_succs.append(tf_succ)

    allclose = np.allclose(ms_resp, tf_resp, rtol=1e-05, atol=1e-05)
#     allclose_results.append(allclose)
    allclose_results.append(str(allclose))

    max_diff = np.max(np.abs(ms_resp - tf_resp))
    max_diffs.append(f'{max_diff:.2E}')

HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))




In [15]:
print(tabulate(
    zip(
        layer_specs,
        ms_succs,
        tf_succs,
        allclose_results,
        max_diffs,
    ),
    headers=['layer', 'ms ran', 'tf ran', 'allclose', 'max diff'],
    disable_numparse=True,
))

layer           ms ran    tf ran    allclose    max diff
--------------  --------  --------  ----------  ----------
wc.18x4.g       True      True      True        1.94E-07
wc.18x4.b       True      True      True        2.99E-08
fir.18x4        True      True      False       5.93E-01
do.18x4         False     False     nan         nan
stategain.18x3  False     False     nan         nan
relu.18         True      True      True        2.98E-08
dlog.c18        True      True      True        9.08E-08
stp.18          True      True      False       3.02E-02
dexp.18         True      True      True        1.19E-07
