Skip to content

Commit

Permalink
fix: hotfix for a cyclic import in the keras-encoder, also added some…
Browse files Browse the repository at this point in the history
… minor documentation
  • Loading branch information
MLRichter committed Jan 14, 2022
1 parent f7e4132 commit e8100fd
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 14 deletions.
25 changes: 25 additions & 0 deletions README.md
Expand Up @@ -83,6 +83,31 @@ unproductive for an input resolution of 32x32 pixels:

Keep in mind that the Graph is reverse-engineerd from the PyTorch JIT-compiler, therefore no looping-logic is
allowed within the forward pass of the model.
Also, the use of the functional-library for stateful-operations like pooling or convolutional layers is discouraged,
since it can cause malformed graphs during the reverse engineering of the compute graph.

#### Tensorflow / Keras

Similar to the PyTorch-import, models can be imported from TensorFlow as well.

The code below will create a graph of VGG16 and visualize it using GraphViz and color all layers that are predicted to be
unproductive for an input resolution of 32x32 pixels. This is analog to the PyTorch-variant depicted above:

```python

from keras.applications.vgg19 import VGG19
from rfa_toolbox import create_graph_from_tensorflow_model

model = VGG19(include_top=True, weights=None)
graph: EnrichedNetworkNode = create_graph_from_tensorflow_model(model)
visualize_architecture(graph, "VGG16", input_res=32).view()

```

This will create the following visualization:
![vgg16.PNG](https://github.com/MLRichter/receptive_field_analysis_toolbox/blob/main/images/vgg16.PNG?raw=true)

Currently, only the use of the keras.Model object is supported.

#### Custom

Expand Down
2 changes: 1 addition & 1 deletion rfa_toolbox/__init__.py
Expand Up @@ -15,7 +15,7 @@ def create_graph_from_pytorch_model(*args, **kwargs):
from rfa_toolbox.encodings.tensorflow_keras.ingest_architecture import (
create_graph_from_model as create_graph_from_tensorflow_model,
)
except ImportError:
except ValueError:

def create_graph_from_tensorflow_model(*args, **kwargs):
raise ImportError("This function is not available, tensorflow not installed")
Expand Down
10 changes: 0 additions & 10 deletions rfa_toolbox/encodings/tensorflow_keras/ingest_architecture.py
Expand Up @@ -3,7 +3,6 @@

from tensorflow.keras.models import Model

from rfa_toolbox import visualize_architecture
from rfa_toolbox.encodings.tensorflow_keras.layer_handlers import (
AnyHandler,
DenseHandler,
Expand Down Expand Up @@ -138,12 +137,3 @@ def create_graph_from_model(model: Model) -> EnrichedNetworkNode:
"""
model_dict = keras_model_to_dict(model)
return model_dict_to_enriched_graph(model_dict)


if __name__ == "__main__":
from tensorflow.keras.applications.inception_resnet_v2 import InceptionResNetV2

graph: EnrichedNetworkNode = create_graph_from_model(
InceptionResNetV2(weights=None)
)
visualize_architecture(graph, "InceptionResNetV2", input_res=32).view()
10 changes: 7 additions & 3 deletions rfa_toolbox/vis_script.py
Expand Up @@ -170,10 +170,14 @@ def vgg19_perf2() -> EnrichedNetworkNode:


if __name__ == "__main__":
m = vgg19_perf2
# m = vgg19_perf2

dot = visualize_architecture(m(), "vgg19_perf", input_res=16).view()
# dot = visualize_architecture(m(), "vgg19_perf", input_res=16).view()

from keras.applications.vgg19 import VGG19

VGG19(weights=None, input_shape=(224, 224, 3), include_top=True).to_json()
from rfa_toolbox import create_graph_from_tensorflow_model

model = VGG19(include_top=True, weights=None)
graph: EnrichedNetworkNode = create_graph_from_tensorflow_model(model)
visualize_architecture(graph, "VGG16", input_res=32).view()

0 comments on commit e8100fd

Please sign in to comment.