From 04a5a2796b50c5fd611d1250a03ca0e022989f93 Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Thu, 20 Aug 2020 15:28:47 +0000 Subject: [PATCH 1/9] fix --- pytorch_lightning/trainer/trainer.py | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index f7f7eff35997f..18ce8400cc5c5 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -769,6 +769,15 @@ def use_type(x): use_type = Trainer._allowed_type arg_default = Trainer._arg_default + # hack for types in (int, float) + if len(arg_types) == 2 and int in set(arg_types) and float in set(arg_types): + use_type = float + + # hack for track_grad_norm + if arg == 'track_grad_norm': + use_type = Trainer._grad_norm_allowed_type + arg_default = Trainer._grad_norm_arg_default + parser.add_argument( f'--{arg}', dest=arg, @@ -780,18 +789,30 @@ def use_type(x): return parser - def _allowed_type(x) -> Union[int, str]: + def _gpus_allowed_type(x) -> Union[int, str]: if ',' in x: return str(x) else: return int(x) - def _arg_default(x) -> Union[int, str]: + def _gpus_arg_default(x) -> Union[int, str]: if ',' in x: return str(x) else: return int(x) + def _grad_norm_allowed_type(x) -> Union[int, float, str]: + if 'inf' in x: + return str(x) + else: + return float(x) + + def _grad_norm_arg_default(x) -> Union[int, float, str]: + if 'inf' in x: + return str(x) + else: + return float(x) + @classmethod def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namespace: """Parse CLI arguments, required for custom bool types.""" From ed30c803dac90997853eeec11e1ca73f5147c4a1 Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Thu, 20 Aug 2020 15:31:44 +0000 Subject: [PATCH 2/9] fix --- pytorch_lightning/trainer/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 18ce8400cc5c5..e4f2744f74c31 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -766,8 +766,8 @@ def use_type(x): use_type = arg_types[0] if arg == 'gpus' or arg == 'tpu_cores': - use_type = Trainer._allowed_type - arg_default = Trainer._arg_default + use_type = Trainer._gpus_allowed_type + arg_default = Trainer._gpus_arg_default # hack for types in (int, float) if len(arg_types) == 2 and int in set(arg_types) and float in set(arg_types): From 3d55dea50ccfe99b5f2c2f1f8e898f759d35e96b Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Thu, 20 Aug 2020 15:46:34 +0000 Subject: [PATCH 3/9] fix --- pytorch_lightning/trainer/trainer.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index e4f2744f74c31..413dee67b78d8 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -442,9 +442,8 @@ def __init__( self.gradient_clip_val = gradient_clip_val self.check_val_every_n_epoch = check_val_every_n_epoch - if not isinstance(track_grad_norm, (int, float)) and track_grad_norm != 'inf': + if not (isinstance(track_grad_norm, (int, float)) or callable(track_grad_norm)) and track_grad_norm != 'inf': raise MisconfigurationException("track_grad_norm can be an int, a float or 'inf' (infinity norm).") - self.track_grad_norm = float(track_grad_norm) self.tpu_cores = _parse_tpu_cores(tpu_cores) self.on_tpu = self.tpu_cores is not None @@ -802,16 +801,10 @@ def _gpus_arg_default(x) -> Union[int, str]: return int(x) def _grad_norm_allowed_type(x) -> Union[int, float, str]: - if 'inf' in x: - return str(x) - else: - return float(x) + return float(x) def _grad_norm_arg_default(x) -> Union[int, float, str]: - if 'inf' in x: - return str(x) - else: - return float(x) + return float(x) @classmethod def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namespace: From 11724d21d322663d82bdcd451b9692c21b091cd1 Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Thu, 20 Aug 2020 15:58:03 +0000 Subject: [PATCH 4/9] fix --- pytorch_lightning/trainer/trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 413dee67b78d8..ec7586b141a59 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -444,6 +444,7 @@ def __init__( if not (isinstance(track_grad_norm, (int, float)) or callable(track_grad_norm)) and track_grad_norm != 'inf': raise MisconfigurationException("track_grad_norm can be an int, a float or 'inf' (infinity norm).") + self.track_grad_norm = track_grad_norm self.tpu_cores = _parse_tpu_cores(tpu_cores) self.on_tpu = self.tpu_cores is not None From 2870c07156c3875ac9d56b23ededf353c651232c Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Thu, 20 Aug 2020 16:38:46 +0000 Subject: [PATCH 5/9] temp --- pytorch_lightning/trainer/trainer.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index ec7586b141a59..8e0360eb33572 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -442,9 +442,9 @@ def __init__( self.gradient_clip_val = gradient_clip_val self.check_val_every_n_epoch = check_val_every_n_epoch - if not (isinstance(track_grad_norm, (int, float)) or callable(track_grad_norm)) and track_grad_norm != 'inf': + if not isinstance(track_grad_norm, (int, float)) and track_grad_norm != 'inf': raise MisconfigurationException("track_grad_norm can be an int, a float or 'inf' (infinity norm).") - self.track_grad_norm = track_grad_norm + self.track_grad_norm = float(track_grad_norm) self.tpu_cores = _parse_tpu_cores(tpu_cores) self.on_tpu = self.tpu_cores is not None @@ -775,8 +775,7 @@ def use_type(x): # hack for track_grad_norm if arg == 'track_grad_norm': - use_type = Trainer._grad_norm_allowed_type - arg_default = Trainer._grad_norm_arg_default + use_type = str parser.add_argument( f'--{arg}', @@ -801,12 +800,6 @@ def _gpus_arg_default(x) -> Union[int, str]: else: return int(x) - def _grad_norm_allowed_type(x) -> Union[int, float, str]: - return float(x) - - def _grad_norm_arg_default(x) -> Union[int, float, str]: - return float(x) - @classmethod def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namespace: """Parse CLI arguments, required for custom bool types.""" From ddf818acc40cdb5f1eced3024fdf20e2a5ead17e Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Thu, 20 Aug 2020 16:42:03 +0000 Subject: [PATCH 6/9] fix --- pytorch_lightning/trainer/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 8e0360eb33572..4297fbbc9d5c7 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -775,7 +775,7 @@ def use_type(x): # hack for track_grad_norm if arg == 'track_grad_norm': - use_type = str + use_type = float parser.add_argument( f'--{arg}', From 625360e8e6d269ada71216a797781fd09d671e29 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 20 Aug 2020 13:16:13 -0400 Subject: [PATCH 7/9] 0.9.0 readme --- pytorch_lightning/trainer/trainer.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 4297fbbc9d5c7..b4a3d86b716c9 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -771,7 +771,8 @@ def use_type(x): # hack for types in (int, float) if len(arg_types) == 2 and int in set(arg_types) and float in set(arg_types): - use_type = float + use_type = Trainer._int_or_float_type + arg_default = Trainer._int_or_float_default # hack for track_grad_norm if arg == 'track_grad_norm': @@ -800,6 +801,18 @@ def _gpus_arg_default(x) -> Union[int, str]: else: return int(x) + def _int_or_float_type(x) -> Union[int, float]: + if '.' in x: + return float(x) + else: + return int(x) + + def _int_or_float_default(x) -> Union[int, float]: + if '.' in x: + return float(x) + else: + return int(x) + @classmethod def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namespace: """Parse CLI arguments, required for custom bool types.""" From bd94039cce29a77780201f05dcec8582aec2eecc Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 20 Aug 2020 13:25:50 -0400 Subject: [PATCH 8/9] 0.9.0 readme --- pytorch_lightning/trainer/trainer.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index b4a3d86b716c9..697b444aa5399 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -772,7 +772,6 @@ def use_type(x): # hack for types in (int, float) if len(arg_types) == 2 and int in set(arg_types) and float in set(arg_types): use_type = Trainer._int_or_float_type - arg_default = Trainer._int_or_float_default # hack for track_grad_norm if arg == 'track_grad_norm': @@ -807,12 +806,6 @@ def _int_or_float_type(x) -> Union[int, float]: else: return int(x) - def _int_or_float_default(x) -> Union[int, float]: - if '.' in x: - return float(x) - else: - return int(x) - @classmethod def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namespace: """Parse CLI arguments, required for custom bool types.""" From e12933dfac8860ed72041157911d194b91ab44d3 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 20 Aug 2020 13:26:02 -0400 Subject: [PATCH 9/9] 0.9.0 readme --- pytorch_lightning/trainer/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 697b444aa5399..9d462ef8f36aa 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -801,7 +801,7 @@ def _gpus_arg_default(x) -> Union[int, str]: return int(x) def _int_or_float_type(x) -> Union[int, float]: - if '.' in x: + if '.' in str(x): return float(x) else: return int(x)