Skip to content

Commit

Permalink
Merge pull request #499 from BindsNET/hananel
Browse files Browse the repository at this point in the history
Fix PyTorch 1.9 compatibility issues.
  • Loading branch information
Hananel-Hazan committed Jun 27, 2021
2 parents cf4af8d + 1d2fd17 commit 328102f
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 21 deletions.
6 changes: 3 additions & 3 deletions bindsnet/conversion/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,10 @@ def set_requires_grad(module, value):

extractor2 = FeatureExtractor(module)
all_activations2 = extractor2.forward(data)
for name2, module2 in module.named_children():
for name2, module in module.named_children():
activations = all_activations2[name2]

if isinstance(module2, nn.ReLU):
if isinstance(module, nn.ReLU):
if prev_module is not None:
scale_factor = np.percentile(activations.cpu(), percentile)

Expand All @@ -136,7 +136,7 @@ def set_requires_grad(module, value):
elif isinstance(module2, nn.Linear) or isinstance(module2, nn.Conv2d):
prev_module = module2

if isinstance(module2, nn.Linear):
if isinstance(module, nn.Linear):
if prev_module is not None:
scale_factor = np.percentile(activations.cpu(), percentile)
prev_module.weight *= prev_factor / scale_factor
Expand Down
3 changes: 2 additions & 1 deletion bindsnet/datasets/collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
"""

import torch
from torch._six import container_abcs, string_classes, int_classes
from torch._six import string_classes
import collections

from torch.utils.data._utils import collate as pytorch_collate

Expand Down
3 changes: 2 additions & 1 deletion bindsnet/pipeline/base_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from typing import Tuple, Dict, Any

import torch
from torch._six import container_abcs, string_classes
from torch._six import string_classes
import collections

from ..network import Network
from ..network.monitors import Monitor
Expand Down
12 changes: 6 additions & 6 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
foolbox
scipy>=1.1.0
numpy>=1.14.2
cython>=0.28.5
torch==1.8.1
torchvision==0.9.1
scipy>=1.5.4
numpy>=1.19.5
cython>=0.29.5
torch==1.9.0
torchvision==0.10.0
tensorboardX==2.2
tqdm>=4.19.9
tqdm>=4.60.0
setuptools>=39.0.1
matplotlib>=2.1.0
gym>=0.10.4
Expand Down
14 changes: 7 additions & 7 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,20 @@
packages=find_packages(),
zip_safe=False,
install_requires=[
"numpy>=1.14.2",
"torch==1.8.1",
"torchvision==0.9.1",
"numpy>=1.19.5",
"torch==1.9.0",
"torchvision==0.10.0",
"tensorboardX==2.2",
"tqdm>=4.19.9",
"tqdm>=4.60.0",
"matplotlib>=2.1.0",
"gym>=0.10.4",
"scikit-build>=0.11.1",
"scikit_image>=0.13.1",
"scikit_learn>=0.19.1",
"opencv-python>=3.4.0.12",
"pytest>=3.4.0",
"scipy>=1.1.0",
"cython>=0.28.5",
"pytest>=6.2.0",
"scipy>=1.5.4",
"cython>=0.29.0",
"pandas>=0.23.4",
],
)
13 changes: 10 additions & 3 deletions test/conversion/test_conversion.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

Expand All @@ -24,14 +25,20 @@ def forward(self, x):
return x


def test_conversion():
def test_conversion_1():
ann = FullyConnectedNetwork()
snn = ann_to_snn(ann, input_shape=(784,))


def main():
def test_conversion_2():
data = torch.rand(784, 20)
ann = FullyConnectedNetwork()
return ann_to_snn(ann, input_shape=(28, 28))
snn = ann_to_snn(ann, data=data, input_shape=(784,))


def main():
test_conversion_1()
test_conversion_2()


if __name__ == "__main__":
Expand Down

0 comments on commit 328102f

Please sign in to comment.