Skip to content

Commit

Permalink
make load_state_dict_forgiving more forgiving
Browse files Browse the repository at this point in the history
  • Loading branch information
Vermeille committed May 10, 2023
1 parent 424bda4 commit 9db8fb1
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions torchelie/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ def kaiming_gain(m: T_Module,
Return the std needed to initialize a weight matrix with given parameters.
"""
if mode == 'fan_inout':
fan = (math.sqrt(nn.init._calculate_correct_fan(m.weight, 'fan_in'))
+ math.sqrt(nn.init._calculate_correct_fan(m.weight,
'fan_out'))) / 2
fan = (
math.sqrt(nn.init._calculate_correct_fan(m.weight, 'fan_in')) +
math.sqrt(nn.init._calculate_correct_fan(m.weight, 'fan_out'))) / 2
else:
fan = math.sqrt(nn.init._calculate_correct_fan(m.weight, mode))
gain = nn.init.calculate_gain(nonlinearity, param=a)
Expand Down Expand Up @@ -390,8 +390,12 @@ def load_state_dict_forgiving(dst, state_dict: dict, silent: bool = False):
failed.add(name)
if silent:
continue
print('error in', name, ': checkpoint has ', val.shape,
'-> model has', dst_dict[name].shape, '(', str(e), ')')
if name in dst_dict:
print('error in', name, ': checkpoint has ', val.shape,
'-> model has', dst_dict[name].shape, '(', str(e), ')')
else:
print('error in', name, ': checkpoint has ', val.shape,
'-> model has no such key')
return {
'dst_only': dst_names - from_dict,
'state_only': from_dict - dst_names,
Expand Down

0 comments on commit 9db8fb1

Please sign in to comment.