From 72b80d5d718a9e34f7ae70906b7117c5aa3ef9df Mon Sep 17 00:00:00 2001 From: Jose Javier <3844846+JJGO@users.noreply.github.com> Date: Fri, 17 Sep 2021 11:35:18 -0400 Subject: [PATCH] Register 'grid' as non-persistent buffer PyTorch added the functionality to have non-persistent buffers that are not included in the state_dict [1]. This is done by specifying `persistent=False` when registering it. This is just an improvement, that removes the associated workaround [1] https://github.com/pytorch/pytorch/issues/18056 --- voxelmorph/torch/layers.py | 9 +++------ voxelmorph/torch/modelio.py | 7 +------ 2 files changed, 4 insertions(+), 12 deletions(-) diff --git a/voxelmorph/torch/layers.py b/voxelmorph/torch/layers.py index 7cddf59a..2c30bd79 100644 --- a/voxelmorph/torch/layers.py +++ b/voxelmorph/torch/layers.py @@ -20,12 +20,9 @@ def __init__(self, size, mode='bilinear'): grid = torch.unsqueeze(grid, 0) grid = grid.type(torch.FloatTensor) - # registering the grid as a buffer cleanly moves it to the GPU, but it also - # adds it to the state dict. this is annoying since everything in the state dict - # is included when saving weights to disk, so the model files are way bigger - # than they need to be. so far, there does not appear to be an elegant solution. - # see: https://discuss.pytorch.org/t/how-to-register-buffer-without-polluting-state-dict - self.register_buffer('grid', grid) + # registering the grid as a buffer cleanly moves it to the GPU + # persistent=False, prevents it from appearing in the state_dict + self.register_buffer('grid', grid, persistent=False) def forward(self, src, flow): # new locations diff --git a/voxelmorph/torch/modelio.py b/voxelmorph/torch/modelio.py index 0f4b6292..2b0b8997 100644 --- a/voxelmorph/torch/modelio.py +++ b/voxelmorph/torch/modelio.py @@ -59,12 +59,7 @@ def save(self, path): """ Saves the model configuration and weights to a pytorch file. """ - # don't save the transformer_grid buffers - see SpatialTransformer doc for more info - sd = self.state_dict().copy() - grid_buffers = [key for key in sd.keys() if key.endswith('.grid')] - for key in grid_buffers: - sd.pop(key) - torch.save({'config': self.config, 'model_state': sd}, path) + torch.save({'config': self.config, 'model_state': self.state_dict()}, path) @classmethod def load(cls, path, device):