Skip to content

Commit

Permalink
Merge pull request #578 from ValerioB88/fix-softmax-batching
Browse files Browse the repository at this point in the history
Squeeze spike record sum to compute softmax on the correct dimension
  • Loading branch information
Hananel-Hazan committed Sep 16, 2022
2 parents 61fda79 + 10507df commit a94bc98
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion bindsnet/pipeline/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def select_softmax(pipeline: EnvironmentPipeline, **kwargs) -> int:
pipeline, "spike_record"
), "EnvironmentPipeline is missing the attribute: spike_record."

spikes = torch.sum(pipeline.spike_record[output], dim=0)
spikes = torch.sum(pipeline.spike_record[output], dim=0).squeeze()
probabilities = torch.softmax(spikes, dim=0)
return torch.multinomial(probabilities, num_samples=1).item()

Expand Down

0 comments on commit a94bc98

Please sign in to comment.