Skip to content

Commit

Permalink
Merge pull request #500 from BindsNET/hananel
Browse files Browse the repository at this point in the history
Fix issues cause by torch._six upgrade
  • Loading branch information
Hananel-Hazan committed Jun 27, 2021
2 parents 328102f + a2c9d78 commit c9f3e9d
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
6 changes: 3 additions & 3 deletions bindsnet/datasets/collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,15 +73,15 @@ def time_aware_collate(batch):
return torch.as_tensor(batch)
elif isinstance(elem, float):
return torch.tensor(batch, dtype=torch.float64)
elif isinstance(elem, int_classes):
elif isinstance(elem, int):
return torch.tensor(batch)
elif isinstance(elem, string_classes):
return batch
elif isinstance(elem, container_abcs.Mapping):
elif isinstance(elem, collections.Mapping):
return {key: time_aware_collate([d[key] for d in batch]) for key in elem}
elif isinstance(elem, tuple) and hasattr(elem, "_fields"): # namedtuple
return elem_type(*(time_aware_collate(samples) for samples in zip(*batch)))
elif isinstance(elem, container_abcs.Sequence):
elif isinstance(elem, collections.Sequence):
transposed = zip(*batch)
return [time_aware_collate(samples) for samples in transposed]

Expand Down
4 changes: 2 additions & 2 deletions bindsnet/pipeline/base_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ def recursive_to(item, device):
return item.to(device)
elif isinstance(item, (string_classes, int, float, bool)):
return item
elif isinstance(item, container_abcs.Mapping):
elif isinstance(item, collections.Mapping):
return {key: recursive_to(item[key], device) for key in item}
elif isinstance(item, tuple) and hasattr(item, "_fields"):
return type(item)(*(recursive_to(i, device) for i in item))
elif isinstance(item, container_abcs.Sequence):
elif isinstance(item, collections.Sequence):
return [recursive_to(i, device) for i in item]
else:
raise NotImplementedError(f"Target type {type(item)} not supported.")
Expand Down

0 comments on commit c9f3e9d

Please sign in to comment.