Skip to content

Commit

Permalink
Merge pull request #518 from danielgafni/isort-and-autoflake-refactor
Browse files Browse the repository at this point in the history
Applied isort and autoflake; Converted all relative imports to absolute
  • Loading branch information
Hananel-Hazan committed Sep 20, 2021
2 parents 12f2302 + 4785155 commit f4b550b
Show file tree
Hide file tree
Showing 74 changed files with 490 additions and 382 deletions.
12 changes: 12 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,20 @@
repos:
- repo: local
hooks:
- id: isort
name: isort
entry: isort . --settings-file pyproject.toml
language: system
pass_filenames: false
- id: black
name: black
entry: black .
language: system
pass_filenames: false
- id: autoflake
name: autoflake
entry: autoflake
language: system
types: [ python ]
args: [ --in-place, --remove-all-unused-imports, --remove-duplicate-keys ]
files: ^bindsnet/|test/
35 changes: 26 additions & 9 deletions bindsnet/__init__.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,35 @@
from pathlib import Path

from . import (
utils,
network,
models,
from bindsnet import (
analysis,
preprocessing,
conversion,
datasets,
encoding,
pipeline,
learning,
evaluation,
environment,
conversion,
evaluation,
learning,
models,
network,
pipeline,
preprocessing,
utils,
)

ROOT_DIR = Path(__file__).parents[0].parents[0]


__all__ = [
"utils",
"network",
"models",
"analysis",
"preprocessing",
"datasets",
"encoding",
"pipeline",
"learning",
"evaluation",
"environment",
"conversion",
"ROOT_DIR",
]
4 changes: 3 additions & 1 deletion bindsnet/analysis/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from . import plotting, visualization, pipeline_analysis
from bindsnet.analysis import pipeline_analysis, plotting, visualization

__all__ = ["plotting", "visualization", "pipeline_analysis"]
2 changes: 1 addition & 1 deletion bindsnet/analysis/dotTrace_plotter.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import numpy as np
import glob
import sys

import matplotlib.pyplot as plt
import numpy as np

# Define grid dimensions globally
ROWS = 28
Expand Down
9 changes: 1 addition & 8 deletions bindsnet/analysis/pipeline_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from tensorboardX import SummaryWriter
from torchvision.utils import make_grid

from .plotting import plot_spikes, plot_voltages, plot_conv2d_weights
from ..utils import reshape_conv2d_weights
from .plotting import plot_conv2d_weights, plot_spikes, plot_voltages


class PipelineAnalyzer(ABC):
Expand All @@ -25,7 +25,6 @@ def finalize_step(self) -> None:
"""
Flush the output from the current step.
"""
pass

@abstractmethod
def plot_obs(self, obs: torch.Tensor, tag: str = "obs", step: int = None) -> None:
Expand All @@ -38,7 +37,6 @@ def plot_obs(self, obs: torch.Tensor, tag: str = "obs", step: int = None) -> Non
:param tag: A unique tag to associate the data with.
:param step: The step of the pipeline.
"""
pass

@abstractmethod
def plot_reward(
Expand All @@ -57,7 +55,6 @@ def plot_reward(
:param tag: A unique tag to associate the data with.
:param step: The step of the pipeline.
"""
pass

@abstractmethod
def plot_spikes(
Expand All @@ -75,7 +72,6 @@ def plot_spikes(
:param tag: A unique tag to associate the data with.
:param step: The step of the pipeline.
"""
pass

@abstractmethod
def plot_voltages(
Expand All @@ -96,7 +92,6 @@ def plot_voltages(
:param tag: A unique tag to associate the data with.
:param step: The step of the pipeline.
"""
pass

@abstractmethod
def plot_conv2d_weights(
Expand All @@ -110,7 +105,6 @@ def plot_conv2d_weights(
:param tag: A unique tag to associate the data with.
:param step: The step of the pipeline.
"""
pass


class MatplotlibAnalyzer(PipelineAnalyzer):
Expand Down Expand Up @@ -313,7 +307,6 @@ def finalize_step(self) -> None:
"""
No-op for ``TensorboardAnalyzer``.
"""
pass

def plot_obs(self, obs: torch.Tensor, tag: str = "obs", step: int = None) -> None:
# language=rst
Expand Down
14 changes: 7 additions & 7 deletions bindsnet/analysis/plotting.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import torch
import numpy as np
import matplotlib.pyplot as plt
from typing import Dict, List, Optional, Sized, Tuple, Union

import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.axes import Axes
from matplotlib.image import AxesImage
from torch.nn.modules.utils import _pair
from matplotlib.collections import PathCollection
from matplotlib.image import AxesImage
from mpl_toolkits.axes_grid1 import make_axes_locatable
from typing import Tuple, List, Optional, Sized, Dict, Union
from torch.nn.modules.utils import _pair

from ..utils import reshape_locally_connected_weights, reshape_conv2d_weights
from bindsnet.utils import reshape_conv2d_weights, reshape_locally_connected_weights

plt.ion()

Expand Down
10 changes: 5 additions & 5 deletions bindsnet/analysis/visualization.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import torch
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from typing import List, Optional, Tuple

from typing import List, Tuple, Optional
import matplotlib.animation as animation
import matplotlib.pyplot as plt
import numpy as np
import torch


def plot_weights_movie(ws: np.ndarray, sample_every: int = 1) -> None:
Expand Down
21 changes: 16 additions & 5 deletions bindsnet/conversion/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,21 @@
from .conversion import (
Permute,
from bindsnet.conversion.conversion import (
ConstantPad2dConnection,
FeatureExtractor,
SubtractiveResetIFNodes,
PassThroughNodes,
Permute,
PermuteConnection,
ConstantPad2dConnection,
data_based_normalization,
SubtractiveResetIFNodes,
ann_to_snn,
data_based_normalization,
)

__all__ = [
"Permute",
"FeatureExtractor",
"SubtractiveResetIFNodes",
"PassThroughNodes",
"PermuteConnection",
"ConstantPad2dConnection",
"data_based_normalization",
"ann_to_snn",
]
15 changes: 6 additions & 9 deletions bindsnet/conversion/conversion.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,16 @@
import torch
from copy import deepcopy
from typing import Dict, Optional, Sequence, Union

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.nn.modules.utils import _pair

from copy import deepcopy
from typing import Union, Sequence, Optional, Tuple, Dict, Iterable

import bindsnet.network.nodes as nodes
import bindsnet.network.topology as topology

from bindsnet.conversion.nodes import PassThroughNodes, SubtractiveResetIFNodes
from bindsnet.conversion.topology import ConstantPad2dConnection, PermuteConnection
from bindsnet.network import Network
from .nodes import SubtractiveResetIFNodes, PassThroughNodes
from .topology import PermuteConnection, ConstantPad2dConnection


class Permute(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion bindsnet/conversion/nodes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Iterable, Union
from typing import Iterable, Optional, Union

import torch

Expand Down
2 changes: 1 addition & 1 deletion bindsnet/conversion/topology.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Iterable, Union, Tuple
from typing import Iterable, Optional, Tuple, Union

import torch
import torch.nn.functional as F
Expand Down
55 changes: 47 additions & 8 deletions bindsnet/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from .torchvision_wrapper import create_torchvision_dataset_wrapper
from .spoken_mnist import SpokenMNIST
from .davis import Davis
from .alov300 import ALOV300

from .collate import time_aware_collate
from .dataloader import DataLoader

from bindsnet.datasets.alov300 import ALOV300
from bindsnet.datasets.collate import time_aware_collate
from bindsnet.datasets.dataloader import DataLoader
from bindsnet.datasets.davis import Davis
from bindsnet.datasets.spoken_mnist import SpokenMNIST
from bindsnet.datasets.torchvision_wrapper import create_torchvision_dataset_wrapper

CIFAR10 = create_torchvision_dataset_wrapper("CIFAR10")
CIFAR100 = create_torchvision_dataset_wrapper("CIFAR100")
Expand All @@ -31,3 +29,44 @@
SVHN = create_torchvision_dataset_wrapper("SVHN")
VOCDetection = create_torchvision_dataset_wrapper("VOCDetection")
VOCSegmentation = create_torchvision_dataset_wrapper("VOCSegmentation")


__all__ = [
"torchvision_wrapper",
"create_torchvision_dataset_wrapper",
"spoken_mnist",
"SpokenMNIST",
"davis",
"Davis",
"preprocess",
"alov300",
"ALOV300",
"collate",
"time_aware_collate",
"dataloader",
"DataLoader",
"CIFAR10",
"CIFAR100",
"Cityscapes",
"CocoCaptions",
"CocoDetection",
"DatasetFolder",
"EMNIST",
"FakeData",
"FashionMNIST",
"Flickr30k",
"Flickr8k",
"ImageFolder",
"KMNIST",
"LSUN",
"LSUNClass",
"MNIST",
"Omniglot",
"PhotoTour",
"SBU",
"SEMEION",
"STL10",
"SVHN",
"VOCDetection",
"VOCSegmentation",
]
10 changes: 3 additions & 7 deletions bindsnet/datasets/alov300.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,20 @@
import os
import sys
import time
import zipfile
import warnings
from glob import glob
import zipfile
from urllib.request import urlretrieve
from typing import Optional, Tuple, List, Iterable

import cv2
import torch
import numpy as np
from PIL import Image
from torch.utils.data import Dataset

from bindsnet.datasets.preprocess import (
cropPadImage,
BoundingBox,
crop_sample,
Rescale,
bgr2rgb,
crop_sample,
cropPadImage,
)

warnings.filterwarnings("ignore")
Expand Down
4 changes: 2 additions & 2 deletions bindsnet/datasets/collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
Modifications exist to have [time, batch, n_0, ... n_k] instead of batch in dimension 0.
"""

import torch
from torch._six import string_classes
import collections

import torch
from torch._six import string_classes
from torch.utils.data._utils import collate as pytorch_collate


Expand Down
2 changes: 1 addition & 1 deletion bindsnet/datasets/dataloader.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch

from .collate import time_aware_collate
from bindsnet.datasets.collate import time_aware_collate


class DataLoader(torch.utils.data.DataLoader):
Expand Down
6 changes: 3 additions & 3 deletions bindsnet/datasets/davis.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import os
import shutil
import sys
import time
import shutil
import zipfile
from glob import glob
from collections import defaultdict
from glob import glob
from urllib.request import urlretrieve

import torch
import numpy as np
import torch
from PIL import Image
from tqdm import tqdm

Expand Down
3 changes: 1 addition & 2 deletions bindsnet/datasets/preprocess.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import math
import random
import warnings

import cv2
import torch
import numpy as np
import torch
from torchvision import transforms


Expand Down

0 comments on commit f4b550b

Please sign in to comment.