Skip to content

Commit

Permalink
Register 'grid' as non-persistent buffer
Browse files Browse the repository at this point in the history
PyTorch added the functionality to have non-persistent buffers that
are not included in the state_dict [1]. This is done by specifying
`persistent=False` when registering it.

This is just an improvement, that removes the associated workaround

[1] pytorch/pytorch#18056
  • Loading branch information
JJGO committed Sep 17, 2021
1 parent 3d004df commit 72b80d5
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 12 deletions.
9 changes: 3 additions & 6 deletions voxelmorph/torch/layers.py
Expand Up @@ -20,12 +20,9 @@ def __init__(self, size, mode='bilinear'):
grid = torch.unsqueeze(grid, 0)
grid = grid.type(torch.FloatTensor)

# registering the grid as a buffer cleanly moves it to the GPU, but it also
# adds it to the state dict. this is annoying since everything in the state dict
# is included when saving weights to disk, so the model files are way bigger
# than they need to be. so far, there does not appear to be an elegant solution.
# see: https://discuss.pytorch.org/t/how-to-register-buffer-without-polluting-state-dict
self.register_buffer('grid', grid)
# registering the grid as a buffer cleanly moves it to the GPU
# persistent=False, prevents it from appearing in the state_dict
self.register_buffer('grid', grid, persistent=False)

def forward(self, src, flow):
# new locations
Expand Down
7 changes: 1 addition & 6 deletions voxelmorph/torch/modelio.py
Expand Up @@ -59,12 +59,7 @@ def save(self, path):
"""
Saves the model configuration and weights to a pytorch file.
"""
# don't save the transformer_grid buffers - see SpatialTransformer doc for more info
sd = self.state_dict().copy()
grid_buffers = [key for key in sd.keys() if key.endswith('.grid')]
for key in grid_buffers:
sd.pop(key)
torch.save({'config': self.config, 'model_state': sd}, path)
torch.save({'config': self.config, 'model_state': self.state_dict()}, path)

@classmethod
def load(cls, path, device):
Expand Down

0 comments on commit 72b80d5

Please sign in to comment.