From 1dc9983f997c842e87b8084c4a2fca8ac57892b3 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Fri, 22 Jun 2018 09:53:23 -0400 Subject: [PATCH] Renaming to locally connected layer and adding reference to reflect literature. --- bindsnet/analysis/plotting.py | 9 +++++---- bindsnet/utils.py | 4 ++-- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/bindsnet/analysis/plotting.py b/bindsnet/analysis/plotting.py index 7f4e0619..c946031c 100644 --- a/bindsnet/analysis/plotting.py +++ b/bindsnet/analysis/plotting.py @@ -5,7 +5,7 @@ from mpl_toolkits.axes_grid1 import make_axes_locatable -from ..utils import reshape_fully_conv_weights +from ..utils import reshape_locally_connected_weights plt.ion() @@ -234,10 +234,11 @@ def plot_conv2d_weights(weights, wmin=0.0, wmax=1.0, im=None, figsize=(5, 5)): return im -def plot_fully_conv_weights(weights, n_filters, kernel_size, conv_size, locations, +def plot_locally_connected_weights(weights, n_filters, kernel_size, conv_size, locations, input_sqrt, wmin=0.0, wmax=1.0, im=None, figsize=(5, 5)): ''' - Plot a connection weight matrix of a Connection with fully convolutional structure. + Plot a connection weight matrix of a Connection with + `locally connected structure _. Inputs: @@ -257,7 +258,7 @@ def plot_fully_conv_weights(weights, n_filters, kernel_size, conv_size, location | (:code:`im` (:code:`matplotlib.image.AxesImage`): Used for re-drawing the weights plot. ''' - reshaped = reshape_fully_conv_weights(weights, n_filters, kernel_size, + reshaped = reshape_locally_connected_weights(weights, n_filters, kernel_size, conv_size, locations, input_sqrt) n_sqrt = int(np.ceil(np.sqrt(n_filters))) * conv_size diff --git a/bindsnet/utils.py b/bindsnet/utils.py index 65c01ace..d0918adf 100644 --- a/bindsnet/utils.py +++ b/bindsnet/utils.py @@ -91,9 +91,9 @@ def get_square_assignments(assignments, n_sqrt): return square_assignments -def reshape_fully_conv_weights(w, n_filters, kernel_size, conv_size, locations, input_sqrt): +def reshape_locally_connected_weights(w, n_filters, kernel_size, conv_size, locations, input_sqrt): ''' - Get the weights from a fully convolution layer + Get the weights from a locally connected layer and reshape them to be two-dimensional and square. ''' w_ = torch.zeros((n_filters * kernel_size, kernel_size * conv_size ** 2))