From 3ca676b82f44d851b63a564dd6a2289e920b516c Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Tue, 1 Sep 2020 16:37:09 -0400 Subject: [PATCH] Black formatting with default line length. --- bindsnet/analysis/pipeline_analysis.py | 16 +-- bindsnet/analysis/plotting.py | 83 +++++-------- bindsnet/analysis/visualization.py | 23 +--- bindsnet/conversion/conversion.py | 44 ++----- bindsnet/conversion/nodes.py | 8 +- bindsnet/datasets/alov300.py | 33 ++---- bindsnet/datasets/collate.py | 16 +-- bindsnet/datasets/davis.py | 63 +++------- bindsnet/datasets/preprocess.py | 72 +++--------- bindsnet/datasets/spoken_mnist.py | 52 ++------ bindsnet/encoding/encoders.py | 4 +- bindsnet/encoding/encodings.py | 18 +-- bindsnet/environment/environment.py | 16 +-- bindsnet/evaluation/evaluation.py | 7 +- bindsnet/learning/learning.py | 103 ++++------------ bindsnet/learning/reward.py | 9 +- bindsnet/models/models.py | 8 +- bindsnet/network/monitors.py | 46 ++------ bindsnet/network/network.py | 10 +- bindsnet/network/nodes.py | 137 ++++++---------------- bindsnet/network/topology.py | 65 ++++------ bindsnet/pipeline/action.py | 8 +- bindsnet/pipeline/base_pipeline.py | 20 +--- bindsnet/pipeline/environment_pipeline.py | 8 +- bindsnet/utils.py | 33 ++---- examples/breakout/breakout.py | 8 +- examples/breakout/breakout_stdp.py | 8 +- examples/mnist/batch_eth_mnist.py | 23 ++-- examples/mnist/conv_mnist.py | 26 +--- examples/mnist/eth_mnist.py | 22 +--- examples/mnist/reservoir.py | 22 +--- examples/mnist/supervised_mnist.py | 28 +---- test/encoding/test_encoding.py | 16 +-- test/models/test_models.py | 15 +-- test/network/test_connections.py | 14 +-- test/network/test_learning.py | 45 ++----- test/network/test_monitors.py | 8 +- test/network/test_nodes.py | 36 +----- 38 files changed, 304 insertions(+), 869 deletions(-) diff --git a/bindsnet/analysis/pipeline_analysis.py b/bindsnet/analysis/pipeline_analysis.py index 665b479..bb030e3 100644 --- a/bindsnet/analysis/pipeline_analysis.py +++ b/bindsnet/analysis/pipeline_analysis.py @@ -28,9 +28,7 @@ def finalize_step(self) -> None: pass @abstractmethod - def plot_obs( - self, obs: torch.Tensor, tag: str = "obs", step: int = None - ) -> None: + def plot_obs(self, obs: torch.Tensor, tag: str = "obs", step: int = None) -> None: # language=rst """ Pulls the observation from PyTorch and sets up for Matplotlib @@ -139,9 +137,7 @@ def __init__(self, **kwargs) -> None: plt.ion() self.plots = {} - def plot_obs( - self, obs: torch.Tensor, tag: str = "obs", step: int = None - ) -> None: + def plot_obs(self, obs: torch.Tensor, tag: str = "obs", step: int = None) -> None: # language=rst """ Pulls the observation off of torch and sets up for Matplotlib @@ -260,9 +256,7 @@ def plot_voltages( """ if tag not in self.plots: self.plots[tag] = plot_voltages( - voltage_record, - plot_type=self.volts_type, - thresholds=thresholds, + voltage_record, plot_type=self.volts_type, thresholds=thresholds, ) else: v_im, v_ax = self.plots[tag] @@ -321,9 +315,7 @@ def finalize_step(self) -> None: """ pass - def plot_obs( - self, obs: torch.Tensor, tag: str = "obs", step: int = None - ) -> None: + def plot_obs(self, obs: torch.Tensor, tag: str = "obs", step: int = None) -> None: # language=rst """ Pulls the observation off of torch and sets up for Matplotlib diff --git a/bindsnet/analysis/plotting.py b/bindsnet/analysis/plotting.py index 21de17b..4c8a2f7 100644 --- a/bindsnet/analysis/plotting.py +++ b/bindsnet/analysis/plotting.py @@ -115,8 +115,7 @@ def plot_spikes( for i, datum in enumerate(spikes.items()): spikes = ( datum[1][ - time[0] : time[1], - n_neurons[datum[0]][0] : n_neurons[datum[0]][1], + time[0] : time[1], n_neurons[datum[0]][0] : n_neurons[datum[0]][1], ] .detach() .clone() @@ -144,19 +143,14 @@ def plot_spikes( ax.set_aspect("auto") plt.setp( - axes, - xticks=[], - yticks=[], - xlabel="Simulation time", - ylabel="Neuron index", + axes, xticks=[], yticks=[], xlabel="Simulation time", ylabel="Neuron index", ) plt.tight_layout() else: for i, datum in enumerate(spikes.items()): spikes = ( datum[1][ - time[0] : time[1], - n_neurons[datum[0]][0] : n_neurons[datum[0]][1], + time[0] : time[1], n_neurons[datum[0]][0] : n_neurons[datum[0]][1], ] .detach() .clone() @@ -375,16 +369,11 @@ def plot_assignments( if classes is None: color = plt.get_cmap("RdBu", 11) - im = ax.matshow( - locals_assignments, cmap=color, vmin=-1.5, vmax=9.5 - ) + im = ax.matshow(locals_assignments, cmap=color, vmin=-1.5, vmax=9.5) else: color = plt.get_cmap("RdBu", len(classes) + 1) im = ax.matshow( - locals_assignments, - cmap=color, - vmin=-1.5, - vmax=len(classes) - 0.5, + locals_assignments, cmap=color, vmin=-1.5, vmax=len(classes) - 0.5, ) div = make_axes_locatable(ax) @@ -509,14 +498,12 @@ def plot_voltages( ) ) - if thresholds is not None and thresholds[ - v[0] - ].size() == torch.Size([]): + if thresholds is not None and thresholds[v[0]].size() == torch.Size( + [] + ): ims.append( axes.axhline( - y=thresholds[v[0]].item(), - c="r", - linestyle="--", + y=thresholds[v[0]].item(), c="r", linestyle="--", ) ) else: @@ -540,10 +527,7 @@ def plot_voltages( time[0], time[1], ) - plt.title( - "%s voltages for neurons (%d - %d) from t = %d to %d " - % args - ) + plt.title("%s voltages for neurons (%d - %d) from t = %d to %d " % args) plt.xlabel("Time (ms)") if plot_type == "line": @@ -566,14 +550,12 @@ def plot_voltages( ] ) ) - if thresholds is not None and thresholds[ - v[0] - ].size() == torch.Size([]): + if thresholds is not None and thresholds[v[0]].size() == torch.Size( + [] + ): ims.append( axes[i].axhline( - y=thresholds[v[0]].item(), - c="r", - linestyle="--", + y=thresholds[v[0]].item(), c="r", linestyle="--", ) ) else: @@ -597,8 +579,7 @@ def plot_voltages( time[1], ) axes[i].set_title( - "%s voltages for neurons (%d - %d) from t = %d to %d " - % args + "%s voltages for neurons (%d - %d) from t = %d to %d " % args ) for ax in axes: @@ -621,23 +602,19 @@ def plot_voltages( v[1] .cpu() .numpy()[ - n_neurons[v[0]][0] : n_neurons[v[0]][1], - time[0] : time[1], + n_neurons[v[0]][0] : n_neurons[v[0]][1], time[0] : time[1], ] ) - if thresholds is not None and thresholds[ - v[0] - ].size() == torch.Size([]): - axes.axhline( - y=thresholds[v[0]].item(), c="r", linestyle="--" - ) + if thresholds is not None and thresholds[v[0]].size() == torch.Size( + [] + ): + axes.axhline(y=thresholds[v[0]].item(), c="r", linestyle="--") else: axes.matshow( v[1] .cpu() .numpy()[ - n_neurons[v[0]][0] : n_neurons[v[0]][1], - time[0] : time[1], + n_neurons[v[0]][0] : n_neurons[v[0]][1], time[0] : time[1], ] .T, cmap=cmap, @@ -650,8 +627,7 @@ def plot_voltages( time[1], ) axes.set_title( - "%s voltages for neurons (%d - %d) from t = %d to %d " - % args + "%s voltages for neurons (%d - %d) from t = %d to %d " % args ) axes.set_aspect("auto") @@ -664,13 +640,12 @@ def plot_voltages( v[1] .cpu() .numpy()[ - n_neurons[v[0]][0] : n_neurons[v[0]][1], - time[0] : time[1], + n_neurons[v[0]][0] : n_neurons[v[0]][1], time[0] : time[1], ] ) - if thresholds is not None and thresholds[ - v[0] - ].size() == torch.Size([]): + if thresholds is not None and thresholds[v[0]].size() == torch.Size( + [] + ): axes[i].axhline( y=thresholds[v[0]].item(), c="r", linestyle="--" ) @@ -679,8 +654,7 @@ def plot_voltages( v[1] .cpu() .numpy()[ - n_neurons[v[0]][0] : n_neurons[v[0]][1], - time[0] : time[1], + n_neurons[v[0]][0] : n_neurons[v[0]][1], time[0] : time[1], ] .T, cmap=cmap, @@ -693,8 +667,7 @@ def plot_voltages( time[1], ) axes[i].set_title( - "%s voltages for neurons (%d - %d) from t = %d to %d " - % args + "%s voltages for neurons (%d - %d) from t = %d to %d " % args ) for ax in axes: diff --git a/bindsnet/analysis/visualization.py b/bindsnet/analysis/visualization.py index 806c1e1..2d03fe2 100644 --- a/bindsnet/analysis/visualization.py +++ b/bindsnet/analysis/visualization.py @@ -18,18 +18,14 @@ def plot_weights_movie(ws: np.ndarray, sample_every: int = 1) -> None: # Obtain samples from the weights for every example. for i in range(ws.shape[0]): - sub_sampled_weight = ws[ - i, :, :, range(0, ws[i].shape[2], sample_every) - ] + sub_sampled_weight = ws[i, :, :, range(0, ws[i].shape[2], sample_every)] weights.append(sub_sampled_weight) else: weights = np.concatenate(weights, axis=0) # Initialize plot. fig = plt.figure() - im = plt.imshow( - weights[0, :, :], cmap="hot_r", animated=True, vmin=0, vmax=1 - ) + im = plt.imshow(weights[0, :, :], cmap="hot_r", animated=True, vmin=0, vmax=1) plt.axis("off") plt.colorbar(im) @@ -68,9 +64,7 @@ def plot_spike_trains_for_example( plt.figure() if top_k is None and indices is None: # Plot all neurons' spiking activity - spike_per_neuron = [ - np.argwhere(i == 1).flatten() for i in spikes[n_ex, :, :] - ] + spike_per_neuron = [np.argwhere(i == 1).flatten() for i in spikes[n_ex, :, :]] plt.title("Spiking activity for all %d neurons" % spikes.shape[1]) elif top_k is None: # Plot based on indices parameter @@ -82,12 +76,9 @@ def plot_spike_trains_for_example( elif indices is None: # Plot based on top_k parameter assert top_k is not None # Obtain the top k neurons that fired the most - top_k_loc = np.argsort(np.sum(spikes[n_ex, :, :], axis=1), axis=0)[ - ::-1 - ] + top_k_loc = np.argsort(np.sum(spikes[n_ex, :, :], axis=1), axis=0)[::-1] spike_per_neuron = [ - np.argwhere(i == 1).flatten() - for i in spikes[n_ex, top_k_loc[0:top_k], :] + np.argwhere(i == 1).flatten() for i in spikes[n_ex, top_k_loc[0:top_k], :] ] plt.title("Spiking activity for top %d neurons" % top_k) @@ -133,9 +124,7 @@ def plot_voltage( plt.plot(voltage[n_ex, n_neuron, timer]) plt.xlabel("Simulation Time") plt.ylabel("Voltage") - plt.title( - "Membrane voltage of neuron %d for example %d" % (n_neuron, n_ex + 1) - ) + plt.title("Membrane voltage of neuron %d for example %d" % (n_neuron, n_ex + 1)) locs, labels = plt.xticks() locs = range(int(locs[1]), int(locs[-1]), 10) plt.xticks(locs, time_ticks) diff --git a/bindsnet/conversion/conversion.py b/bindsnet/conversion/conversion.py index e1f0f7e..bf9b367 100644 --- a/bindsnet/conversion/conversion.py +++ b/bindsnet/conversion/conversion.py @@ -126,26 +126,22 @@ def set_requires_grad(module, value): if isinstance(module2, nn.ReLU): if prev_module is not None: - scale_factor = np.percentile( - activations.cpu(), percentile - ) + scale_factor = np.percentile(activations.cpu(), percentile) prev_module.weight *= prev_factor / scale_factor prev_module.bias /= scale_factor prev_factor = scale_factor - elif isinstance(module2, nn.Linear) or isinstance( - module2, nn.Conv2d - ): + elif isinstance(module2, nn.Linear) or isinstance(module2, nn.Conv2d): prev_module = module2 if isinstance(module2, nn.Linear): - if prev_module is not None: - scale_factor = np.percentile(activations.cpu(), percentile) - prev_module.weight *= prev_factor / scale_factor - prev_module.bias /= scale_factor - prev_factor = scale_factor + if prev_module is not None: + scale_factor = np.percentile(activations.cpu(), percentile) + prev_module.weight *= prev_factor / scale_factor + prev_module.bias /= scale_factor + prev_factor = scale_factor else: activations = all_activations[name] @@ -158,9 +154,7 @@ def set_requires_grad(module, value): prev_factor = scale_factor - elif isinstance(module, nn.Linear) or isinstance( - module, nn.Conv2d - ): + elif isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): prev_module = module return ann @@ -187,9 +181,7 @@ def _ann_to_snn_helper(prev, current, node_type, last=False, **kwargs): sum_input=last, **kwargs, ) - bias = ( - current.bias if current.bias is not None else torch.zeros(layer.n) - ) + bias = current.bias if current.bias is not None else torch.zeros(layer.n) connection = topology.Connection( source=prev, target=layer, w=current.weight.t(), b=bias ) @@ -209,11 +201,7 @@ def _ann_to_snn_helper(prev, current, node_type, last=False, **kwargs): layer = node_type( shape=shape, reset=0, thresh=1, refrac=0, sum_input=last, **kwargs ) - bias = ( - current.bias - if current.bias is not None - else torch.zeros(layer.shape[1]) - ) + bias = current.bias if current.bias is not None else torch.zeros(layer.shape[1]) connection = topology.Conv2dConnection( source=prev, target=layer, @@ -259,9 +247,7 @@ def _ann_to_snn_helper(prev, current, node_type, last=False, **kwargs): ] ) - connection = PermuteConnection( - source=prev, target=layer, dims=current.dims - ) + connection = PermuteConnection(source=prev, target=layer, dims=current.dims) elif isinstance(current, nn.ConstantPad2d): layer = PassThroughNodes( @@ -317,9 +303,7 @@ def ann_to_snn( if data is None: import warnings - warnings.warn( - "Data is None. Weights will not be scaled.", RuntimeWarning - ) + warnings.warn("Data is None. Weights will not be scaled.", RuntimeWarning) else: ann = data_based_normalization( ann=ann, data=data.detach(), percentile=percentile @@ -342,9 +326,7 @@ def ann_to_snn( prev = input_layer while i < len(children) - 1: current, nxt = children[i : i + 2] - layer, connection = _ann_to_snn_helper( - prev, current, node_type, **kwargs - ) + layer, connection = _ann_to_snn_helper(prev, current, node_type, **kwargs) i += 1 diff --git a/bindsnet/conversion/nodes.py b/bindsnet/conversion/nodes.py index 37a573b..27b1d1c 100644 --- a/bindsnet/conversion/nodes.py +++ b/bindsnet/conversion/nodes.py @@ -115,12 +115,8 @@ def set_batch_size(self, batch_size) -> None: :param batch_size: Mini-batch size. """ super().set_batch_size(batch_size=batch_size) - self.v = self.reset * torch.ones( - batch_size, *self.shape, device=self.v.device - ) - self.refrac_count = torch.zeros_like( - self.v, device=self.refrac_count.device - ) + self.v = self.reset * torch.ones(batch_size, *self.shape, device=self.v.device) + self.refrac_count = torch.zeros_like(self.v, device=self.refrac_count.device) class PassThroughNodes(nodes.Nodes): diff --git a/bindsnet/datasets/alov300.py b/bindsnet/datasets/alov300.py index 658e9d6..27e2b89 100644 --- a/bindsnet/datasets/alov300.py +++ b/bindsnet/datasets/alov300.py @@ -99,9 +99,7 @@ def _parse_data(self, root, target_dir): f = open(vid_ann, "r") annotations = f.readlines() f.close() - frame_idxs = [ - int(ann.split(" ")[0]) - 1 for ann in annotations - ] + frame_idxs = [int(ann.split(" ")[0]) - 1 for ann in annotations] frames = np.array(frames) num_anno += len(annotations) for i in range(len(frame_idxs) - 1): @@ -130,9 +128,7 @@ def get_sample(self, idx): curr_img = self.get_orig_sample(idx, 1)["image"] currbb = self.get_orig_sample(idx, 1)["bb"] prevbb = self.get_orig_sample(idx, 0)["bb"] - bbox_curr_shift = BoundingBox( - prevbb[0], prevbb[1], prevbb[2], prevbb[3] - ) + bbox_curr_shift = BoundingBox(prevbb[0], prevbb[1], prevbb[2], prevbb[3]) ( rand_search_region, rand_search_location, @@ -142,10 +138,7 @@ def get_sample(self, idx): bbox_curr_gt = BoundingBox(currbb[0], currbb[1], currbb[2], currbb[3]) bbox_gt_recentered = BoundingBox(0, 0, 0, 0) bbox_gt_recentered = bbox_curr_gt.recenter( - rand_search_location, - edge_spacing_x, - edge_spacing_y, - bbox_gt_recentered, + rand_search_location, edge_spacing_x, edge_spacing_y, bbox_gt_recentered, ) curr_sample["image"] = rand_search_region curr_sample["bb"] = bbox_gt_recentered.get_bb_list() @@ -208,9 +201,7 @@ def show(self, idx, is_current=1): image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) bb = sample["bb"] bb = [int(val) for val in bb] - image = cv2.rectangle( - image, (bb[0], bb[1]), (bb[2], bb[3]), (0, 255, 0), 2 - ) + image = cv2.rectangle(image, (bb[0], bb[1]), (bb[2], bb[3]), (0, 255, 0), 2) cv2.imshow("alov dataset sample: " + str(idx), image) cv2.waitKey(0) @@ -270,18 +261,12 @@ def _download(self): # Grabs the correct zip url based on parameters self.frame_zip_path = os.path.join(self.root, "frame.zip") self.text_zip_path = os.path.join(self.root, "text.zip") - frame_zip_url = ( - f"http://isis-data.science.uva.nl/alov/alov300++_frames.zip" - ) - text_zip_url = ( - f"http://isis-data.science.uva.nl/alov/alov300++GT_txtFiles.zip" - ) + frame_zip_url = f"http://isis-data.science.uva.nl/alov/alov300++_frames.zip" + text_zip_url = f"http://isis-data.science.uva.nl/alov/alov300++GT_txtFiles.zip" # Downloads the relevant dataset print("\nDownloading ALOV300++ frame set from " + frame_zip_url + "\n") - urlretrieve( - frame_zip_url, self.frame_zip_path, reporthook=self.progress - ) + urlretrieve(frame_zip_url, self.frame_zip_path, reporthook=self.progress) print("\nDownloading ALOV300++ text set from " + text_zip_url + "\n") urlretrieve(text_zip_url, self.text_zip_path, reporthook=self.progress) @@ -300,9 +285,7 @@ def _download(self): os.remove(self.text_zip_path) # Renames the folders containing the dataset - box_folder = os.path.join( - self.root, "alov300++_rectangleAnnotation_full/" - ) + box_folder = os.path.join(self.root, "alov300++_rectangleAnnotation_full/") frame_folder = os.path.join(self.root, "imagedata++") os.rename(box_folder, self.box_path) diff --git a/bindsnet/datasets/collate.py b/bindsnet/datasets/collate.py index 7ca4c6c..1707973 100644 --- a/bindsnet/datasets/collate.py +++ b/bindsnet/datasets/collate.py @@ -64,9 +64,7 @@ def time_aware_collate(batch): is not None ): raise TypeError( - pytorch_collate.default_collate_err_msg_format.format( - elem.dtype - ) + pytorch_collate.default_collate_err_msg_format.format(elem.dtype) ) return time_aware_collate([torch.as_tensor(b) for b in batch]) @@ -79,17 +77,11 @@ def time_aware_collate(batch): elif isinstance(elem, string_classes): return batch elif isinstance(elem, container_abcs.Mapping): - return { - key: time_aware_collate([d[key] for d in batch]) for key in elem - } + return {key: time_aware_collate([d[key] for d in batch]) for key in elem} elif isinstance(elem, tuple) and hasattr(elem, "_fields"): # namedtuple - return elem_type( - *(time_aware_collate(samples) for samples in zip(*batch)) - ) + return elem_type(*(time_aware_collate(samples) for samples in zip(*batch))) elif isinstance(elem, container_abcs.Sequence): transposed = zip(*batch) return [time_aware_collate(samples) for samples in transposed] - raise TypeError( - pytorch_collate.default_collate_err_msg_format.format(elem_type) - ) + raise TypeError(pytorch_collate.default_collate_err_msg_format.format(elem_type)) diff --git a/bindsnet/datasets/davis.py b/bindsnet/datasets/davis.py index af5f5c5..2360bc3 100644 --- a/bindsnet/datasets/davis.py +++ b/bindsnet/datasets/davis.py @@ -52,9 +52,7 @@ def __init__( if subset not in self.SUBSET_OPTIONS: raise ValueError(f"Subset should be in {self.SUBSET_OPTIONS}") if task not in self.TASKS: - raise ValueError( - f"The only tasks that are supported are {self.TASKS}" - ) + raise ValueError(f"The only tasks that are supported are {self.TASKS}") if resolution not in self.RESOLUTION_OPTIONS: raise ValueError( f"You may only use one of these resolutions: {self.RESOLUTION_OPTIONS}" @@ -89,13 +87,9 @@ def __init__( self.zip_path = os.path.join(self.root, "repo.zip") self.img_path = os.path.join(self.root, "JPEGImages", resolution) annotations_folder = ( - "Annotations" - if task == "semi-supervised" - else "Annotations_unsupervised" - ) - self.mask_path = os.path.join( - self.root, annotations_folder, resolution + "Annotations" if task == "semi-supervised" else "Annotations_unsupervised" ) + self.mask_path = os.path.join(self.root, annotations_folder, resolution) year = ( "2019" if task == "unsupervised" @@ -134,17 +128,11 @@ def __init__( # Sets the images and masks for each sequence resizing for the given size for seq in self.sequences_names: - images = np.sort( - glob(os.path.join(self.img_path, seq, "*.jpg")) - ).tolist() + images = np.sort(glob(os.path.join(self.img_path, seq, "*.jpg"))).tolist() if len(images) == 0 and not self.codalab: - raise FileNotFoundError( - f"Images for sequence {seq} not found." - ) + raise FileNotFoundError(f"Images for sequence {seq} not found.") self.sequences[seq]["images"] = images - masks = np.sort( - glob(os.path.join(self.mask_path, seq, "*.png")) - ).tolist() + masks = np.sort(glob(os.path.join(self.mask_path, seq, "*.png"))).tolist() masks.extend([-1] * (len(images) - len(masks))) self.sequences[seq]["masks"] = masks @@ -168,9 +156,7 @@ def _convert_sequences(self): Creates a new root for the dataset to be converted and placed into, then copies each image and mask into the given size and stores correctly. """ - os.makedirs( - os.path.join(self.converted_imagesets_path, f"{self.subset}.txt") - ) + os.makedirs(os.path.join(self.converted_imagesets_path, f"{self.subset}.txt")) os.makedirs(self.converted_img_path) os.makedirs(self.converted_mask_path) @@ -183,34 +169,24 @@ def _convert_sequences(self): for seq in tqdm(self.sequences_names): os.makedirs(os.path.join(self.converted_img_path, seq)) os.makedirs(os.path.join(self.converted_mask_path, seq)) - images = np.sort( - glob(os.path.join(self.img_path, seq, "*.jpg")) - ).tolist() + images = np.sort(glob(os.path.join(self.img_path, seq, "*.jpg"))).tolist() if len(images) == 0 and not self.codalab: - raise FileNotFoundError( - f"Images for sequence {seq} not found." - ) + raise FileNotFoundError(f"Images for sequence {seq} not found.") for ind, img in enumerate(images): im = Image.open(img) im.thumbnail(self.size, Image.ANTIALIAS) im.save( os.path.join( - self.converted_img_path, - seq, - str(ind).zfill(5) + ".jpg", + self.converted_img_path, seq, str(ind).zfill(5) + ".jpg", ) ) - masks = np.sort( - glob(os.path.join(self.mask_path, seq, "*.png")) - ).tolist() + masks = np.sort(glob(os.path.join(self.mask_path, seq, "*.png"))).tolist() for ind, msk in enumerate(masks): im = Image.open(msk) im.thumbnail(self.size, Image.ANTIALIAS) im.convert("RGB").save( os.path.join( - self.converted_mask_path, - seq, - str(ind).zfill(5) + ".png", + self.converted_mask_path, seq, str(ind).zfill(5) + ".png", ) ) @@ -231,16 +207,12 @@ def _check_directories(self): f"DAVIS not found in the specified directory, download it from " f"{self.DATASET_WEB} or add download=True to your call" ) - if not os.path.exists( - os.path.join(self.imagesets_path, f"{self.subset}.txt") - ): + if not os.path.exists(os.path.join(self.imagesets_path, f"{self.subset}.txt")): raise FileNotFoundError( f"Subset sequences list for {self.subset} not found, download the " f"missing subset for the {self.task} task from {self.DATASET_WEB}" ) - if self.subset in ["train", "val"] and not os.path.exists( - self.mask_path - ): + if self.subset in ["train", "val"] and not os.path.exists(self.mask_path): raise FileNotFoundError( f"Annotations folder for the {self.task} task not found, " f"download it from {self.DATASET_WEB}" @@ -254,8 +226,7 @@ def _check_directories(self): def get_frames(self, sequence): for img, msk in zip( - self.sequences[sequence]["images"], - self.sequences[sequence]["masks"], + self.sequences[sequence]["images"], self.sequences[sequence]["masks"], ): image = np.array(Image.open(img)) mask = None if msk is None else np.array(Image.open(msk)) @@ -263,9 +234,7 @@ def get_frames(self, sequence): def _get_all_elements(self, sequence, obj_type): obj = np.array(Image.open(self.sequences[sequence][obj_type][0])) - all_objs = np.zeros( - (len(self.sequences[sequence][obj_type]), *obj.shape) - ) + all_objs = np.zeros((len(self.sequences[sequence][obj_type]), *obj.shape)) obj_id = [] for i, obj in enumerate(self.sequences[sequence][obj_type]): all_objs[i, ...] = np.array(Image.open(obj)) diff --git a/bindsnet/datasets/preprocess.py b/bindsnet/datasets/preprocess.py index 96328bd..7bcf3ce 100644 --- a/bindsnet/datasets/preprocess.py +++ b/bindsnet/datasets/preprocess.py @@ -129,10 +129,7 @@ def shift_crop_training_sample(sample, bb_params): bbox_curr_gt = BoundingBox(currbb[0], currbb[1], currbb[2], currbb[3]) bbox_gt_recentered = BoundingBox(0, 0, 0, 0) bbox_gt_recentered = bbox_curr_gt.recenter( - rand_search_location, - edge_spacing_x, - edge_spacing_y, - bbox_gt_recentered, + rand_search_location, edge_spacing_x, edge_spacing_y, bbox_gt_recentered, ) output_sample["image"] = rand_search_region output_sample["bb"] = bbox_gt_recentered.get_bb_list() @@ -155,12 +152,9 @@ def crop_sample(sample): opts = {} image, bb = sample["image"], sample["bb"] orig_bbox = BoundingBox(bb[0], bb[1], bb[2], bb[3]) - ( - output_image, - pad_image_location, - edge_spacing_x, - edge_spacing_y, - ) = cropPadImage(orig_bbox, image) + (output_image, pad_image_location, edge_spacing_x, edge_spacing_y,) = cropPadImage( + orig_bbox, image + ) new_bbox = BoundingBox(0, 0, 0, 0) new_bbox = new_bbox.recenter( pad_image_location, edge_spacing_x, edge_spacing_y, new_bbox @@ -195,13 +189,10 @@ def cropPadImage(bbox_tight, image): int(roi_left + err) : int(roi_left + roi_width), ] output_width = max(math.ceil(bbox_tight.compute_output_width()), roi_width) - output_height = max( - math.ceil(bbox_tight.compute_output_height()), roi_height - ) + output_height = max(math.ceil(bbox_tight.compute_output_height()), roi_height) if image.ndim > 2: output_image = np.zeros( - (int(output_height), int(output_width), image.shape[2]), - dtype=image.dtype, + (int(output_height), int(output_width), image.shape[2]), dtype=image.dtype, ) else: output_image = np.zeros( @@ -285,11 +276,7 @@ def print_bb(self): print("------Bounding-box-------") print("(x1, y1): ({}, {})".format(self.x1, self.y1)) print("(x2, y2): ({}, {})".format(self.x2, self.y2)) - print( - "(w, h) : ({}, {})".format( - self.x2 - self.x1 + 1, self.y2 - self.y1 + 1 - ) - ) + print("(w, h) : ({}, {})".format(self.x2 - self.x1 + 1, self.y2 - self.y1 + 1)) print("--------------------------") def get_bb_list(self): @@ -339,21 +326,13 @@ def unscale(self, image): self.y1 = self.y1 * height self.y2 = self.y2 * height - def uncenter( - self, raw_image, search_location, edge_spacing_x, edge_spacing_y - ): + def uncenter(self, raw_image, search_location, edge_spacing_x, edge_spacing_y): self.x1 = max(0.0, self.x1 + search_location.x1 - edge_spacing_x) self.y1 = max(0.0, self.y1 + search_location.y1 - edge_spacing_y) - self.x2 = min( - raw_image.shape[1], self.x2 + search_location.x1 - edge_spacing_x - ) - self.y2 = min( - raw_image.shape[0], self.y2 + search_location.y1 - edge_spacing_y - ) + self.x2 = min(raw_image.shape[1], self.x2 + search_location.x1 - edge_spacing_x) + self.y2 = min(raw_image.shape[0], self.y2 + search_location.y1 - edge_spacing_y) - def recenter( - self, search_loc, edge_spacing_x, edge_spacing_y, bbox_gt_recentered - ): + def recenter(self, search_loc, edge_spacing_x, edge_spacing_y, bbox_gt_recentered): bbox_gt_recentered.x1 = self.x1 - search_loc.x1 + edge_spacing_x bbox_gt_recentered.y1 = self.y1 - search_loc.y1 + edge_spacing_y bbox_gt_recentered.x2 = self.x2 - search_loc.x1 + edge_spacing_x @@ -406,14 +385,11 @@ def shift( ): if shift_motion_model: width_scale_factor = max( - min_scale, - min(max_scale, sample_exp_two_sides(lambda_scale_frac)), + min_scale, min(max_scale, sample_exp_two_sides(lambda_scale_frac)), ) else: rand_num = sample_rand_uniform() - width_scale_factor = ( - rand_num * (max_scale - min_scale) + min_scale - ) + width_scale_factor = rand_num * (max_scale - min_scale) + min_scale new_width = width * (1 + width_scale_factor) new_width = max(1.0, min((image.shape[1] - 1), new_width)) @@ -426,14 +402,11 @@ def shift( ): if shift_motion_model: height_scale_factor = max( - min_scale, - min(max_scale, sample_exp_two_sides(lambda_scale_frac)), + min_scale, min(max_scale, sample_exp_two_sides(lambda_scale_frac)), ) else: rand_num = sample_rand_uniform() - height_scale_factor = ( - rand_num * (max_scale - min_scale) + min_scale - ) + height_scale_factor = rand_num * (max_scale - min_scale) + min_scale new_height = height * (1 + height_scale_factor) new_height = max(1.0, min((image.shape[0] - 1), new_height)) @@ -452,9 +425,7 @@ def shift( ) and (num_tries_x < kMaxNumTries): if shift_motion_model: - new_x_temp = center_x + width * sample_exp_two_sides( - lambda_shift_frac - ) + new_x_temp = center_x + width * sample_exp_two_sides(lambda_shift_frac) else: rand_num = sample_rand_uniform() new_x_temp = center_x + rand_num * (2 * new_width) - new_width @@ -478,18 +449,13 @@ def shift( ) and (num_tries_y < kMaxNumTries): if shift_motion_model: - new_y_temp = center_y + height * sample_exp_two_sides( - lambda_shift_frac - ) + new_y_temp = center_y + height * sample_exp_two_sides(lambda_shift_frac) else: rand_num = sample_rand_uniform() - new_y_temp = ( - center_y + rand_num * (2 * new_height) - new_height - ) + new_y_temp = center_y + rand_num * (2 * new_height) - new_height new_center_y = min( - image.shape[0] - new_height / 2, - max(new_height / 2, new_y_temp), + image.shape[0] - new_height / 2, max(new_height / 2, new_y_temp), ) first_time_y = False num_tries_y = num_tries_y + 1 diff --git a/bindsnet/datasets/spoken_mnist.py b/bindsnet/datasets/spoken_mnist.py index e070e30..ab1cf66 100644 --- a/bindsnet/datasets/spoken_mnist.py +++ b/bindsnet/datasets/spoken_mnist.py @@ -27,9 +27,7 @@ class SpokenMNIST(torch.utils.data.Dataset): for digit in range(10): for speaker in ["jackson", "nicolas", "theo"]: for example in range(50): - files.append( - "_".join([str(digit), speaker, str(example)]) + ".wav" - ) + files.append("_".join([str(digit), speaker, str(example)]) + ".wav") n_files = len(files) @@ -82,9 +80,7 @@ def __getitem__(self, ind): return {"audio": audio, "label": label} - def _get_train( - self, split: float = 0.8 - ) -> Tuple[torch.Tensor, torch.Tensor]: + def _get_train(self, split: float = 0.8) -> Tuple[torch.Tensor, torch.Tensor]: # language=rst """ Gets the Spoken MNIST training audio and labels. @@ -93,22 +89,16 @@ def _get_train( :return: Spoken MNIST training audio and labels. """ split_index = int(split * SpokenMNIST.n_files) - path = os.path.join( - self.path, "_".join([SpokenMNIST.train_pickle, str(split)]) - ) + path = os.path.join(self.path, "_".join([SpokenMNIST.train_pickle, str(split)])) - if not all( - [os.path.isfile(os.path.join(self.path, f)) for f in self.files] - ): + if not all([os.path.isfile(os.path.join(self.path, f)) for f in self.files]): # Download data if it isn't on disk. if self.download: print("Downloading Spoken MNIST data.\n") self._download() # Process data into audio, label (input, output) pairs. - audio, labels = self.process_data( - SpokenMNIST.files[:split_index] - ) + audio, labels = self.process_data(SpokenMNIST.files[:split_index]) # Serialize image data on disk for next time. torch.save((audio, labels), open(path, "wb")) @@ -138,9 +128,7 @@ def _get_train( return audio, torch.Tensor(labels) - def _get_test( - self, split: float = 0.8 - ) -> Tuple[torch.Tensor, List[torch.Tensor]]: + def _get_test(self, split: float = 0.8) -> Tuple[torch.Tensor, List[torch.Tensor]]: # language=rst """ Gets the Spoken MNIST training audio and labels. @@ -149,22 +137,16 @@ def _get_test( :return: The Spoken MNIST test audio and labels. """ split_index = int(split * SpokenMNIST.n_files) - path = os.path.join( - self.path, "_".join([SpokenMNIST.test_pickle, str(split)]) - ) + path = os.path.join(self.path, "_".join([SpokenMNIST.test_pickle, str(split)])) - if not all( - [os.path.isfile(os.path.join(self.path, f)) for f in self.files] - ): + if not all([os.path.isfile(os.path.join(self.path, f)) for f in self.files]): # Download data if it isn't on disk. if self.download: print("Downloading Spoken MNIST data.\n") self._download() # Process data into audio, label (input, output) pairs. - audio, labels = self.process_data( - SpokenMNIST.files[split_index:] - ) + audio, labels = self.process_data(SpokenMNIST.files[split_index:]) # Serialize image data on disk for next time. torch.save((audio, labels), open(path, "wb")) @@ -202,9 +184,7 @@ def _download(self) -> None: z.extractall(path=self.path) z.close() - path = os.path.join( - self.path, "free-spoken-digit-dataset-master", "recordings" - ) + path = os.path.join(self.path, "free-spoken-digit-dataset-master", "recordings") for f in os.listdir(path): shutil.move(os.path.join(path, f), os.path.join(self.path)) @@ -249,9 +229,7 @@ def process_data( # Make sure that we have at least 1 frame num_frames = int( - np.ceil( - float(np.abs(signal_length - frame_length)) / frame_step - ) + np.ceil(float(np.abs(signal_length - frame_length)) / frame_step) ) pad_signal_length = num_frames * frame_step + frame_length @@ -272,9 +250,7 @@ def process_data( # Fast Fourier Transform and Power Spectrum NFFT = 512 - mag_frames = np.absolute( - np.fft.rfft(frames, NFFT) - ) # Magnitude of the FFT + mag_frames = np.absolute(np.fft.rfft(frames, NFFT)) # Magnitude of the FFT pow_frames = (1.0 / NFFT) * (mag_frames ** 2) # Power Spectrum # Log filter banks @@ -286,9 +262,7 @@ def process_data( mel_points = np.linspace( low_freq_mel, high_freq_mel, nfilt + 2 ) # Equally spaced in Mel scale - hz_points = 700 * ( - 10 ** (mel_points / 2595) - 1 - ) # Convert Mel to Hz + hz_points = 700 * (10 ** (mel_points / 2595) - 1) # Convert Mel to Hz bin = np.floor((NFFT + 1) * hz_points / sample_rate) fbank = np.zeros((nfilt, int(np.floor(NFFT / 2 + 1)))) diff --git a/bindsnet/encoding/encoders.py b/bindsnet/encoding/encoders.py index 6cfd8fc..5014875 100644 --- a/bindsnet/encoding/encoders.py +++ b/bindsnet/encoding/encoders.py @@ -35,9 +35,7 @@ def __call__(self, img): class SingleEncoder(Encoder): - def __init__( - self, time: int, dt: float = 1.0, sparsity: float = 0.5, **kwargs - ): + def __init__(self, time: int, dt: float = 1.0, sparsity: float = 0.5, **kwargs): # language=rst """ Creates a callable SingleEncoder which encodes as defined in diff --git a/bindsnet/encoding/encodings.py b/bindsnet/encoding/encodings.py index 4f59eb5..76c6177 100644 --- a/bindsnet/encoding/encodings.py +++ b/bindsnet/encoding/encodings.py @@ -5,11 +5,7 @@ def single( - datum: torch.Tensor, - time: int, - dt: float = 1.0, - sparsity: float = 0.5, - **kwargs + datum: torch.Tensor, time: int, dt: float = 1.0, sparsity: float = 0.5, **kwargs ) -> torch.Tensor: # language=rst """ @@ -33,9 +29,7 @@ def single( return torch.Tensor(s).byte() -def repeat( - datum: torch.Tensor, time: int, dt: float = 1.0, **kwargs -) -> torch.Tensor: +def repeat(datum: torch.Tensor, time: int, dt: float = 1.0, **kwargs) -> torch.Tensor: # language=rst """ :param datum: Repeats a tensor along a new dimension in the 0th position for @@ -70,9 +64,7 @@ def bernoulli( # Setting kwargs. max_prob = kwargs.get("max_prob", 1.0) - assert ( - 0 <= max_prob <= 1 - ), "Maximum firing probability must be in range [0, 1]" + assert 0 <= max_prob <= 1, "Maximum firing probability must be in range [0, 1]" assert (datum >= 0).all(), "Inputs must be non-negative" shape, size = datum.shape, datum.numel() @@ -96,9 +88,7 @@ def bernoulli( return spikes.byte() -def poisson( - datum: torch.Tensor, time: int, dt: float = 1.0, **kwargs -) -> torch.Tensor: +def poisson(datum: torch.Tensor, time: int, dt: float = 1.0, **kwargs) -> torch.Tensor: # language=rst """ Generates Poisson-distributed spike trains based on input intensity. Inputs must be diff --git a/bindsnet/environment/environment.py b/bindsnet/environment/environment.py index 8480388..fba9d76 100644 --- a/bindsnet/environment/environment.py +++ b/bindsnet/environment/environment.py @@ -64,9 +64,7 @@ class GymEnvironment(Environment): A wrapper around the OpenAI ``gym`` environments. """ - def __init__( - self, name: str, encoder: Encoder = NullEncoder(), **kwargs - ) -> None: + def __init__(self, name: str, encoder: Encoder = NullEncoder(), **kwargs) -> None: # language=rst """ Initializes the environment wrapper. This class makes the @@ -105,9 +103,7 @@ def __init__( if self.history_length is not None and self.delta is not None: self.history = { i: torch.Tensor() - for i in range( - 1, self.history_length * self.delta + 1, self.delta - ) + for i in range(1, self.history_length * self.delta + 1, self.delta) } else: self.history = {} @@ -212,9 +208,7 @@ def preprocess(self) -> None: self.obs = self.obs[26:104, :] self.obs = binary_image(self.obs) elif self.name == "BreakoutDeterministic-v4": - self.obs = subsample( - gray_scale(crop(self.obs, 34, 194, 0, 160)), 80, 80 - ) + self.obs = subsample(gray_scale(crop(self.obs, 34, 194, 0, 160)), 80, 80) self.obs = binary_image(self.obs) else: # Default pre-processing step. pass @@ -260,6 +254,4 @@ def update_index(self) -> None: self.history_index += self.delta else: # Wrap around the history. - self.history_index = ( - self.history_index % max(self.history.keys()) - ) + 1 + self.history_index = (self.history_index % max(self.history.keys())) + 1 diff --git a/bindsnet/evaluation/evaluation.py b/bindsnet/evaluation/evaluation.py index 5e53e7a..e653ad4 100644 --- a/bindsnet/evaluation/evaluation.py +++ b/bindsnet/evaluation/evaluation.py @@ -76,9 +76,7 @@ def logreg_fit( return logreg -def logreg_predict( - spikes: torch.Tensor, logreg: LogisticRegression -) -> torch.Tensor: +def logreg_predict(spikes: torch.Tensor, logreg: LogisticRegression) -> torch.Tensor: # language=rst """ Predicts classes according to spike data summed over time. @@ -166,8 +164,7 @@ def proportion_weighting( # Compute layer-wise firing rate for this label. rates[:, i] += ( - torch.sum((proportions[:, i] * spikes)[:, indices], 1) - / n_assigns + torch.sum((proportions[:, i] * spikes)[:, indices], 1) / n_assigns ) # Predictions are arg-max of layer-wise firing rates. diff --git a/bindsnet/learning/learning.py b/bindsnet/learning/learning.py index e4a426b..736bfd0 100644 --- a/bindsnet/learning/learning.py +++ b/bindsnet/learning/learning.py @@ -76,9 +76,7 @@ def update(self) -> None: if ( self.connection.wmin != -np.inf or self.connection.wmax != np.inf ) and not isinstance(self, NoOp): - self.connection.w.clamp_( - self.connection.wmin, self.connection.wmax - ) + self.connection.w.clamp_(self.connection.wmin, self.connection.wmax) class NoOp(LearningRule): @@ -206,11 +204,7 @@ def _conv2d_connection_update(self, **kwargs) -> None: # Reshaping spike traces and spike occurrences. source_x = im2col_indices( - self.source.x, - kernel_height, - kernel_width, - padding=padding, - stride=stride, + self.source.x, kernel_height, kernel_width, padding=padding, stride=stride, ) target_x = self.target.x.view(batch_size, out_channels, -1) source_s = im2col_indices( @@ -227,18 +221,14 @@ def _conv2d_connection_update(self, **kwargs) -> None: pre = self.reduction( torch.bmm(target_x, source_s.permute((0, 2, 1))), dim=0 ) - self.connection.w -= self.nu[0] * pre.view( - self.connection.w.size() - ) + self.connection.w -= self.nu[0] * pre.view(self.connection.w.size()) # Post-synaptic update. if self.nu[1]: post = self.reduction( torch.bmm(target_s, source_x.permute((0, 2, 1))), dim=0 ) - self.connection.w += self.nu[1] * post.view( - self.connection.w.size() - ) + self.connection.w += self.nu[1] * post.view(self.connection.w.size()) super().update() @@ -278,9 +268,7 @@ def __init__( **kwargs ) - assert ( - self.source.traces - ), "Pre-synaptic nodes must record spike traces." + assert self.source.traces, "Pre-synaptic nodes must record spike traces." assert ( connection.wmin != -np.inf and connection.wmax != np.inf ), "Connection must define finite wmin and wmax." @@ -314,21 +302,13 @@ def _connection_update(self, **kwargs) -> None: # Pre-synaptic update. if self.nu[0]: - outer_product = self.reduction( - torch.bmm(source_s, target_x), dim=0 - ) - update -= ( - self.nu[0] * outer_product * (self.connection.w - self.wmin) - ) + outer_product = self.reduction(torch.bmm(source_s, target_x), dim=0) + update -= self.nu[0] * outer_product * (self.connection.w - self.wmin) # Post-synaptic update. if self.nu[1]: - outer_product = self.reduction( - torch.bmm(source_x, target_s), dim=0 - ) - update += ( - self.nu[1] * outer_product * (self.wmax - self.connection.w) - ) + outer_product = self.reduction(torch.bmm(source_x, target_s), dim=0) + update += self.nu[1] * outer_product * (self.wmax - self.connection.w) self.connection.w += update @@ -352,11 +332,7 @@ def _conv2d_connection_update(self, **kwargs) -> None: # Reshaping spike traces and spike occurrences. source_x = im2col_indices( - self.source.x, - kernel_height, - kernel_width, - padding=padding, - stride=stride, + self.source.x, kernel_height, kernel_width, padding=padding, stride=stride, ) target_x = self.target.x.view(batch_size, out_channels, -1) source_s = im2col_indices( @@ -478,11 +454,7 @@ def _conv2d_connection_update(self, **kwargs) -> None: # Reshaping spike traces and spike occurrences. source_x = im2col_indices( - self.source.x, - kernel_height, - kernel_width, - padding=padding, - stride=stride, + self.source.x, kernel_height, kernel_width, padding=padding, stride=stride, ) target_x = self.target.x.view(batch_size, out_channels, -1) source_s = im2col_indices( @@ -495,15 +467,11 @@ def _conv2d_connection_update(self, **kwargs) -> None: target_s = self.target.s.view(batch_size, out_channels, -1).float() # Pre-synaptic update. - pre = self.reduction( - torch.bmm(target_x, source_s.permute((0, 2, 1))), dim=0 - ) + pre = self.reduction(torch.bmm(target_x, source_s.permute((0, 2, 1))), dim=0) self.connection.w += self.nu[0] * pre.view(self.connection.w.size()) # Post-synaptic update. - post = self.reduction( - torch.bmm(target_s, source_x.permute((0, 2, 1))), dim=0 - ) + post = self.reduction(torch.bmm(target_s, source_x.permute((0, 2, 1))), dim=0) self.connection.w += self.nu[1] * post.view(self.connection.w.size()) super().update() @@ -581,9 +549,7 @@ def _connection_update(self, **kwargs) -> None: if not hasattr(self, "p_minus"): self.p_minus = torch.zeros(batch_size, *self.target.shape) if not hasattr(self, "eligibility"): - self.eligibility = torch.zeros( - batch_size, *self.connection.w.shape - ) + self.eligibility = torch.zeros(batch_size, *self.connection.w.shape) # Reshape pre- and post-synaptic spikes. source_s = self.source.s.view(batch_size, -1).float() @@ -628,9 +594,7 @@ def _conv2d_connection_update(self, **kwargs) -> None: # Initialize eligibility. if not hasattr(self, "eligibility"): - self.eligibility = torch.zeros( - batch_size, *self.connection.w.shape - ) + self.eligibility = torch.zeros(batch_size, *self.connection.w.shape) # Parse keyword arguments. reward = kwargs["reward"] @@ -658,9 +622,7 @@ def _conv2d_connection_update(self, **kwargs) -> None: ) if not hasattr(self, "p_minus"): self.p_minus = torch.zeros(batch_size, *self.target.shape) - self.p_minus = self.p_minus.view( - batch_size, out_channels, -1 - ).float() + self.p_minus = self.p_minus.view(batch_size, out_channels, -1).float() # Reshaping spike occurrences. source_s = im2col_indices( @@ -775,9 +737,7 @@ def _connection_update(self, **kwargs) -> None: # Calculate value of eligibility trace based on the value # of the point eligibility value of the past timestep. - self.eligibility_trace *= torch.exp( - -self.connection.dt / self.tc_e_trace - ) + self.eligibility_trace *= torch.exp(-self.connection.dt / self.tc_e_trace) self.eligibility_trace += self.eligibility / self.tc_e_trace # Compute weight update. @@ -815,13 +775,9 @@ def _conv2d_connection_update(self, **kwargs) -> None: # Initialize eligibility and eligibility trace. if not hasattr(self, "eligibility"): - self.eligibility = torch.zeros( - batch_size, *self.connection.w.shape - ) + self.eligibility = torch.zeros(batch_size, *self.connection.w.shape) if not hasattr(self, "eligibility_trace"): - self.eligibility_trace = torch.zeros( - batch_size, *self.connection.w.shape - ) + self.eligibility_trace = torch.zeros(batch_size, *self.connection.w.shape) # Parse keyword arguments. reward = kwargs["reward"] @@ -830,15 +786,11 @@ def _conv2d_connection_update(self, **kwargs) -> None: # Calculate value of eligibility trace based on the value # of the point eligibility value of the past timestep. - self.eligibility_trace *= torch.exp( - -self.connection.dt / self.tc_e_trace - ) + self.eligibility_trace *= torch.exp(-self.connection.dt / self.tc_e_trace) # Compute weight update. update = reward * self.eligibility_trace - self.connection.w += ( - self.nu[0] * self.connection.dt * torch.sum(update, dim=0) - ) + self.connection.w += self.nu[0] * self.connection.dt * torch.sum(update, dim=0) out_channels, _, kernel_height, kernel_width = self.connection.w.size() padding, stride = self.connection.padding, self.connection.stride @@ -855,9 +807,7 @@ def _conv2d_connection_update(self, **kwargs) -> None: ) if not hasattr(self, "p_minus"): self.p_minus = torch.zeros(batch_size, *self.target.shape) - self.p_minus = self.p_minus.view( - batch_size, out_channels, -1 - ).float() + self.p_minus = self.p_minus.view(batch_size, out_channels, -1).float() # Reshaping spike occurrences. source_s = im2col_indices( @@ -868,9 +818,7 @@ def _conv2d_connection_update(self, **kwargs) -> None: stride=stride, ) target_s = ( - self.target.s.permute(1, 2, 3, 0) - .view(batch_size, out_channels, -1) - .float() + self.target.s.permute(1, 2, 3, 0).view(batch_size, out_channels, -1).float() ) # Update P^+ and P^- values. @@ -978,10 +926,7 @@ def _connection_update(self, **kwargs) -> None: self.eligibility_trace *= 1 - self.connection.dt / self.tc_e_trace self.eligibility_trace += ( target_s - - ( - target_s_prob - / (1.0 + self.tc_c / self.connection.dt * target_s_prob) - ) + - (target_s_prob / (1.0 + self.tc_c / self.connection.dt * target_s_prob)) ) * source_x[:, None] # Compute weight update. diff --git a/bindsnet/learning/reward.py b/bindsnet/learning/reward.py index 26c54d1..c70f91e 100644 --- a/bindsnet/learning/reward.py +++ b/bindsnet/learning/reward.py @@ -40,9 +40,7 @@ def __init__(self, **kwargs) -> None: Constructor for EMA reward prediction error. """ self.reward_predict = torch.tensor(0.0) # Predicted reward (per step). - self.reward_predict_episode = torch.tensor( - 0.0 - ) # Predicted reward per episode. + self.reward_predict_episode = torch.tensor(0.0) # Predicted reward per episode. self.rewards_predict_episode = ( [] ) # List of predicted rewards per episode (used for plotting). @@ -87,7 +85,6 @@ def update(self, **kwargs) -> None: 1 - 1 / ema_window ) * self.reward_predict + 1 / ema_window * reward self.reward_predict_episode = ( - (1 - 1 / ema_window) * self.reward_predict_episode - + 1 / ema_window * accumulated_reward - ) + 1 - 1 / ema_window + ) * self.reward_predict_episode + 1 / ema_window * accumulated_reward self.rewards_predict_episode.append(self.reward_predict_episode.item()) diff --git a/bindsnet/models/models.py b/bindsnet/models/models.py index 64e1cc6..d7d17dc 100644 --- a/bindsnet/models/models.py +++ b/bindsnet/models/models.py @@ -53,9 +53,7 @@ def __init__( self.n_neurons = n_neurons self.dt = dt - self.add_layer( - Input(n=self.n_inpt, traces=True, tc_trace=20.0), name="X" - ) + self.add_layer(Input(n=self.n_inpt, traces=True, tc_trace=20.0), name="X") self.add_layer( LIFNodes( n=self.n_neurons, @@ -395,9 +393,7 @@ def __init__( x1, y1 = i // self.n_sqrt, i % self.n_sqrt x2, y2 = j // self.n_sqrt, j % self.n_sqrt - inhib = self.start_inhib * np.sqrt( - euclidean([x1, y1], [x2, y2]) - ) + inhib = self.start_inhib * np.sqrt(euclidean([x1, y1], [x2, y2])) w[i, j] = -min(self.max_inhib, inhib) recurrent_output_conn = Connection( diff --git a/bindsnet/network/monitors.py b/bindsnet/network/monitors.py index 566ae49..372132e 100644 --- a/bindsnet/network/monitors.py +++ b/bindsnet/network/monitors.py @@ -112,17 +112,13 @@ def __init__( super().__init__() self.network = network - self.layers = ( - layers if layers is not None else list(self.network.layers.keys()) - ) + self.layers = layers if layers is not None else list(self.network.layers.keys()) self.connections = ( connections if connections is not None else list(self.network.connections.keys()) ) - self.state_vars = ( - state_vars if state_vars is not None else ("v", "s", "w") - ) + self.state_vars = state_vars if state_vars is not None else ("v", "s", "w") self.time = time if self.time is not None: @@ -149,15 +145,13 @@ def __init__( for l in self.layers: if hasattr(self.network.layers[l], v): self.recording[l][v] = torch.zeros( - self.time, - *getattr(self.network.layers[l], v).size() + self.time, *getattr(self.network.layers[l], v).size() ) for c in self.connections: if hasattr(self.network.connections[c], v): self.recording[c][v] = torch.zeros( - self.time, - *getattr(self.network.connections[c], v).size() + self.time, *getattr(self.network.connections[c], v).size() ) def get(self) -> Dict[str, Dict[str, Union[Nodes, AbstractConnection]]]: @@ -180,20 +174,14 @@ def record(self) -> None: for v in self.state_vars: for l in self.layers: if hasattr(self.network.layers[l], v): - data = ( - getattr(self.network.layers[l], v) - .unsqueeze(0) - .float() - ) + data = getattr(self.network.layers[l], v).unsqueeze(0).float() self.recording[l][v] = torch.cat( (self.recording[l][v], data), 0 ) for c in self.connections: if hasattr(self.network.connections[c], v): - data = getattr( - self.network.connections[c], v - ).unsqueeze(0) + data = getattr(self.network.connections[c], v).unsqueeze(0) self.recording[c][v] = torch.cat( (self.recording[c][v], data), 0 ) @@ -202,24 +190,16 @@ def record(self) -> None: for v in self.state_vars: for l in self.layers: if hasattr(self.network.layers[l], v): - data = ( - getattr(self.network.layers[l], v) - .float() - .unsqueeze(0) - ) + data = getattr(self.network.layers[l], v).float().unsqueeze(0) self.recording[l][v] = torch.cat( - (self.recording[l][v][1:].type(data.type()), data), - 0, + (self.recording[l][v][1:].type(data.type()), data), 0, ) for c in self.connections: if hasattr(self.network.connections[c], v): - data = getattr( - self.network.connections[c], v - ).unsqueeze(0) + data = getattr(self.network.connections[c], v).unsqueeze(0) self.recording[c][v] = torch.cat( - (self.recording[c][v][1:].type(data.type()), data), - 0, + (self.recording[c][v][1:].type(data.type()), data), 0, ) self.i += 1 @@ -290,13 +270,11 @@ def reset_state_variables(self) -> None: for l in self.layers: if hasattr(self.network.layers[l], v): self.recording[l][v] = torch.zeros( - self.time, - *getattr(self.network.layers[l], v).size() + self.time, *getattr(self.network.layers[l], v).size() ) for c in self.connections: if hasattr(self.network.connections[c], v): self.recording[c][v] = torch.zeros( - self.time, - *getattr(self.network.layers[c], v).size() + self.time, *getattr(self.network.layers[c], v).size() ) diff --git a/bindsnet/network/network.py b/bindsnet/network/network.py index 3438e6b..0b7f736 100644 --- a/bindsnet/network/network.py +++ b/bindsnet/network/network.py @@ -9,9 +9,7 @@ from ..learning.reward import AbstractReward -def load( - file_name: str, map_location: str = "cpu", learning: bool = None -) -> "Network": +def load(file_name: str, map_location: str = "cpu", learning: bool = None) -> "Network": # language=rst """ Loads serialized network object from disk. @@ -238,11 +236,7 @@ def _get_inputs(self, layers: Iterable = None) -> Dict[str, torch.Tensor]: return inputs def run( - self, - inputs: Dict[str, torch.Tensor], - time: int, - one_step=False, - **kwargs + self, inputs: Dict[str, torch.Tensor], time: int, one_step=False, **kwargs ) -> None: # language=rst """ diff --git a/bindsnet/network/nodes.py b/bindsnet/network/nodes.py index 3f577c6..b4c9a7b 100644 --- a/bindsnet/network/nodes.py +++ b/bindsnet/network/nodes.py @@ -59,8 +59,8 @@ def __init__( self.traces = traces # Whether to record synaptic traces. self.traces_additive = ( - traces_additive - ) # Whether to record spike traces additively. + traces_additive # Whether to record spike traces additively. + ) self.register_buffer("s", torch.ByteTensor()) # Spike occurrences. self.sum_input = sum_input # Whether to sum all inputs. @@ -79,9 +79,7 @@ def __init__( ) # Set in compute_decays. if self.sum_input: - self.register_buffer( - "summed", torch.FloatTensor() - ) # Summed inputs. + self.register_buffer("summed", torch.FloatTensor()) # Summed inputs. self.dt = None self.batch_size = None @@ -414,12 +412,8 @@ def set_batch_size(self, batch_size) -> None: :param batch_size: Mini-batch size. """ super().set_batch_size(batch_size=batch_size) - self.v = self.reset * torch.ones( - batch_size, *self.shape, device=self.v.device - ) - self.refrac_count = torch.zeros_like( - self.v, device=self.refrac_count.device - ) + self.v = self.reset * torch.ones(batch_size, *self.shape, device=self.v.device) + self.refrac_count = torch.zeros_like(self.v, device=self.refrac_count.device) class LIFNodes(Nodes): @@ -557,12 +551,8 @@ def set_batch_size(self, batch_size) -> None: :param batch_size: Mini-batch size. """ super().set_batch_size(batch_size=batch_size) - self.v = self.rest * torch.ones( - batch_size, *self.shape, device=self.v.device - ) - self.refrac_count = torch.zeros_like( - self.v, device=self.refrac_count.device - ) + self.v = self.rest * torch.ones(batch_size, *self.shape, device=self.v.device) + self.refrac_count = torch.zeros_like(self.v, device=self.refrac_count.device) class CurrentLIFNodes(Nodes): @@ -620,12 +610,8 @@ def __init__( ) self.register_buffer("rest", torch.tensor(rest)) # Rest voltage. - self.register_buffer( - "reset", torch.tensor(reset) - ) # Post-spike reset voltage. - self.register_buffer( - "thresh", torch.tensor(thresh) - ) # Spike threshold voltage. + self.register_buffer("reset", torch.tensor(reset)) # Post-spike reset voltage. + self.register_buffer("thresh", torch.tensor(thresh)) # Spike threshold voltage. self.register_buffer( "refrac", torch.tensor(refrac) ) # Post-spike refractory period. @@ -643,9 +629,7 @@ def __init__( ) # Set in compute_decays. self.register_buffer("v", torch.FloatTensor()) # Neuron voltages. - self.register_buffer( - "i", torch.FloatTensor() - ) # Synaptic input currents. + self.register_buffer("i", torch.FloatTensor()) # Synaptic input currents. self.register_buffer( "refrac_count", torch.FloatTensor() ) # Refractory period counters. @@ -716,13 +700,9 @@ def set_batch_size(self, batch_size) -> None: :param batch_size: Mini-batch size. """ super().set_batch_size(batch_size=batch_size) - self.v = self.rest * torch.ones( - batch_size, *self.shape, device=self.v.device - ) + self.v = self.rest * torch.ones(batch_size, *self.shape, device=self.v.device) self.i = torch.zeros_like(self.v, device=self.i.device) - self.refrac_count = torch.zeros_like( - self.v, device=self.refrac_count.device - ) + self.refrac_count = torch.zeros_like(self.v, device=self.refrac_count.device) class AdaptiveLIFNodes(Nodes): @@ -782,12 +762,8 @@ def __init__( ) self.register_buffer("rest", torch.tensor(rest)) # Rest voltage. - self.register_buffer( - "reset", torch.tensor(reset) - ) # Post-spike reset voltage. - self.register_buffer( - "thresh", torch.tensor(thresh) - ) # Spike threshold voltage. + self.register_buffer("reset", torch.tensor(reset)) # Post-spike reset voltage. + self.register_buffer("thresh", torch.tensor(thresh)) # Spike threshold voltage. self.register_buffer( "refrac", torch.tensor(refrac) ) # Post-spike refractory period. @@ -808,9 +784,7 @@ def __init__( ) # Set in compute_decays. self.register_buffer("v", torch.FloatTensor()) # Neuron voltages. - self.register_buffer( - "theta", torch.zeros(*self.shape) - ) # Adaptive thresholds. + self.register_buffer("theta", torch.zeros(*self.shape)) # Adaptive thresholds. self.register_buffer( "refrac_count", torch.FloatTensor() ) # Refractory period counters. @@ -881,12 +855,8 @@ def set_batch_size(self, batch_size) -> None: :param batch_size: Mini-batch size. """ super().set_batch_size(batch_size=batch_size) - self.v = self.rest * torch.ones( - batch_size, *self.shape, device=self.v.device - ) - self.refrac_count = torch.zeros_like( - self.v, device=self.refrac_count.device - ) + self.v = self.rest * torch.ones(batch_size, *self.shape, device=self.v.device) + self.refrac_count = torch.zeros_like(self.v, device=self.refrac_count.device) class DiehlAndCookNodes(Nodes): @@ -948,12 +918,8 @@ def __init__( ) self.register_buffer("rest", torch.tensor(rest)) # Rest voltage. - self.register_buffer( - "reset", torch.tensor(reset) - ) # Post-spike reset voltage. - self.register_buffer( - "thresh", torch.tensor(thresh) - ) # Spike threshold voltage. + self.register_buffer("reset", torch.tensor(reset)) # Post-spike reset voltage. + self.register_buffer("thresh", torch.tensor(thresh)) # Spike threshold voltage. self.register_buffer( "refrac", torch.tensor(refrac) ) # Post-spike refractory period. @@ -973,9 +939,7 @@ def __init__( "theta_decay", torch.empty_like(self.tc_theta_decay) ) # Set in compute_decays. self.register_buffer("v", torch.FloatTensor()) # Neuron voltages. - self.register_buffer( - "theta", torch.zeros(*self.shape) - ) # Adaptive thresholds. + self.register_buffer("theta", torch.zeros(*self.shape)) # Adaptive thresholds. self.register_buffer( "refrac_count", torch.FloatTensor() ) # Refractory period counters. @@ -1059,12 +1023,8 @@ def set_batch_size(self, batch_size) -> None: :param batch_size: Mini-batch size. """ super().set_batch_size(batch_size=batch_size) - self.v = self.rest * torch.ones( - batch_size, *self.shape, device=self.v.device - ) - self.refrac_count = torch.zeros_like( - self.v, device=self.refrac_count.device - ) + self.v = self.rest * torch.ones(batch_size, *self.shape, device=self.v.device) + self.refrac_count = torch.zeros_like(self.v, device=self.refrac_count.device) class IzhikevichNodes(Nodes): @@ -1115,9 +1075,7 @@ def __init__( ) self.register_buffer("rest", torch.tensor(rest)) # Rest voltage. - self.register_buffer( - "thresh", torch.tensor(thresh) - ) # Spike threshold voltage. + self.register_buffer("thresh", torch.tensor(thresh)) # Spike threshold voltage. self.lbound = lbound self.register_buffer("r", None) @@ -1184,9 +1142,7 @@ def __init__( self.S[:, ex:] = -torch.rand(n, inh) self.excitatory[ex:] = 0 - self.register_buffer( - "v", self.rest * torch.ones(n) - ) # Neuron voltages. + self.register_buffer("v", self.rest * torch.ones(n)) # Neuron voltages. self.register_buffer("u", self.b * self.v) # Neuron recovery. def forward(self, x: torch.Tensor) -> None: @@ -1206,24 +1162,13 @@ def forward(self, x: torch.Tensor) -> None: # Add inter-columnar input. if self.s.any(): x += torch.cat( - [ - self.S[:, self.s[i]].sum(dim=1)[None] - for i in range(self.s.shape[0]) - ], + [self.S[:, self.s[i]].sum(dim=1)[None] for i in range(self.s.shape[0])], dim=0, ) # Apply v and u updates. - self.v += ( - self.dt - * 0.5 - * (0.04 * self.v ** 2 + 5 * self.v + 140 - self.u + x) - ) - self.v += ( - self.dt - * 0.5 - * (0.04 * self.v ** 2 + 5 * self.v + 140 - self.u + x) - ) + self.v += self.dt * 0.5 * (0.04 * self.v ** 2 + 5 * self.v + 140 - self.u + x) + self.v += self.dt * 0.5 * (0.04 * self.v ** 2 + 5 * self.v + 140 - self.u + x) self.u += self.dt * self.a * (self.b * self.v - self.u) # Voltage clipping to lower bound. @@ -1249,9 +1194,7 @@ def set_batch_size(self, batch_size) -> None: :param batch_size: Mini-batch size. """ super().set_batch_size(batch_size=batch_size) - self.v = self.rest * torch.ones( - batch_size, *self.shape, device=self.v.device - ) + self.v = self.rest * torch.ones(batch_size, *self.shape, device=self.v.device) self.u = self.b * self.v @@ -1314,21 +1257,15 @@ def __init__( ) self.register_buffer("rest", torch.tensor(rest)) # Rest voltage. - self.register_buffer( - "reset", torch.tensor(reset) - ) # Post-spike reset voltage. - self.register_buffer( - "thresh", torch.tensor(thresh) - ) # Spike threshold voltage. + self.register_buffer("reset", torch.tensor(reset)) # Post-spike reset voltage. + self.register_buffer("thresh", torch.tensor(thresh)) # Spike threshold voltage. self.register_buffer( "refrac", torch.tensor(refrac) ) # Post-spike refractory period. self.register_buffer( "tc_decay", torch.tensor(tc_decay) ) # Time constant of neuron voltage decay. - self.register_buffer( - "decay", torch.tensor(tc_decay) - ) # Set in compute_decays. + self.register_buffer("decay", torch.tensor(tc_decay)) # Set in compute_decays. self.register_buffer( "eps_0", torch.tensor(eps_0) ) # Scaling factor for pre-synaptic spike contributions. @@ -1360,9 +1297,7 @@ def forward(self, x: torch.Tensor) -> None: # Compute (instantaneous) probabilities of spiking, clamp between 0 and 1 using exponentials. # Also known as 'escape noise', this simulates nearby neurons. - self.rho = self.rho_0 * torch.exp( - (self.v - self.thresh) / self.d_thresh - ) + self.rho = self.rho_0 * torch.exp((self.v - self.thresh) / self.d_thresh) self.s_prob = 1.0 - torch.exp(-self.rho * self.dt) # Decrement refractory counters. @@ -1410,9 +1345,5 @@ def set_batch_size(self, batch_size) -> None: :param batch_size: Mini-batch size. """ super().set_batch_size(batch_size=batch_size) - self.v = self.rest * torch.ones( - batch_size, *self.shape, device=self.v.device - ) - self.refrac_count = torch.zeros_like( - self.v, device=self.refrac_count.device - ) + self.v = self.rest * torch.ones(batch_size, *self.shape, device=self.v.device) + self.refrac_count = torch.zeros_like(self.v, device=self.refrac_count.device) diff --git a/bindsnet/network/topology.py b/bindsnet/network/topology.py index 9aca368..76c223c 100644 --- a/bindsnet/network/topology.py +++ b/bindsnet/network/topology.py @@ -155,21 +155,15 @@ def __init__( w = kwargs.get("w", None) if w is None: if self.wmin == -np.inf or self.wmax == np.inf: - w = torch.clamp( - torch.rand(source.n, target.n), self.wmin, self.wmax - ) + w = torch.clamp(torch.rand(source.n, target.n), self.wmin, self.wmax) else: - w = self.wmin + torch.rand(source.n, target.n) * ( - self.wmax - self.wmin - ) + w = self.wmin + torch.rand(source.n, target.n) * (self.wmax - self.wmin) else: if self.wmin != -np.inf or self.wmax != np.inf: w = torch.clamp(w, self.wmin, self.wmax) self.w = Parameter(w, requires_grad=False) - self.b = Parameter( - kwargs.get("b", torch.zeros(target.n)), requires_grad=False - ) + self.b = Parameter(kwargs.get("b", torch.zeros(target.n)), requires_grad=False) def compute(self, s: torch.Tensor) -> torch.Tensor: # language=rst @@ -285,7 +279,7 @@ def __init__( "(input_height - filter_height + 2 * padding_height) / stride_height + 1," "(input_width - filter_width + 2 * padding_width) / stride_width + 1" ) - + assert ( target.shape[0] == shape[1] and target.shape[1] == shape[2] @@ -296,9 +290,7 @@ def __init__( if w is None: if self.wmin == -np.inf or self.wmax == np.inf: w = torch.clamp( - torch.rand( - self.out_channels, self.in_channels, *self.kernel_size - ), + torch.rand(self.out_channels, self.in_channels, *self.kernel_size), self.wmin, self.wmax, ) @@ -313,8 +305,7 @@ def __init__( self.w = Parameter(w, requires_grad=False) self.b = Parameter( - kwargs.get("b", torch.zeros(self.out_channels)), - requires_grad=False, + kwargs.get("b", torch.zeros(self.out_channels)), requires_grad=False, ) def compute(self, s: torch.Tensor) -> torch.Tensor: @@ -351,8 +342,7 @@ def normalize(self) -> None: if self.norm is not None: # get a view and modify in place w = self.w.view( - self.w.size(0) * self.w.size(1), - self.w.size(2) * self.w.size(3), + self.w.size(0) * self.w.size(1), self.w.size(2) * self.w.size(3), ) for fltr in range(w.size(0)): @@ -551,9 +541,7 @@ def __init__( ) locations[k1, k2, c1, c2] = location - self.register_buffer( - "locations", locations.view(kernel_prod, conv_prod) - ) + self.register_buffer("locations", locations.view(kernel_prod, conv_prod)) w = kwargs.get("w", None) if w is None: @@ -562,15 +550,13 @@ def __init__( for c in range(conv_prod): for k in range(kernel_prod): if self.wmin == -np.inf or self.wmax == np.inf: - w[ - self.locations[k, c], f * conv_prod + c - ] = np.clip(np.random.rand(), self.wmin, self.wmax) + w[self.locations[k, c], f * conv_prod + c] = np.clip( + np.random.rand(), self.wmin, self.wmax + ) else: w[ self.locations[k, c], f * conv_prod + c - ] = self.wmin + np.random.rand() * ( - self.wmax - self.wmin - ) + ] = self.wmin + np.random.rand() * (self.wmax - self.wmin) else: if self.wmin != -np.inf or self.wmax != np.inf: w = torch.clamp(w, self.wmin, self.wmax) @@ -579,9 +565,7 @@ def __init__( self.register_buffer("mask", self.w == 0) - self.b = Parameter( - kwargs.get("b", torch.zeros(target.n)), requires_grad=False - ) + self.b = Parameter(kwargs.get("b", torch.zeros(target.n)), requires_grad=False) if self.norm is not None: self.norm *= kernel_prod @@ -596,10 +580,7 @@ def compute(self, s: torch.Tensor) -> torch.Tensor: decaying spike activation). """ # Compute multiplication of pre-activations by connection weights. - if ( - self.w.shape[0] == self.source.n - and self.w.shape[1] == self.target.n - ): + if self.w.shape[0] == self.source.n and self.w.shape[1] == self.target.n: return s.float().view(s.size(0), -1) @ self.w + self.b else: a_post = ( @@ -677,13 +658,9 @@ def __init__( w = kwargs.get("w", None) if w is None: if self.wmin == -np.inf or self.wmax == np.inf: - w = torch.clamp( - (torch.randn(1)[0] + 1) / 10, self.wmin, self.wmax - ) + w = torch.clamp((torch.randn(1)[0] + 1) / 10, self.wmin, self.wmax) else: - w = self.wmin + ((torch.randn(1)[0] + 1) / 10) * ( - self.wmax - self.wmin - ) + w = self.wmin + ((torch.randn(1)[0] + 1) / 10) * (self.wmax - self.wmin) else: if self.wmin != -np.inf or self.wmax != np.inf: w = torch.clamp(w, self.wmin, self.wmax) @@ -786,14 +763,12 @@ def __init__( self.wmax, ) else: - v = self.wmin + torch.rand(*source.shape, *target.shape)[ - i.byte() - ] * (self.wmax - self.wmin) + v = self.wmin + torch.rand(*source.shape, *target.shape)[i.byte()] * ( + self.wmax - self.wmin + ) w = torch.sparse.FloatTensor(i.nonzero().t(), v) elif w is not None and self.sparsity is None: - assert ( - w.is_sparse - ), "Weight matrix is not sparse (see torch.sparse module)" + assert w.is_sparse, "Weight matrix is not sparse (see torch.sparse module)" if self.wmin != -np.inf or self.wmax != np.inf: w = torch.clamp(w, self.wmin, self.wmax) diff --git a/bindsnet/pipeline/action.py b/bindsnet/pipeline/action.py index ab12850..d158e29 100644 --- a/bindsnet/pipeline/action.py +++ b/bindsnet/pipeline/action.py @@ -21,9 +21,7 @@ def select_multinomial(pipeline: EnvironmentPipeline, **kwargs) -> int: try: output = kwargs["output"] except KeyError: - raise KeyError( - 'select_multinomial() requires an "output" layer argument.' - ) + raise KeyError('select_multinomial() requires an "output" layer argument.') output = pipeline.network.layers[output] action_space = pipeline.env.action_space @@ -46,9 +44,7 @@ def select_multinomial(pipeline: EnvironmentPipeline, **kwargs) -> int: for i in range(action_space.n) ] ) - action = torch.multinomial((pop_spikes.float() / _sum).view(-1), 1)[ - 0 - ].item() + action = torch.multinomial((pop_spikes.float() / _sum).view(-1), 1)[0].item() return action diff --git a/bindsnet/pipeline/base_pipeline.py b/bindsnet/pipeline/base_pipeline.py index b6f3444..b1026c3 100644 --- a/bindsnet/pipeline/base_pipeline.py +++ b/bindsnet/pipeline/base_pipeline.py @@ -74,9 +74,7 @@ def __init__(self, network: Network, **kwargs) -> None: for l in self.network.layers: self.network.add_monitor( Monitor( - self.network.layers[l], - "s", - self.plot_config["data_length"], + self.network.layers[l], "s", self.plot_config["data_length"], ), name=f"{l}_spikes", ) @@ -138,16 +136,10 @@ def step(self, batch: Any, **kwargs) -> Any: self.plots(batch, step_out) - if ( - self.save_interval is not None - and self.step_count % self.save_interval == 0 - ): + if self.save_interval is not None and self.step_count % self.save_interval == 0: self.network.save(self.save_dir) - if ( - self.test_interval is not None - and self.step_count % self.test_interval == 0 - ): + if self.test_interval is not None and self.step_count % self.test_interval == 0: self.test() return step_out @@ -165,7 +157,7 @@ def get_spike_data(self) -> Dict[str, torch.Tensor]: } def get_voltage_data( - self + self, ) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: # language=rst """ @@ -179,9 +171,7 @@ def get_voltage_data( threshold_value = {} for l in self.network.layers: if hasattr(self.network.layers[l], "v"): - voltage_record[l] = self.network.monitors[f"{l}_voltages"].get( - "v" - ) + voltage_record[l] = self.network.monitors[f"{l}_voltages"].get("v") if hasattr(self.network.layers[l], "thresh"): threshold_value[l] = self.network.layers[l].thresh diff --git a/bindsnet/pipeline/environment_pipeline.py b/bindsnet/pipeline/environment_pipeline.py index 1d2615b..21b847a 100644 --- a/bindsnet/pipeline/environment_pipeline.py +++ b/bindsnet/pipeline/environment_pipeline.py @@ -166,9 +166,7 @@ def step_( inputs = {k: obs.repeat(self.time, *obs_shape) for k in self.inputs} # Run the network on the spike train-encoded inputs. - self.network.run( - inputs=inputs, time=self.time, reward=reward, **kwargs - ) + self.network.run(inputs=inputs, time=self.time, reward=reward, **kwargs) if self.output is not None: self.spike_record[self.output] = ( @@ -194,9 +192,7 @@ def reset_state_variables(self) -> None: self.accumulated_reward = 0.0 self.step_count = 0 - def plots( - self, gym_batch: Tuple[torch.Tensor, float, bool, Dict], *args - ) -> None: + def plots(self, gym_batch: Tuple[torch.Tensor, float, bool, Dict], *args) -> None: # language=rst """ Plot the encoded input, layer spikes, and layer voltages. diff --git a/bindsnet/utils.py b/bindsnet/utils.py index 9d72bdc..dc55878 100644 --- a/bindsnet/utils.py +++ b/bindsnet/utils.py @@ -27,9 +27,7 @@ def im2col_indices( :param stride: Amount to stride over image by per convolution. :return: Input tensor reshaped to column-wise format. """ - return F.unfold( - x, (kernel_height, kernel_width), padding=padding, stride=stride - ) + return F.unfold(x, (kernel_height, kernel_width), padding=padding, stride=stride) def col2im_indices( @@ -53,11 +51,7 @@ def col2im_indices( :return: Image tensor in original image shape. """ return F.fold( - cols, - x_shape, - (kernel_height, kernel_width), - padding=padding, - stride=stride, + cols, x_shape, (kernel_height, kernel_width), padding=padding, stride=stride, ) @@ -157,13 +151,9 @@ def reshape_locally_connected_weights( n = n1 * c2 + n2 filter_ = w[ locations[:, n], - feature * (c1 * c2) - + (n // c2sqrt) * c2sqrt - + (n % c2sqrt), + feature * (c1 * c2) + (n // c2sqrt) * c2sqrt + (n % c2sqrt), ].view(k1, k2) - w_[ - feature * k1 : (feature + 1) * k1, n * k2 : (n + 1) * k2 - ] = filter_ + w_[feature * k1 : (feature + 1) * k1, n * k2 : (n + 1) * k2] = filter_ if c1 == 1 and c2 == 1: square = torch.zeros((i1 * fs, i2 * fs)) @@ -214,21 +204,16 @@ def reshape_conv2d_weights(weights: torch.Tensor) -> torch.Tensor: for j in range(sqrt1): for k in range(sqrt2): for l in range(sqrt2): - if i * sqrt1 + j < weights.size( - 0 - ) and k * sqrt2 + l < weights.size(1): - fltr = weights[i * sqrt1 + j, k * sqrt2 + l].view( - height, width - ) + if i * sqrt1 + j < weights.size(0) and k * sqrt2 + l < weights.size( + 1 + ): + fltr = weights[i * sqrt1 + j, k * sqrt2 + l].view(height, width) reshaped[ i * height + k * height * sqrt1 : (i + 1) * height + k * height * sqrt1, (j % sqrt1) * width - + (l % sqrt2) - * width - * sqrt1 : ((j % sqrt1) + 1) - * width + + (l % sqrt2) * width * sqrt1 : ((j % sqrt1) + 1) * width + (l % sqrt2) * width * sqrt1, ] = fltr diff --git a/examples/breakout/breakout.py b/examples/breakout/breakout.py index b36563c..29c8925 100644 --- a/examples/breakout/breakout.py +++ b/examples/breakout/breakout.py @@ -22,12 +22,8 @@ network.add_layer(inpt, name="Input Layer") network.add_layer(middle, name="Hidden Layer") network.add_layer(out, name="Output Layer") -network.add_connection( - inpt_middle, source="Input Layer", target="Hidden Layer" -) -network.add_connection( - middle_out, source="Hidden Layer", target="Output Layer" -) +network.add_connection(inpt_middle, source="Input Layer", target="Hidden Layer") +network.add_connection(middle_out, source="Hidden Layer", target="Output Layer") # Load the Breakout environment. environment = GymEnvironment("BreakoutDeterministic-v4") diff --git a/examples/breakout/breakout_stdp.py b/examples/breakout/breakout_stdp.py index 55584c1..6738d95 100644 --- a/examples/breakout/breakout_stdp.py +++ b/examples/breakout/breakout_stdp.py @@ -31,12 +31,8 @@ network.add_layer(inpt, name="Input Layer") network.add_layer(middle, name="Hidden Layer") network.add_layer(out, name="Output Layer") -network.add_connection( - inpt_middle, source="Input Layer", target="Hidden Layer" -) -network.add_connection( - middle_out, source="Hidden Layer", target="Output Layer" -) +network.add_connection(inpt_middle, source="Input Layer", target="Hidden Layer") +network.add_connection(middle_out, source="Hidden Layer", target="Output Layer") # Load the Breakout environment. environment = GymEnvironment("BreakoutDeterministic-v4") diff --git a/examples/mnist/batch_eth_mnist.py b/examples/mnist/batch_eth_mnist.py index 8cb9854..8fe6535 100644 --- a/examples/mnist/batch_eth_mnist.py +++ b/examples/mnist/batch_eth_mnist.py @@ -136,9 +136,7 @@ voltages = {} for layer in set(network.layers) - {"X"}: - voltages[layer] = Monitor( - network.layers[layer], state_vars=["v"], time=time - ) + voltages[layer] = Monitor(network.layers[layer], state_vars=["v"], time=time) network.add_monitor(voltages[layer], name="%s_voltages" % layer) inpt_ims, inpt_axes = None, None @@ -158,9 +156,7 @@ labels = [] if epoch % progress_interval == 0: - print( - "Progress: %d / %d (%.4f seconds)" % (epoch, n_epochs, t() - start) - ) + print("Progress: %d / %d (%.4f seconds)" % (epoch, n_epochs, t() - start)) start = t() # Create a dataloader to iterate and batch data @@ -184,9 +180,7 @@ # Get network predictions. all_activity_pred = all_activity( - spikes=spike_record, - assignments=assignments, - n_labels=n_classes, + spikes=spike_record, assignments=assignments, n_labels=n_classes, ) proportion_pred = proportion_weighting( spikes=spike_record, @@ -261,19 +255,18 @@ ) square_assignments = get_square_assignments(assignments, n_sqrt) spikes_ = { - layer: spikes[layer].get("s")[:, 0].contiguous() - for layer in spikes + layer: spikes[layer].get("s")[:, 0].contiguous() for layer in spikes } voltages = {"Ae": exc_voltages, "Ai": inh_voltages} inpt_axes, inpt_ims = plot_input( image, inpt, label=labels[step], axes=inpt_axes, ims=inpt_ims ) - spike_ims, spike_axes = plot_spikes( - spikes_, ims=spike_ims, axes=spike_axes - ) + spike_ims, spike_axes = plot_spikes(spikes_, ims=spike_ims, axes=spike_axes) weights_im = plot_weights(square_weights, im=weights_im) assigns_im = plot_assignments(square_assignments, im=assigns_im) - perf_ax = plot_performance(accuracy, x_scale=update_steps * batch_size, ax=perf_ax) + perf_ax = plot_performance( + accuracy, x_scale=update_steps * batch_size, ax=perf_ax + ) voltage_ims, voltage_axes = plot_voltages( voltages, ims=voltage_ims, axes=voltage_axes, plot_type="line" ) diff --git a/examples/mnist/conv_mnist.py b/examples/mnist/conv_mnist.py index b9ca38a..a50d34a 100644 --- a/examples/mnist/conv_mnist.py +++ b/examples/mnist/conv_mnist.py @@ -91,9 +91,7 @@ wmax=1.0, ) -w = torch.zeros( - n_filters, conv_size, conv_size, n_filters, conv_size, conv_size -) +w = torch.zeros(n_filters, conv_size, conv_size, n_filters, conv_size, conv_size) for fltr1 in range(n_filters): for fltr2 in range(n_filters): if fltr1 != fltr2: @@ -101,9 +99,7 @@ for j in range(conv_size): w[fltr1, i, j, fltr2, i, j] = -100.0 -w = w.view( - n_filters * conv_size * conv_size, n_filters * conv_size * conv_size -) +w = w.view(n_filters * conv_size * conv_size, n_filters * conv_size * conv_size) recurrent_conn = Connection(conv_layer, conv_layer, w=w) network.add_layer(input_layer, name="X") @@ -137,9 +133,7 @@ voltages = {} for layer in set(network.layers) - {"X"}: - voltages[layer] = Monitor( - network.layers[layer], state_vars=["v"], time=time - ) + voltages[layer] = Monitor(network.layers[layer], state_vars=["v"], time=time) network.add_monitor(voltages[layer], name="%s_voltages" % layer) # Train the network. @@ -156,17 +150,11 @@ for epoch in range(n_epochs): if epoch % progress_interval == 0: - print( - "Progress: %d / %d (%.4f seconds)" % (epoch, n_epochs, t() - start) - ) + print("Progress: %d / %d (%.4f seconds)" % (epoch, n_epochs, t() - start)) start = t() train_dataloader = torch.utils.data.DataLoader( - train_dataset, - batch_size=1, - shuffle=True, - num_workers=4, - pin_memory=gpu, + train_dataset, batch_size=1, shuffle=True, num_workers=4, pin_memory=gpu, ) for step, batch in enumerate(tqdm(train_dataloader)): @@ -195,9 +183,7 @@ inpt_axes, inpt_ims = plot_input( image, inpt, label=label, axes=inpt_axes, ims=inpt_ims ) - spike_ims, spike_axes = plot_spikes( - _spikes, ims=spike_ims, axes=spike_axes - ) + spike_ims, spike_axes = plot_spikes(_spikes, ims=spike_ims, axes=spike_axes) weights1_im = plot_conv2d_weights(weights1, im=weights1_im) voltage_ims, voltage_axes = plot_voltages( _voltages, ims=voltage_ims, axes=voltage_axes diff --git a/examples/mnist/eth_mnist.py b/examples/mnist/eth_mnist.py index 8c4b7fc..d3e0c04 100644 --- a/examples/mnist/eth_mnist.py +++ b/examples/mnist/eth_mnist.py @@ -137,9 +137,7 @@ voltages = {} for layer in set(network.layers) - {"X"}: - voltages[layer] = Monitor( - network.layers[layer], state_vars=["v"], time=time - ) + voltages[layer] = Monitor(network.layers[layer], state_vars=["v"], time=time) network.add_monitor(voltages[layer], name="%s_voltages" % layer) inpt_ims, inpt_axes = None, None @@ -157,18 +155,12 @@ labels = [] if epoch % progress_interval == 0: - print( - "Progress: %d / %d (%.4f seconds)" % (epoch, n_epochs, t() - start) - ) + print("Progress: %d / %d (%.4f seconds)" % (epoch, n_epochs, t() - start)) start = t() # Create a dataloader to iterate and batch data dataloader = torch.utils.data.DataLoader( - dataset, - batch_size=1, - shuffle=True, - num_workers=n_workers, - pin_memory=gpu, + dataset, batch_size=1, shuffle=True, num_workers=n_workers, pin_memory=gpu, ) for step, batch in enumerate(tqdm(dataloader)): @@ -183,9 +175,7 @@ # Get network predictions. all_activity_pred = all_activity( - spikes=spike_record, - assignments=assignments, - n_labels=n_classes, + spikes=spike_record, assignments=assignments, n_labels=n_classes, ) proportion_pred = proportion_weighting( spikes=spike_record, @@ -259,9 +249,7 @@ inpt_axes, inpt_ims = plot_input( image, inpt, label=batch["label"], axes=inpt_axes, ims=inpt_ims ) - spike_ims, spike_axes = plot_spikes( - spikes_, ims=spike_ims, axes=spike_axes - ) + spike_ims, spike_axes = plot_spikes(spikes_, ims=spike_ims, axes=spike_axes) weights_im = plot_weights(square_weights, im=weights_im) assigns_im = plot_assignments(square_assignments, im=assigns_im) perf_ax = plot_performance(accuracy, x_scale=update_interval, ax=perf_ax) diff --git a/examples/mnist/reservoir.py b/examples/mnist/reservoir.py index 9bcf08e..d2b99a7 100644 --- a/examples/mnist/reservoir.py +++ b/examples/mnist/reservoir.py @@ -74,16 +74,10 @@ network = Network(dt=dt) inpt = Input(784, shape=(1, 28, 28)) network.add_layer(inpt, name="I") -output = LIFNodes( - n_neurons, thresh=-52 + np.random.randn(n_neurons).astype(float) -) +output = LIFNodes(n_neurons, thresh=-52 + np.random.randn(n_neurons).astype(float)) network.add_layer(output, name="O") -C1 = Connection( - source=inpt, target=output, w=0.5 * torch.randn(inpt.n, output.n) -) -C2 = Connection( - source=output, target=output, w=0.5 * torch.randn(output.n, output.n) -) +C1 = Connection(source=inpt, target=output, w=0.5 * torch.randn(inpt.n, output.n)) +C2 = Connection(source=output, target=output, w=0.5 * torch.randn(output.n, output.n)) network.add_connection(C1, source="I", target="O") network.add_connection(C2, source="O", target="O") @@ -155,10 +149,7 @@ ims=spike_ims, ) voltage_ims, voltage_axes = plot_voltages( - { - layer: voltages[layer].get("v").view(-1, time) - for layer in voltages - }, + {layer: voltages[layer].get("v").view(-1, time) for layer in voltages}, ims=voltage_ims, axes=voltage_axes, ) @@ -239,10 +230,7 @@ def forward(self, x): ims=spike_ims, ) voltage_ims, voltage_axes = plot_voltages( - { - layer: voltages[layer].get("v").view(-1, 250) - for layer in voltages - }, + {layer: voltages[layer].get("v").view(-1, 250) for layer in voltages}, ims=voltage_ims, axes=voltage_axes, ) diff --git a/examples/mnist/supervised_mnist.py b/examples/mnist/supervised_mnist.py index 8c2d9b3..7fc826a 100644 --- a/examples/mnist/supervised_mnist.py +++ b/examples/mnist/supervised_mnist.py @@ -158,23 +158,15 @@ # Compute network accuracy according to available classification strategies. accuracy["all"].append( - 100 - * torch.sum(label.long() == all_activity_pred).item() - / update_interval + 100 * torch.sum(label.long() == all_activity_pred).item() / update_interval ) accuracy["proportion"].append( - 100 - * torch.sum(label.long() == proportion_pred).item() - / update_interval + 100 * torch.sum(label.long() == proportion_pred).item() / update_interval ) print( "\nAll activity accuracy: %.2f (last), %.2f (average), %.2f (best)" - % ( - accuracy["all"][-1], - np.mean(accuracy["all"]), - np.max(accuracy["all"]), - ) + % (accuracy["all"][-1], np.mean(accuracy["all"]), np.max(accuracy["all"]),) ) print( "Proportion weighting accuracy: %.2f (last), %.2f (average), %.2f (best)\n" @@ -186,9 +178,7 @@ ) # Assign labels to excitatory layer neurons. - assignments, proportions, rates = assign_labels( - spike_record, label, 10, rates - ) + assignments, proportions, rates = assign_labels(spike_record, label, 10, rates) # Run the network on the input. choice = np.random.choice(int(n_neurons / 10), size=n_clamp, replace=False) @@ -201,9 +191,7 @@ inh_voltages = inh_voltage_monitor.get("v") # Add to spikes recording. - spike_record[i % update_interval] = ( - spikes["Ae"].get("s").view(time, n_neurons) - ) + spike_record[i % update_interval] = spikes["Ae"].get("s").view(time, n_neurons) # Optionally plot various simulation information. if plot: @@ -216,11 +204,7 @@ voltages = {"Ae": exc_voltages, "Ai": inh_voltages} inpt_axes, inpt_ims = plot_input( - image.sum(1).view(28, 28), - inpt, - label=label, - axes=inpt_axes, - ims=inpt_ims, + image.sum(1).view(28, 28), inpt, label=label, axes=inpt_axes, ims=inpt_ims, ) spike_ims, spike_axes = plot_spikes( {layer: spikes[layer].get("s") for layer in spikes}, diff --git a/test/encoding/test_encoding.py b/test/encoding/test_encoding.py index 85f32bc..4484fd3 100644 --- a/test/encoding/test_encoding.py +++ b/test/encoding/test_encoding.py @@ -32,9 +32,7 @@ def test_bernoulli_loader(self): for m in [0.1, 1.0]: # maximum spiking probability for t in [1, 100]: # number of timesteps data = torch.empty(s, n).uniform_(0, 1) - spike_loader = bernoulli_loader( - data, time=t, max_prob=m - ) + spike_loader = bernoulli_loader(data, time=t, max_prob=m) for i, spikes in enumerate(spike_loader): assert spikes.size() == torch.Size((t, n)) @@ -42,9 +40,7 @@ def test_bernoulli_loader(self): def test_poisson(self): for n in [1, 100]: # number of nodes in layer for t in [1000]: # number of timesteps - datum = torch.empty(n).uniform_( - 20, 100 - ) # Generate firing rates. + datum = torch.empty(n).uniform_(20, 100) # Generate firing rates. spikes = poisson(datum, time=t) # Encode as spikes. assert spikes.size() == torch.Size((t, n)) @@ -53,12 +49,8 @@ def test_poisson_loader(self): for s in [1, 10]: # number of data samples for n in [1, 100]: # number of nodes in layer for t in [1000]: # number of timesteps - data = torch.empty(s, n).uniform_( - 20, 100 - ) # Generate firing rates. - spike_loader = poisson_loader( - data, time=t - ) # Encode as spikes. + data = torch.empty(s, n).uniform_(20, 100) # Generate firing rates. + spike_loader = poisson_loader(data, time=t) # Encode as spikes. for i, spikes in enumerate(spike_loader): assert spikes.size() == torch.Size((t, n)) diff --git a/test/models/test_models.py b/test/models/test_models.py index 0863955..465dac1 100644 --- a/test/models/test_models.py +++ b/test/models/test_models.py @@ -8,9 +8,7 @@ def test_init(self): for n_inpt in [50, 100, 200]: for n_neurons in [50, 100, 200]: for dt in [1.0, 2.0]: - network = TwoLayerNetwork( - n_inpt, n_neurons=n_neurons, dt=dt - ) + network = TwoLayerNetwork(n_inpt, n_neurons=n_neurons, dt=dt) assert network.n_inpt == n_inpt assert network.n_neurons == n_neurons @@ -24,13 +22,10 @@ def test_init(self): isinstance(network.layers["Y"], LIFNodes) and network.layers["Y"].n == n_neurons ) - assert isinstance( - network.connections[("X", "Y")], Connection - ) + assert isinstance(network.connections[("X", "Y")], Connection) assert ( network.connections[("X", "Y")].source.n == n_inpt - and network.connections[("X", "Y")].target.n - == n_neurons + and network.connections[("X", "Y")].target.n == n_neurons ) @@ -60,9 +55,7 @@ def test_init(self): and network.layers["X"].n == n_inpt ) assert ( - isinstance( - network.layers["Ae"], DiehlAndCookNodes - ) + isinstance(network.layers["Ae"], DiehlAndCookNodes) and network.layers["Ae"].n == n_neurons ) assert ( diff --git a/test/network/test_connections.py b/test/network/test_connections.py index a2dc3b6..6f76cb2 100644 --- a/test/network/test_connections.py +++ b/test/network/test_connections.py @@ -42,20 +42,12 @@ def test_transfer(self): conn_type, connection.state_dict().keys() ) ) - print( - "__dict__ in {} : {}".format( - conn_type, connection.__dict__.keys() - ) - ) + print("__dict__ in {} : {}".format(conn_type, connection.__dict__.keys())) print("Tensors in {} : {}".format(conn_type, connection_tensors)) - tensor_devs = [ - getattr(connection, k).device for k in connection_tensors - ] + tensor_devs = [getattr(connection, k).device for k in connection_tensors] print( - "Tensor devices {}".format( - list(zip(connection_tensors, tensor_devs)) - ) + "Tensor devices {}".format(list(zip(connection_tensors, tensor_devs))) ) for d in tensor_devs: diff --git a/test/network/test_learning.py b/test/network/test_learning.py index 239d8c8..6207602 100644 --- a/test/network/test_learning.py +++ b/test/network/test_learning.py @@ -34,16 +34,13 @@ def test_hebbian(self): target="output", ) network.run( - inputs={"input": torch.bernoulli(torch.rand(250, 100)).byte()}, - time=250, + inputs={"input": torch.bernoulli(torch.rand(250, 100)).byte()}, time=250, ) # Conv2dConnection test network = Network(dt=1.0) network.add_layer(Input(shape=[1, 10, 10], traces=True), name="input") - network.add_layer( - LIFNodes(shape=[32, 8, 8], traces=True), name="output" - ) + network.add_layer(LIFNodes(shape=[32, 8, 8], traces=True), name="output") network.add_connection( Conv2dConnection( source=network.layers["input"], @@ -58,9 +55,7 @@ def test_hebbian(self): ) # shape is [time, batch, channels, height, width] network.run( - inputs={ - "input": torch.bernoulli(torch.rand(250, 1, 1, 10, 10)).byte() - }, + inputs={"input": torch.bernoulli(torch.rand(250, 1, 1, 10, 10)).byte()}, time=250, ) @@ -80,16 +75,13 @@ def test_post_pre(self): target="output", ) network.run( - inputs={"input": torch.bernoulli(torch.rand(250, 100)).byte()}, - time=250, + inputs={"input": torch.bernoulli(torch.rand(250, 100)).byte()}, time=250, ) # Conv2dConnection test network = Network(dt=1.0) network.add_layer(Input(shape=[1, 10, 10], traces=True), name="input") - network.add_layer( - LIFNodes(shape=[32, 8, 8], traces=True), name="output" - ) + network.add_layer(LIFNodes(shape=[32, 8, 8], traces=True), name="output") network.add_connection( Conv2dConnection( source=network.layers["input"], @@ -103,9 +95,7 @@ def test_post_pre(self): target="output", ) network.run( - inputs={ - "input": torch.bernoulli(torch.rand(250, 1, 1, 10, 10)).byte() - }, + inputs={"input": torch.bernoulli(torch.rand(250, 1, 1, 10, 10)).byte()}, time=250, ) @@ -127,16 +117,13 @@ def test_weight_dependent_post_pre(self): target="output", ) network.run( - inputs={"input": torch.bernoulli(torch.rand(250, 100)).byte()}, - time=250, + inputs={"input": torch.bernoulli(torch.rand(250, 100)).byte()}, time=250, ) # Conv2dConnection test network = Network(dt=1.0) network.add_layer(Input(shape=[1, 10, 10], traces=True), name="input") - network.add_layer( - LIFNodes(shape=[32, 8, 8], traces=True), name="output" - ) + network.add_layer(LIFNodes(shape=[32, 8, 8], traces=True), name="output") network.add_connection( Conv2dConnection( source=network.layers["input"], @@ -152,9 +139,7 @@ def test_weight_dependent_post_pre(self): target="output", ) network.run( - inputs={ - "input": torch.bernoulli(torch.rand(250, 1, 1, 10, 10)).byte() - }, + inputs={"input": torch.bernoulli(torch.rand(250, 1, 1, 10, 10)).byte()}, time=250, ) @@ -197,9 +182,7 @@ def test_mstdp(self): ) network.run( - inputs={ - "input": torch.bernoulli(torch.rand(250, 1, 1, 10, 10)).byte() - }, + inputs={"input": torch.bernoulli(torch.rand(250, 1, 1, 10, 10)).byte()}, time=250, reward=1.0, ) @@ -243,9 +226,7 @@ def test_mstdpet(self): ) network.run( - inputs={ - "input": torch.bernoulli(torch.rand(250, 1, 1, 10, 10)).byte() - }, + inputs={"input": torch.bernoulli(torch.rand(250, 1, 1, 10, 10)).byte()}, time=250, reward=1.0, ) @@ -253,9 +234,7 @@ def test_mstdpet(self): def test_rmax(self): # Connection test network = Network(dt=1.0) - network.add_layer( - Input(n=100, traces=True, traces_additive=True), name="input" - ) + network.add_layer(Input(n=100, traces=True, traces_additive=True), name="input") network.add_layer(SRM0Nodes(n=100), name="output") network.add_connection( Connection( diff --git a/test/network/test_monitors.py b/test/network/test_monitors.py index 4d3a4ac..031ca10 100644 --- a/test/network/test_monitors.py +++ b/test/network/test_monitors.py @@ -25,9 +25,7 @@ class TestMonitor: _if_mon = Monitor(_if, state_vars=["s", "v"]) network.add_monitor(_if_mon, name="Y") - network.run( - inputs={"X": torch.bernoulli(torch.rand(100, inpt.n))}, time=100 - ) + network.run(inputs={"X": torch.bernoulli(torch.rand(100, inpt.n))}, time=100) assert inpt_mon.get("s").size() == torch.Size([100, 1, inpt.n]) assert _if_mon.get("s").size() == torch.Size([100, 1, _if.n]) @@ -40,9 +38,7 @@ class TestMonitor: _if_mon = Monitor(_if, state_vars=["s", "v"], time=500) network.add_monitor(_if_mon, name="Y") - network.run( - inputs={"X": torch.bernoulli(torch.rand(500, inpt.n))}, time=500 - ) + network.run(inputs={"X": torch.bernoulli(torch.rand(500, inpt.n))}, time=500) assert inpt_mon.get("s").size() == torch.Size([500, 1, inpt.n]) assert _if_mon.get("s").size() == torch.Size([500, 1, _if.n]) diff --git a/test/network/test_nodes.py b/test/network/test_nodes.py index 5af023b..51beb70 100644 --- a/test/network/test_nodes.py +++ b/test/network/test_nodes.py @@ -20,14 +20,7 @@ class TestNodes: def test_init(self): network = Network() for i, nodes in enumerate( - [ - Input, - McCullochPitts, - IFNodes, - LIFNodes, - AdaptiveLIFNodes, - SRM0Nodes, - ] + [Input, McCullochPitts, IFNodes, LIFNodes, AdaptiveLIFNodes, SRM0Nodes,] ): for n in [1, 100, 10000]: layer = nodes(n) @@ -54,12 +47,7 @@ def test_init(self): for nodes in [LIFNodes, AdaptiveLIFNodes]: for n in [1, 100, 10000]: layer = nodes( - n, - rest=0.0, - reset=-10.0, - thresh=10.0, - refrac=3, - tc_decay=1.5e3, + n, rest=0.0, reset=-10.0, thresh=10.0, refrac=3, tc_decay=1.5e3, ) network.add_layer(layer=layer, name=f"{i}_params_{n}") @@ -81,25 +69,15 @@ def test_transfer(self): layer.to(torch.device("cuda:0")) layer_tensors = [ - k - for k, v in layer.state_dict().items() - if isinstance(v, torch.Tensor) + k for k, v in layer.state_dict().items() if isinstance(v, torch.Tensor) ] tensor_devs = [getattr(layer, k).device for k in layer_tensors] - print( - "State dict in {} : {}".format( - nodes, layer.state_dict().keys() - ) - ) + print("State dict in {} : {}".format(nodes, layer.state_dict().keys())) print("__dict__ in {} : {}".format(nodes, layer.__dict__.keys())) print("Tensors in {} : {}".format(nodes, layer_tensors)) - print( - "Tensor devices {}".format( - list(zip(layer_tensors, tensor_devs)) - ) - ) + print("Tensor devices {}".format(list(zip(layer_tensors, tensor_devs)))) for d in tensor_devs: print(d, d == torch.device("cuda:0")) @@ -108,9 +86,7 @@ def test_transfer(self): print("Reset layer") layer.reset_state_variables() layer_tensors = [ - k - for k, v in layer.state_dict().items() - if isinstance(v, torch.Tensor) + k for k, v in layer.state_dict().items() if isinstance(v, torch.Tensor) ] tensor_devs = [getattr(layer, k).device for k in layer_tensors]