generated from fastai/nbdev_template
-
Notifications
You must be signed in to change notification settings - Fork 295
Open
Description
Hi,
I've found an issue when trying to convert a Pytorch module. I've isolated the issue and created this snippet to reproduce it.
As a temporal workaround, I had to stop using delegates in the tsai library to avoid this issue, but I'd appreciate any help with this. Have any of you experienced this before?
THIS WORKS
import torch
import torch.nn as nn
from fastcore.meta import delegates
class DelegatesTest(nn.Module):
def __init__(self, **kwargs):
super().__init__(**kwargs)
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv = nn.Conv2d(1, 1, 3)
def forward(self, x):
return self.conv(x)
n = Net()
inp = torch.rand(1, 1, 3, 3)
output = n(inp)
print(output)
module = torch.jit.trace(n, inp)
print(module)THIS DOESN'T WORK
import torch
import torch.nn as nn
from fastcore.meta import delegates
@delegates(nn.Linear.__init__)
class DelegatesTest(nn.Module):
def __init__(self, **kwargs):
super().__init__(**kwargs)
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv = nn.Conv2d(1, 1, 3)
def forward(self, x):
return self.conv(x)
n = Net()
inp = torch.rand(1, 1, 3, 3)
output = n(inp)
print(output)
module = torch.jit.trace(n, inp)
print(module)It returns the following error:
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
/var/folders/42/4hhwknbd5kzcbq48tmy_gbp00000gn/T/ipykernel_91620/2045013617.py in <module>
19 output = n(inp)
20 print(output)
---> 21 module = torch.jit.trace(n, inp)
22 print(module)
~/opt/anaconda3/envs/py37torch112/lib/python3.7/site-packages/torch/jit/_trace.py in trace(func, example_inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit)
757 strict,
758 _force_outplace,
--> 759 _module_class,
760 )
761
~/opt/anaconda3/envs/py37torch112/lib/python3.7/site-packages/torch/jit/_trace.py in trace_module(mod, inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit)
949 register_submods(mod, "__module")
950
--> 951 module = make_module(mod, _module_class, _compilation_unit)
952
953 for method_name, example_inputs in inputs.items():
~/opt/anaconda3/envs/py37torch112/lib/python3.7/site-packages/torch/jit/_trace.py in make_module(mod, _module_class, _compilation_unit)
575 if _module_class is None:
576 _module_class = TopLevelTracedModule
--> 577 return _module_class(mod, _compilation_unit=_compilation_unit)
578
579
~/opt/anaconda3/envs/py37torch112/lib/python3.7/site-packages/torch/jit/_trace.py in __init__(self, orig, id_set, _compilation_unit)
1075 continue
1076 tmp_module._modules[name] = make_module(
-> 1077 submodule, TracedModule, _compilation_unit=None
1078 )
1079
~/opt/anaconda3/envs/py37torch112/lib/python3.7/site-packages/torch/jit/_trace.py in make_module(mod, _module_class, _compilation_unit)
575 if _module_class is None:
576 _module_class = TopLevelTracedModule
--> 577 return _module_class(mod, _compilation_unit=_compilation_unit)
578
579
~/opt/anaconda3/envs/py37torch112/lib/python3.7/site-packages/torch/jit/_trace.py in __init__(self, orig, id_set, _compilation_unit)
1079
1080 script_module = torch.jit._recursive.create_script_module(
-> 1081 tmp_module, lambda module: (), share_types=False, is_tracing=True
1082 )
1083
~/opt/anaconda3/envs/py37torch112/lib/python3.7/site-packages/torch/jit/_recursive.py in create_script_module(nn_module, stubs_fn, share_types, is_tracing)
453 assert not isinstance(nn_module, torch.jit.RecursiveScriptModule)
454 check_module_initialized(nn_module)
--> 455 concrete_type = get_module_concrete_type(nn_module, share_types)
456 if not is_tracing:
457 AttributeTypeIsSupportedChecker().check(nn_module)
~/opt/anaconda3/envs/py37torch112/lib/python3.7/site-packages/torch/jit/_recursive.py in get_module_concrete_type(nn_module, share_types)
408 # Get a concrete type directly, without trying to re-use an existing JIT
409 # type from the type store.
--> 410 concrete_type_builder = infer_concrete_type_builder(nn_module, share_types)
411 concrete_type_builder.set_poisoned()
412 concrete_type = concrete_type_builder.build()
~/opt/anaconda3/envs/py37torch112/lib/python3.7/site-packages/torch/jit/_recursive.py in infer_concrete_type_builder(nn_module, share_types)
220 # Constants annotated via `Final[T]` rather than being added to `__constants__`
221 for name, ann in class_annotations.items():
--> 222 if torch._jit_internal.is_final(ann):
223 constants_set.add(name)
224
~/opt/anaconda3/envs/py37torch112/lib/python3.7/site-packages/torch/_jit_internal.py in is_final(ann)
941
942 def is_final(ann) -> bool:
--> 943 return ann.__module__ in {'typing', 'typing_extensions'} and \
944 (getattr(ann, '__origin__', None) is Final or isinstance(ann, type(Final)))
945
AttributeError: 'NoneType' object has no attribute '__module__'Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels