Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Add model parameters / modules inspection helper. (#2466)
Browse files Browse the repository at this point in the history
* Adds inspection method on models to view parameters and modules.

* Add a test.

* Fix a typo.

* Fix another typo.

* Fix mypy and docs.

* Revert changes on model.

* Add inspect_model_parameters and tests.

* Avert circular import (for now).

* Remove extra blank lines.

* Allow modules instead of restricting to only model.

* Add a blank line.

* Allow too many lines in util_test.py

* update inspection util.

* pylint and mypy.

* Update docstring.
  • Loading branch information
HarshTrivedi authored and joelgrus committed Jun 20, 2019
1 parent 0fbd1ca commit cf247c6
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 0 deletions.
35 changes: 35 additions & 0 deletions allennlp/nn/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import logging
import copy
import math
import json

import torch

Expand Down Expand Up @@ -1378,3 +1379,37 @@ def uncombine_initial_dims(tensor: torch.Tensor, original_size: torch.Size) -> t
else:
view_args = list(original_size) + [tensor.size(-1)]
return tensor.view(*view_args)


def inspect_parameters(module: torch.nn.Module, quiet: bool = False) -> Dict[str, Any]:
"""
Inspects the model/module parameters and their tunability. The output is structured
in a nested dict so that parameters in same sub-modules are grouped together.
This can be helpful to setup module path based regex, for example in initializer.
It prints it by default (optional) and returns the inspection dict. Eg. output::
{
"_text_field_embedder": {
"token_embedder_tokens": {
"_projection": {
"bias": "tunable",
"weight": "tunable"
},
"weight": "frozen"
}
}
}
"""
results: Dict[str, Any] = {}
for name, param in sorted(module.named_parameters()):
keys = name.split(".")
write_to = results
for key in keys[:-1]:
if key not in write_to:
write_to[key] = {}
write_to = write_to[key]
write_to[keys[-1]] = "tunable" if param.requires_grad else "frozen"
if not quiet:
print(json.dumps(results, indent=4))
return results
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
{
"_aggregate_feedforward": {
"_linear_layers": {
"0": {
"bias": "tunable",
"weight": "tunable"
}
}
},
"_attend_feedforward": {
"_module": {
"_linear_layers": {
"0": {
"bias": "tunable",
"weight": "tunable"
}
}
}
},
"_compare_feedforward": {
"_module": {
"_linear_layers": {
"0": {
"bias": "tunable",
"weight": "tunable"
}
}
}
},
"_text_field_embedder": {
"token_embedder_tokens": {
"_projection": {
"bias": "tunable",
"weight": "tunable"
},
"weight": "frozen"
}
}
}
10 changes: 10 additions & 0 deletions allennlp/tests/nn/util_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# pylint: disable=invalid-name,no-self-use,too-many-public-methods,not-callable,too-many-lines,protected-access
import json
from typing import NamedTuple

import numpy
Expand All @@ -9,6 +10,7 @@
from allennlp.common.checks import ConfigurationError
from allennlp.common.testing import AllenNlpTestCase
from allennlp.nn import util
from allennlp.models import load_archive


class TestNnUtil(AllenNlpTestCase):
Expand Down Expand Up @@ -1023,6 +1025,14 @@ def test_uncombine_initial_dims(self):
embedding = util.uncombine_initial_dims(embedding2d, torch.Size((4, 10, 20, 17, 5)))
assert list(embedding.size()) == [4, 10, 20, 17, 5, 12]

def test_inspect_model_parameters(self):
model_archive = str(self.FIXTURES_ROOT / 'decomposable_attention' / 'serialization' / 'model.tar.gz')
parameters_inspection = str(self.FIXTURES_ROOT / 'decomposable_attention' / 'parameters_inspection.json')
model = load_archive(model_archive).model
with open(parameters_inspection) as file:
parameters_inspection_dict = json.load(file)
assert parameters_inspection_dict == util.inspect_parameters(model)

def test_move_to_device(self):
# We're faking the tensor here so that we can test the calls to .cuda() without actually
# needing a GPU.
Expand Down

0 comments on commit cf247c6

Please sign in to comment.