Skip to content

Commit

Permalink
Drop TF as requirement to allow general use
Browse files Browse the repository at this point in the history
Only `see_rnn.visuals_gen` will work
  • Loading branch information
OverLordGoldDragon committed Nov 11, 2020
1 parent 7026959 commit 0044e3a
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 25 deletions.
1 change: 1 addition & 0 deletions requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ coverage
pytest
pytest-cov
pycodestyle
tensorflow
Keras
3 changes: 1 addition & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
numpy
matplotlib
tensorflow
matplotlib
20 changes: 12 additions & 8 deletions see_rnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,17 @@ def scalefig(fig):
##############################################################################

from . import visuals_gen
from . import visuals_rnn
from . import inspect_gen
from . import inspect_rnn

from .visuals_gen import *
from .visuals_rnn import *
from .inspect_gen import *
from .inspect_rnn import *
try:
from . import visuals_rnn
from .visuals_rnn import *
from . import inspect_gen
from .inspect_gen import *
from . import inspect_rnn
from .inspect_rnn import *
except:
# handled in _backend.py
pass


__version__ = '1.15.0'
__version__ = '1.15.1'
23 changes: 13 additions & 10 deletions see_rnn/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,19 @@


TF_KERAS = os.environ.get("TF_KERAS", '1') == '1'

WARN = colored("WARNING:", 'red')
NOTE = colored("NOTE:", 'blue')


if TF_KERAS:
from tensorflow.python.keras import backend as K
from tensorflow.keras.layers import Layer
from tensorflow.keras.models import Model
else:
from keras import backend as K
from keras.layers import Layer
from keras.models import Model
try:
if TF_KERAS:
from tensorflow.python.keras import backend as K
from tensorflow.keras.layers import Layer
from tensorflow.keras.models import Model
else:
from keras import backend as K
from keras.layers import Layer
from keras.models import Model
except:
print("WARNING: failed to import TensorFlow or Keras; functionality "
"is restricted to see_rnn.visuals_gen")
K, Layer, Model = None, None, None
5 changes: 3 additions & 2 deletions see_rnn/inspect_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ def _get_by_name(model, name):
if as_dict:
return weights
weights = list(weights.values())
return weights[0] if len(_ids) == 1 else weights
return weights[0] if (len(_ids) == 1 and len(weights) == 1) else weights


def detect_nans(data, include_inf=True):
Expand Down Expand Up @@ -664,7 +664,8 @@ def _cell_penalties(cell):

if _lambda is not None:
weight_name = cell.weights[weight_idx].name
l1_l2 = (float(_lambda.l1), float(_lambda.l2))
l1_l2 = (float(getattr(_lambda, 'l1', 0)),
float(getattr(_lambda, 'l2', 0)))
penalties.append([weight_name, l1_l2])
return penalties

Expand Down
7 changes: 5 additions & 2 deletions see_rnn/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import numpy as np
import tensorflow as tf
from copy import deepcopy
from pathlib import Path

from ._backend import WARN, NOTE, TF_KERAS, Layer

try:
import tensorflow as tf
except:
pass # handled in __init__ via _backend.py


def _kw_from_configs(configs, defaults):
def _fill_absent_defaults(kw, defaults):
Expand Down
2 changes: 1 addition & 1 deletion see_rnn/visuals_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import matplotlib.pyplot as plt
from matplotlib import cm

from .utils import _kw_from_configs, clipnums
from ._backend import NOTE
from .utils import _kw_from_configs, clipnums
from . import scalefig


Expand Down

0 comments on commit 0044e3a

Please sign in to comment.