diff --git a/scripts/unshard.py b/scripts/unshard.py index bdd5faca8..6a4f8301d 100644 --- a/scripts/unshard.py +++ b/scripts/unshard.py @@ -80,7 +80,7 @@ def shard_size(shard_md): def objects_are_equal(a: Any, b: Any) -> bool: - if type(a) != type(b): + if type(a) is not type(b): return False if isinstance(a, ndarray): return np.array_equal(a, b) @@ -139,7 +139,7 @@ def _rebuild_from_type_v2_monkey(func, new_type, args, state): def unshard_object(os: List[Any]) -> Any: rank0_item = os[0] - assert all(type(o) == type(rank0_item) for o in os) + assert all(type(o) is type(rank0_item) for o in os) if isinstance(rank0_item, str): assert all(o == rank0_item for o in os) return rank0_item