Skip to content

Commit

Permalink
TF 2.3 compatibility; 'SCALEFIG' env flag
Browse files Browse the repository at this point in the history
**FEATURES**:

  - Compatibility with TensorFlow 2.3.0 (see "TESTS")
  - `os.environ['SCALEFIG']` (default `'1'`) will scale all drawn figures; specify as tuple `(w, h)` to scale width & height separately.

**BREAKING**:

  - TF 2.3.0 won't work with `keras`; this isn't `see_rnn`, this is breaking changes to some basic TF ops.
  - Changed default `os.environ['TF_KERAS']` to `'1'`

**TESTS**:

  - Discontinued support for `keras` TF2.3.0+; `keras` TF2.2.0 Graph is still tested, but may discontinue entirely in the future, or reinstate depending on how [Keras](https://github.com/keras-team/keras) proceeds
  - `tf.keras` now only tested with TF 2.3.0+
  • Loading branch information
OverLordGoldDragon committed Aug 4, 2020
1 parent 61c3d45 commit 7026959
Show file tree
Hide file tree
Showing 8 changed files with 82 additions and 34 deletions.
5 changes: 2 additions & 3 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,9 @@ env:
matrix:
- TF_VERSION="1.14.0" KERAS_VERSION="2.2.5"
- TF_VERSION="1.14.0" KERAS_VERSION="2.2.5" TF_KERAS="1"
- TF_VERSION="2.2.0" KERAS_VERSION="2.3.1" TF_EAGER="1"
- TF_VERSION="2.2.0" KERAS_VERSION="2.3.1"
- TF_VERSION="2.2.0" KERAS_VERSION="2.3.1" TF_KERAS="1" TF_EAGER="1"
- TF_VERSION="2.2.0" KERAS_VERSION="2.3.1" TF_KERAS="1"
- TF_VERSION="2.3.0" KERAS_VERSION="2.3.1" TF_KERAS="1" TF_EAGER="1"
- TF_VERSION="2.3.0" KERAS_VERSION="2.3.1" TF_KERAS="1"
notifications:
webhooks:
- https://coveralls.io/webhook
Expand Down
25 changes: 24 additions & 1 deletion see_rnn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,26 @@
import os

def _get_scales():
s = os.environ.get('SCALEFIG', '1')
os.environ['SCALEFIG'] = s
if ',' in s:
w_scale, h_scale = map(float, s.strip('[()]').split(','))
else:
w_scale, h_scale = float(s), float(s)
return w_scale, h_scale

def scalefig(fig):
"""Used internally to scale figures according to env var 'SCALEFIG'.
os.environ['SCALEFIG'] can be an int, float, tuple, list, or bracketless
tuple, but must be a string: '1', '1.1', '(1, 1.1)', '1,1.1'.
"""
w, h = fig.get_size_inches()
w_scale, h_scale = _get_scales() # refresh in case env var changed
fig.set_size_inches(w * w_scale, h * h_scale)

##############################################################################

from . import visuals_gen
from . import visuals_rnn
from . import inspect_gen
Expand All @@ -8,4 +31,4 @@
from .inspect_gen import *
from .inspect_rnn import *

__version__ = '1.14.6'
__version__ = '1.15.0'
4 changes: 3 additions & 1 deletion see_rnn/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from termcolor import colored


TF_KERAS = os.environ.get("TF_KERAS", '0') == '1'
TF_KERAS = os.environ.get("TF_KERAS", '1') == '1'

WARN = colored("WARNING:", 'red')
NOTE = colored("NOTE:", 'blue')
Expand All @@ -11,6 +11,8 @@
if TF_KERAS:
from tensorflow.python.keras import backend as K
from tensorflow.keras.layers import Layer
from tensorflow.keras.models import Model
else:
from keras import backend as K
from keras.layers import Layer
from keras.models import Model
19 changes: 13 additions & 6 deletions see_rnn/inspect_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from copy import deepcopy
from .utils import _validate_args, _get_params, _layer_of_output
from ._backend import K, TF_KERAS
from ._backend import K, TF_KERAS, Model

if tf.executing_eagerly():
from tensorflow.python.distribute import parameter_server_strategy
Expand Down Expand Up @@ -60,13 +60,20 @@ def _get_outs_tensors(model, names, idxs, layers):
idxs, layers = None, None
one_requested = len(_id) == 1

layer_outs = _get_outs_tensors(model, names, idxs, layers)
lp = K.symbolic_learning_phase() if TF_KERAS else K.learning_phase()
outs_fn = K.function([*model.inputs, lp], layer_outs)

if not isinstance(input_data, (list, tuple)):
input_data = [input_data]
outs = outs_fn([*input_data, bool(learning_phase)])
layer_outs = _get_outs_tensors(model, names, idxs, layers)

if tf.executing_eagerly():
partial_model = Model(model.inputs, layer_outs)
outs = partial_model(input_data, training=bool(learning_phase))
if not isinstance(outs, (list, tuple)):
outs = [outs]
outs = [o.numpy() for o in outs]
else:
lp = K.symbolic_learning_phase() if TF_KERAS else K.learning_phase()
outs_fn = K.function([*model.inputs, lp], layer_outs)
outs = outs_fn([*input_data, bool(learning_phase)])

if as_dict:
return {get_full_name(model, i): x for i, x in zip(names or idxs, outs)}
Expand Down
7 changes: 7 additions & 0 deletions see_rnn/visuals_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from .utils import _kw_from_configs, clipnums
from ._backend import NOTE
from . import scalefig


def features_0D(data, marker='o', cmap='bwr', color=None, configs=None, **kwargs):
Expand Down Expand Up @@ -114,7 +115,9 @@ def _get_title(data, title):

fig, axes = plt.gcf(), plt.gca()
fig.set_size_inches(12 * w, 4 * h)
scalefig(fig)
plt.show()

if savepath is not None:
fig.savefig(savepath, **kw['save'])
return fig, axes
Expand Down Expand Up @@ -329,6 +332,7 @@ def _style_axis(ax, kw, show_borders, show_xy_ticks, annotations, xmax):
for ax in axes.flat:
[s.set_linewidth(borderwidth) for s in ax.spines.values()]

scalefig(fig)
plt.show()
if savepath is not None:
fig.savefig(savepath, **kw['save'])
Expand Down Expand Up @@ -545,6 +549,7 @@ def _style_axis(ax, kw, show_borders, show_xy_ticks):
s.set_linewidth(borderwidth)
if bordercolor is not None:
s.set_color(bordercolor)
scalefig(fig)
plt.show()
if savepath is not None:
fig.savefig(savepath, **kw['save'])
Expand Down Expand Up @@ -747,6 +752,7 @@ def _style_axis(ax, kw, show_borders, show_xy_ticks, xlims,
for ax in axes.flat:
[s.set_linewidth(borderwidth) for s in ax.spines.values()]

scalefig(fig)
plt.show()
if savepath is not None:
fig.savefig(savepath, **kw['save'])
Expand Down Expand Up @@ -946,6 +952,7 @@ def _style_axis(ax, kw, show_borders, show_xy_ticks, xlims, center_zero,
for ax in axes.flat:
[s.set_linewidth(borderwidth) for s in ax.spines.values()]

scalefig(fig)
plt.show()
if savepath is not None:
fig.savefig(savepath, **kw['save'])
Expand Down
11 changes: 8 additions & 3 deletions see_rnn/visuals_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from .utils import _process_rnn_args, _kw_from_configs, _save_rnn_fig
from .inspect_gen import detect_nans
from . import scalefig


def rnn_histogram(model, _id, layer=None, input_data=None, labels=None,
Expand Down Expand Up @@ -73,7 +74,7 @@ def _process_configs(configs, w, h, equate_axes):
'plot': dict(),
'subplot': dict(sharex=True, sharey=True, dpi=76, figsize=(9, 9)),
'tight': dict(),
'title': dict(weight='bold', fontsize=13, y=1.05),
'title': dict(weight='bold', fontsize=12, y=1.05),
'annot': dict(fontsize=12, weight='bold',
xy=(.90, .93), xycoords='axes fraction'),
'annot-nan': dict(fontsize=12, weight='bold', color='red',
Expand Down Expand Up @@ -257,6 +258,8 @@ def _make_subplots(show_bias, direction_name, d, kw):
x_new, y_new = _get_axes_extrema(subplots_axes)
_set_axes_limits(subplots_axes, x_new, y_new, d)

for fig in subplots_figs:
scalefig(fig)
plt.show()
if savepath is not None:
_save_rnn_fig(subplots_figs, savepath, kw['save'])
Expand Down Expand Up @@ -348,8 +351,8 @@ def _process_configs(configs, w, h):
'plot-bias': dict(interpolation='nearest'),
'subplot': dict(dpi=76, figsize=(14, 8)),
'tight': dict(),
'title': dict(weight='bold', fontsize=14, y=.98),
'subtitle': dict(weight='bold', fontsize=14),
'title': dict(weight='bold', fontsize=13, y=.98),
'subtitle': dict(weight='bold', fontsize=13),
'xlabel': dict(fontsize=12, weight='bold'),
'ylabel': dict(fontsize=12, weight='bold'),
'colorbar': dict(fraction=.03),
Expand Down Expand Up @@ -536,6 +539,8 @@ def _make_subplots(show_bias, direction_name, d, kw):
if kw['tight']:
fig.subplots_adjust(**kw['tight'])

for fig in subplots_figs:
scalefig(fig)
plt.show()
if savepath is not None:
_save_rnn_fig(subplots_figs, savepath, kw['save'])
Expand Down
4 changes: 2 additions & 2 deletions tests/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

#### Environment configs ######################################################
# for testing locally
os.environ['TF_KERAS'] = os.environ.get("TF_KERAS", '1')
os.environ['TF_EAGER'] = os.environ.get("TF_EAGER", '0')
os.environ['TF_KERAS'] = os.environ.get("TF_KERAS", '0')
os.environ['TF_EAGER'] = os.environ.get("TF_EAGER", '1')

TF_KERAS = bool(os.environ['TF_KERAS'] == '1')
TF_EAGER = bool(os.environ['TF_EAGER'] == '1')
Expand Down
41 changes: 23 additions & 18 deletions tests/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from backend import Model
from backend import l2
from backend import tempdir
import backend
from see_rnn import get_gradients, get_outputs, get_weights, get_rnn_weights
from see_rnn import get_weight_penalties, weights_norm, weight_loss
from see_rnn import features_0D, features_1D, features_2D
Expand All @@ -31,7 +32,7 @@


IMPORTS = dict(K=K, Input=Input, GRU=GRU,
Bidirectional=Bidirectional, Model=Model)
Bidirectional=Bidirectional, Model=Model)

if TF_2 and not TF_KERAS:
print(WARN, "LSTM, CuDNNLSTM, and CuDNNGRU imported `from keras` "
Expand All @@ -44,9 +45,9 @@ def test_main():
iterations = 20

kwargs1 = dict(batch_shape=batch_shape, units=units, bidirectional=False,
IMPORTS=IMPORTS)
IMPORTS=IMPORTS)
kwargs2 = dict(batch_shape=batch_shape, units=units, bidirectional=True,
IMPORTS=IMPORTS)
IMPORTS=IMPORTS)

if TF_2 and not TF_KERAS:
rnn_layers = GRU, SimpleRNN
Expand All @@ -57,10 +58,10 @@ def test_main():

model_names = [layer.__name__ for layer in rnn_layers]
model_names = [(prefix + name) for prefix in ("uni-", "bi-")
for name in model_names]
for name in model_names]

configs = [dict(rnn_layer=rnn_layer, **kwargs)
for kwargs in (kwargs1, kwargs2) for rnn_layer in rnn_layers]
for kwargs in (kwargs1, kwargs2) for rnn_layer in rnn_layers]

tests_ran = 0
for config, model_name in zip(configs, model_names):
Expand Down Expand Up @@ -95,9 +96,9 @@ def _test_outputs_gradients(model):
grads_last = get_gradients(model, 2, x, y, mode='outputs')

kwargs1 = dict(n_rows=None, show_xy_ticks=[0, 0], show_borders=True,
max_timesteps=50, title='grads')
max_timesteps=50, title='grads')
kwargs2 = dict(n_rows=2, show_xy_ticks=[1, 1], show_borders=False,
max_timesteps=None)
max_timesteps=None)

features_1D(grads_all[0], **kwargs1)
features_1D(grads_all[:1], **kwargs1)
Expand Down Expand Up @@ -142,7 +143,7 @@ def test_errors(): # test Exception cases

reset_seeds(reset_graph_with_backend=K)
model = make_model(GRU, batch_shape, activation='relu',
recurrent_dropout=0.3, IMPORTS=IMPORTS)
recurrent_dropout=0.3, IMPORTS=IMPORTS)
x, y, sw = make_data(batch_shape, units)
model.train_on_batch(x, y, sw)

Expand All @@ -164,7 +165,7 @@ def test_errors(): # test Exception cases
pass_on_error(get_layer, model)
pass_on_error(get_layer, model, 'capsule')
pass_on_error(rnn_heatmap, model, 1, input_data=x, labels=y,
mode='coffee')
mode='coffee')
pass_on_error(rnn_heatmap, model, 1, co='vid')
pass_on_error(rnn_heatmap, model, 1, norm=(0, 1, 2))
pass_on_error(rnn_heatmap, model, 1, mode='grads')
Expand Down Expand Up @@ -195,7 +196,7 @@ def test_misc(): # test miscellaneous functionalities

reset_seeds(reset_graph_with_backend=K)
model = make_model(GRU, batch_shape, activation='relu',
recurrent_dropout=0.3, IMPORTS=IMPORTS)
recurrent_dropout=0.3, IMPORTS=IMPORTS)
x, y, sw = make_data(batch_shape, units)
model.train_on_batch(x, y, sw)

Expand Down Expand Up @@ -238,8 +239,8 @@ def test_misc(): # test miscellaneous functionalities
savepath=os.path.join(dirpath, 'img.png'))
rnn_histogram(model, 1, equate_axes=False,
configs={'tight': dict(left=0, right=1),
'plot': dict(color='red'),
'title': dict(fontsize=14),})
'plot': dict(color='red'),
'title': dict(fontsize=14),})
rnn_heatmap(model, 1, cmap=None, normalize=True, show_borders=False)
rnn_heatmap(model, 1, cmap=None, norm='auto', absolute_value=True)
rnn_heatmap(model, 1, norm=None)
Expand Down Expand Up @@ -378,7 +379,7 @@ def test_inspect_gen():
batch_shape = (8, 100, 2 * units)

model = make_model(GRU, batch_shape, activation='relu', bidirectional=True,
recurrent_dropout=0.3, include_dense=True, IMPORTS=IMPORTS)
recurrent_dropout=0.3, include_dense=True, IMPORTS=IMPORTS)

assert bool(get_weight_penalties(model))
assert weight_loss(model) > 0
Expand Down Expand Up @@ -423,7 +424,7 @@ def make_data(batch_shape, n_batches):
model.train_on_batch(x, y)

l2_stats[epoch] = weights_norm(model, [1, 3], l2_stats[epoch],
omit_names='bias', verbose=1)
omit_names='bias', verbose=1)
print("Epoch", epoch + 1, "finished")
print()

Expand Down Expand Up @@ -459,8 +460,8 @@ def _merge_layers_and_weights(l2_stats):

## Plot ########
features_hist_v2(stats_merged, colnames=weight_names, title=suptitle,
xlims=xlims, ylim=ylim, side_annot=side_annot,
pad_xticks=True, configs=configs)
xlims=xlims, ylim=ylim, side_annot=side_annot,
pad_xticks=True, configs=configs)


def test_envs(): # pseudo-tests for coverage for different env flags
Expand All @@ -479,6 +480,7 @@ def test_envs(): # pseudo-tests for coverage for different env flags
reload(utils)
reload(inspect_gen)
reload(inspect_rnn)
reload(backend)
from see_rnn.inspect_gen import get_gradients as glg
from see_rnn.inspect_rnn import rnn_summary as rs
from see_rnn.utils import _validate_rnn_type as _vrt
Expand All @@ -498,8 +500,11 @@ def test_envs(): # pseudo-tests for coverage for different env flags
reset_seeds(reset_graph_with_backend=_K)
new_imports = dict(Input=Input, Bidirectional=Bidirectional,
Model=Model)
model = make_model(_GRU, batch_shape, new_imports=new_imports,
IMPORTS=IMPORTS)
try:
model = make_model(_GRU, batch_shape, new_imports=new_imports,
IMPORTS=IMPORTS)
except:
break # fails on case '0' in TF 2.3, doesn't matter to fix

pass_on_error(model, x, y, 1) # possibly _backend-induced err
pass_on_error(glg, model, 1, x, y)
Expand Down

0 comments on commit 7026959

Please sign in to comment.