Skip to content

Commit

Permalink
Merge pull request #87 from Hananel-Hazan/dan
Browse files Browse the repository at this point in the history
Renaming to locally connected layer and adding reference to reflect literature.
  • Loading branch information
Dan Saunders committed Jun 22, 2018
2 parents 317e4f3 + 1dc9983 commit 136abad
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
9 changes: 5 additions & 4 deletions bindsnet/analysis/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from mpl_toolkits.axes_grid1 import make_axes_locatable

from ..utils import reshape_fully_conv_weights
from ..utils import reshape_locally_connected_weights


plt.ion()
Expand Down Expand Up @@ -234,10 +234,11 @@ def plot_conv2d_weights(weights, wmin=0.0, wmax=1.0, im=None, figsize=(5, 5)):
return im


def plot_fully_conv_weights(weights, n_filters, kernel_size, conv_size, locations,
def plot_locally_connected_weights(weights, n_filters, kernel_size, conv_size, locations,
input_sqrt, wmin=0.0, wmax=1.0, im=None, figsize=(5, 5)):
'''
Plot a connection weight matrix of a Connection with fully convolutional structure.
Plot a connection weight matrix of a Connection with
`locally connected structure <http://yann.lecun.com/exdb/publis/pdf/gregor-nips-11.pdf>_.
Inputs:
Expand All @@ -257,7 +258,7 @@ def plot_fully_conv_weights(weights, n_filters, kernel_size, conv_size, location
| (:code:`im` (:code:`matplotlib.image.AxesImage`): Used for re-drawing the weights plot.
'''
reshaped = reshape_fully_conv_weights(weights, n_filters, kernel_size,
reshaped = reshape_locally_connected_weights(weights, n_filters, kernel_size,
conv_size, locations, input_sqrt)

n_sqrt = int(np.ceil(np.sqrt(n_filters))) * conv_size
Expand Down
4 changes: 2 additions & 2 deletions bindsnet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,9 @@ def get_square_assignments(assignments, n_sqrt):

return square_assignments

def reshape_fully_conv_weights(w, n_filters, kernel_size, conv_size, locations, input_sqrt):
def reshape_locally_connected_weights(w, n_filters, kernel_size, conv_size, locations, input_sqrt):
'''
Get the weights from a fully convolution layer
Get the weights from a locally connected layer
and reshape them to be two-dimensional and square.
'''
w_ = torch.zeros((n_filters * kernel_size, kernel_size * conv_size ** 2))
Expand Down

0 comments on commit 136abad

Please sign in to comment.