diff --git a/apex/contrib/gpu_direct_storage/__init__.py b/apex/contrib/gpu_direct_storage/__init__.py index 51503d83c..5a5acf09d 100644 --- a/apex/contrib/gpu_direct_storage/__init__.py +++ b/apex/contrib/gpu_direct_storage/__init__.py @@ -4,8 +4,8 @@ @contextmanager def GDSFile(filename, mode): - assert type(filename) == str - assert type(mode) == str + assert isinstance(filename, str) + assert isinstance(mode, str) try: from apex import deprecated_warning diff --git a/apex/contrib/optimizers/fused_adam.py b/apex/contrib/optimizers/fused_adam.py index 74a59fa18..37a77b3cf 100644 --- a/apex/contrib/optimizers/fused_adam.py +++ b/apex/contrib/optimizers/fused_adam.py @@ -106,7 +106,7 @@ def step(self, closure=None, grads=None, output_params=None, scale=1.0, grad_nor # assuming a list/generator of parameter means single group elif isinstance(grads, types.GeneratorType): grads_group = [grads] - elif type(grads[0]) != list: + elif not isinstance(grads[0], list): grads_group = [grads] else: grads_group = grads @@ -115,7 +115,7 @@ def step(self, closure=None, grads=None, output_params=None, scale=1.0, grad_nor output_params_group = [None] * len(self.param_groups) elif isinstance(output_params, types.GeneratorType): output_params_group = [output_params] - elif type(output_params[0]) != list: + elif not isinstance(output_params[0], list): output_params_group = [output_params] else: output_params_group = output_params diff --git a/apex/contrib/optimizers/fused_sgd.py b/apex/contrib/optimizers/fused_sgd.py index 62c9c2554..e2acfcbaa 100644 --- a/apex/contrib/optimizers/fused_sgd.py +++ b/apex/contrib/optimizers/fused_sgd.py @@ -157,7 +157,7 @@ def step(self, closure=None, grads=None, output_params=None, scale=1.0, grad_nor # assuming a list/generator of parameter means single group elif isinstance(grads, types.GeneratorType): grads_group = [grads] - elif type(grads[0]) != list: + elif not isinstance(grads[0], list): grads_group = [grads] else: grads_group = grads @@ -170,7 +170,7 @@ def step(self, closure=None, grads=None, output_params=None, scale=1.0, grad_nor ) elif isinstance(output_params, types.GeneratorType): output_params_group = [output_params] - elif type(output_params[0]) != list: + elif not isinstance(output_params[0], list): output_params_group = [output_params] else: output_params_group = output_params diff --git a/apex/contrib/test/openfold_triton/test_sync_triton_auto_tune_cache_across_gpus.py b/apex/contrib/test/openfold_triton/test_sync_triton_auto_tune_cache_across_gpus.py index 7cb59a123..0ec5aee35 100644 --- a/apex/contrib/test/openfold_triton/test_sync_triton_auto_tune_cache_across_gpus.py +++ b/apex/contrib/test/openfold_triton/test_sync_triton_auto_tune_cache_across_gpus.py @@ -45,9 +45,9 @@ def destroy_pg_upon_exit(self) -> bool: def _create_process_group_nccl(self): def maybe_export(env, val): - if not type(env) == str: + if not isinstance(env, str): raise ValueError(f"Type of type of env is expected to be str, but got {type(env)}") - if not type(val) == str: + if not isinstance(val, str): raise ValueError(f"Type of type of val is expected to be str, but got {type(val)}") if os.getenv(env) is None: os.environ[env] = val diff --git a/pyproject.toml b/pyproject.toml index 952f4df66..dfe13c0b8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,6 @@ line-length = 100 ignore = [ # Sorted by occurrence count (ascending) - easier to fix first "E731", # lambda assignment (6 occurrences) - "E721", # type comparison should use isinstance (8 occurrences) "E741", # ambiguous variable name (8 occurrences) "E712", # comparison to True/False (9 occurrences) "F403", # star imports used (9 occurrences)