Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
s1eeveW committed May 20, 2023
1 parent cd3fc65 commit 1d0a8ed
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 10 deletions.
8 changes: 7 additions & 1 deletion examples/distributed_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,13 @@ def main():
model=MMResNet50(),
work_dir='./work_dir',
train_dataloader=train_dataloader,
optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),

# optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),
optim_wrapper=dict(
optimizer=dict(type=SGD, lr=0.001, momentum=0.9),
paramwise_cfg=dict(custom_keys=dict(
conv1=dict(lr_mult=0.1)))),

train_cfg=dict(by_epoch=True, max_epochs=2, val_interval=1),
val_dataloader=val_dataloader,
val_cfg=dict(),
Expand Down
26 changes: 26 additions & 0 deletions mmengine/optim/optimizer/amp_optimizer_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from contextlib import contextmanager
from typing import Union

Expand Down Expand Up @@ -147,8 +148,33 @@ def load_state_dict(self, state_dict: dict):
"""
if 'loss_scaler' in state_dict:
self.loss_scaler.load_state_dict(state_dict.pop('loss_scaler'))

# remote the current state tracker during the loading in optimizer
if self.optimizer.param_groups[-1].get('is_state_tracker', False):
self.optimizer.param_groups.pop()

# remote the state tracker in state_dict if exists
# save it and add it back after loading
state_tracker = None
if state_dict['param_groups'][-1].get('is_state_tracker', False):
state_tracker = state_dict['param_groups'].pop()

# load state_dict of optimizer
self.optimizer.load_state_dict(state_dict)

# add the state tracker back
if state_tracker is None:
last_param = copy.deepcopy(self.optimizer.param_groups[-1])
last_param.pop('params')
new_param_settings = {
'params': torch.tensor([0.0], requires_grad=True),
'is_state_tracker': True,
**last_param
}
self.optimizer.param_groups.append(new_param_settings)
else:
self.optimizer.param_groups.append(state_tracker)

@contextmanager
def optim_context(self, model: nn.Module):
"""Enables the context for mixed precision training, and enables the
Expand Down
40 changes: 39 additions & 1 deletion mmengine/optim/optimizer/optimizer_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import logging
from contextlib import contextmanager
from typing import Dict, List, Optional
Expand Down Expand Up @@ -160,6 +161,13 @@ def __init__(self,
# last few iterations. If `_max_counts` has not been initialized,
# the loss factor will always be the same as `_accumulative_counts`.
self._remainder_counts = -1
if hasattr(self.optimizer, 'defaults'):
new_param_settings = {
'params': torch.tensor([0.0], requires_grad=True),
'is_state_tracker': True,
**self.optimizer.defaults
}
self.optimizer.param_groups.append(new_param_settings)

def update_params(self,
loss: torch.Tensor,
Expand Down Expand Up @@ -265,8 +273,31 @@ def load_state_dict(self, state_dict: dict) -> None:
Args:
state_dict (dict): The state dictionary of :attr:`optimizer`.
"""
if self.optimizer.param_groups[-1].get('is_state_tracker', False):
self.optimizer.param_groups.pop()

# remote the state tracker in state_dict if exists
# save it and add it back after loading
state_tracker = None
if state_dict['param_groups'][-1].get('is_state_tracker', False):
state_tracker = state_dict['param_groups'].pop()

# load state_dict of optimizer
self.optimizer.load_state_dict(state_dict)

# add the state tracker back
if state_tracker is None:
last_param = copy.deepcopy(self.optimizer.param_groups[-1])
last_param.pop('params')
new_param_settings = {
'params': torch.tensor([0.0], requires_grad=True),
'is_state_tracker': True,
**last_param
}
self.optimizer.param_groups.append(new_param_settings)
else:
self.optimizer.param_groups.append(state_tracker)

@property
def param_groups(self) -> List[dict]:
"""A wrapper of ``Optimizer.param_groups``.
Expand Down Expand Up @@ -297,7 +328,11 @@ def get_lr(self) -> Dict[str, List[float]]:
Returns:
Dict[str, List[float]]: Learning rate of the optimizer.
"""
lr = [group['lr'] for group in self.param_groups]
# lr = [group['lr'] for group in self.param_groups]
lr = [
group['lr'] for group in self.param_groups if
'is_state_tracker' in group and group['is_state_tracker'] is True
]
return dict(lr=lr)

def get_momentum(self) -> Dict[str, List[float]]:
Expand All @@ -310,6 +345,9 @@ def get_momentum(self) -> Dict[str, List[float]]:
"""
momentum = []
for group in self.param_groups:
if 'is_state_tracker' not in group or group[
'is_state_tracker'] is False:
continue
# Get momentum of SGD.
if 'momentum' in group.keys():
momentum.append(group['momentum'])
Expand Down
15 changes: 7 additions & 8 deletions tests/test_optim/test_optimizer/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def _check_sgd_optimizer(self,
assert optimizer.defaults['momentum'] == self.momentum
assert optimizer.defaults['weight_decay'] == self.base_wd
model_parameters = list(model.parameters())
assert len(param_groups) == len(model_parameters)
assert len(param_groups) == len(model_parameters)+1
for i, param in enumerate(model_parameters):
param_group = param_groups[i]
assert torch.equal(param_group['params'][0], param)
Expand Down Expand Up @@ -598,9 +598,8 @@ def test_default_optimizer_constructor_bypass_duplicate(self):
optim_constructor(model)
optim_wrapper = optim_constructor(model)
model_parameters = list(model.parameters())
num_params = 14 if MMCV_FULL_AVAILABLE else 11
assert len(optim_wrapper.optimizer.param_groups) == len(
model_parameters) == num_params
num_params = 15 if MMCV_FULL_AVAILABLE else 12
assert len(optim_wrapper.optimizer.param_groups) == len(model_parameters) + 1 == num_params
self._check_sgd_optimizer(optim_wrapper.optimizer, model,
**paramwise_cfg)

Expand All @@ -612,9 +611,8 @@ def test_default_optimizer_constructor_bypass_duplicate(self):
optim_constructor(model)
optim_wrapper = optim_constructor(model)
model_parameters = list(model.parameters())
num_params = 14 if MMCV_FULL_AVAILABLE else 11
assert len(optim_wrapper.optimizer.param_groups) == len(
model_parameters) == num_params
nnum_params = 15 if MMCV_FULL_AVAILABLE else 12
assert len(optim_wrapper.optimizer.param_groups) == len(model_parameters) + 1 == num_params
self._check_sgd_optimizer(optim_wrapper.optimizer, model,
**paramwise_cfg)

Expand Down Expand Up @@ -706,7 +704,7 @@ def test_default_optimizer_constructor_custom_key(self):
'weight_decay': self.base_wd
})

num_params = 14 if MMCV_FULL_AVAILABLE else 11
num_params = 15 if MMCV_FULL_AVAILABLE else 12
assert len(param_groups) == num_params
for i, (name, param) in enumerate(self.model.named_parameters()):
assert torch.equal(param_groups[i]['params'][0], param)
Expand Down Expand Up @@ -784,6 +782,7 @@ def _check_default_optimizer(self, optimizer, model):
self.assertEqual(optimizer.defaults['momentum'], self.momentum)
self.assertEqual(optimizer.defaults['weight_decay'], self.base_wd)
param_groups = optimizer.param_groups
param_groups.pop() # remove the last group for state tracker
params_set = set(model.parameters())
self.assertEqual(
sum(len(param_group['params']) for param_group in param_groups),
Expand Down
2 changes: 2 additions & 0 deletions tests/test_optim/test_optimizer/test_optimizer_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ def test_load_state_dict(self):
optimizer = SGD(model.parameters(), lr=0.1)
optim_wrapper.load_state_dict(optimizer.state_dict())

optim_wrapper.param_groups.pop()
self.assertEqual(optim_wrapper.state_dict(), optimizer.state_dict())

def test_param_groups(self):
Expand Down Expand Up @@ -505,6 +506,7 @@ def test_load_state_dict(self):
optimizer = SGD(self.model.parameters(), lr=0.1)
amp_optim_wrapper.load_state_dict(optimizer.state_dict())

amp_optim_wrapper.param_groups.pop()
self.assertDictEqual(optimizer.state_dict(),
amp_optim_wrapper.optimizer.state_dict())
# Test load from optim_wrapper
Expand Down

0 comments on commit 1d0a8ed

Please sign in to comment.