Skip to content

Commit

Permalink
V0.3 pretrained rep fix (#64)
Browse files Browse the repository at this point in the history
* fixed error in base_nets for R3MConv

* added error messages for Sequential

* added doc string for Sequential init
  • Loading branch information
MBronars committed Jul 3, 2023
1 parent e8c97eb commit bbfab6e
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions robomimic/models/base_nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,20 @@ class Sequential(torch.nn.Sequential, Module):
"""
Compose multiple Modules together (defined above).
"""
def __init__(self, *args):
def __init__(self, *args, has_output_shape = True):
"""
Args:
has_output_shape (bool, optional): indicates whether output_shape can be called on the Sequential module.
torch.nn modules do not have an output_shape, but Modules (defined above) do. Defaults to True.
"""
for arg in args:
assert isinstance(arg, Module)
if has_output_shape:
assert isinstance(arg, Module)
else:
assert isinstance(arg, nn.Module)
torch.nn.Sequential.__init__(self, *args)
self.fixed = False
self.has_output_shape = has_output_shape

def output_shape(self, input_shape=None):
"""
Expand All @@ -106,6 +115,8 @@ def output_shape(self, input_shape=None):
Returns:
out_shape ([int]): list of integers corresponding to output shape
"""
if not self.has_output_shape:
raise NotImplementedError("Output shape is not defined for this module")
out_shape = input_shape
for module in self:
out_shape = module.output_shape(out_shape)
Expand Down Expand Up @@ -574,7 +585,7 @@ def __init__(
transforms.CenterCrop(224),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
)
self.nets = nn.Sequential(*([preprocess] + list(net.module.convnet.children())))
self.nets = Sequential(*([preprocess] + list(net.module.convnet.children())), has_output_shape = False)
if freeze:
self.nets.freeze()

Expand Down

0 comments on commit bbfab6e

Please sign in to comment.