Skip to content

Turtle-dev3/neuronview

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

neuronview

A lightweight PyTorch library for visualizing neural network activations as heatmaps. Hook into any layer of your model and see what it "sees" — with one line of code.

Installation

# From source (recommended during development)
git clone https://github.com/yourusername/neuronview.git
cd neuronview
pip install -e ".[dev]"

Quick Start

import torch
import torchvision.models as models
from neuronview import Inspector

# 1. Load any PyTorch model
model = models.resnet18(pretrained=True)

# 2. Create an Inspector and pick a layer to watch
inspector = Inspector(model)
inspector.watch("layer2.0.conv1")

# 3. Run a forward pass with your input
image = torch.randn(1, 3, 224, 224)  # replace with a real image
inspector.run(image)

# 4. Visualize!
inspector.heatmap()

Features

Discover layers

Not sure which layers your model has? List them all:

from neuronview import list_layers
import torchvision.models as models

model = models.resnet18()
for name in list_layers(model):
    print(name)
# conv1
# bn1
# relu
# maxpool
# layer1.0.conv1
# layer1.0.bn1
# ...

Watch multiple layers

You can hook into several layers at once using method chaining:

inspector = Inspector(model)
inspector.watch("layer1.0.conv1").watch("layer3.0.conv1")
inspector.run(image)

# Specify which layer to visualize
inspector.heatmap(layer_name="layer1.0.conv1")
inspector.heatmap(layer_name="layer3.0.conv1")

View a specific channel

Each convolutional layer has multiple channels (filters). See what an individual channel detects:

inspector.heatmap(channel=5)            # show channel 5
inspector.heatmap(channel=0, cmap="hot") # different colormap

Overlay on the original image

See which part of your input image activated the layer most:

inspector.heatmap_overlay(
    original_image=image,
    alpha=0.6,
    cmap="jet",
)

Save figures

inspector.heatmap(save_path="activation_map.png")

Clean up

inspector.clear()          # clear stored activations (keep hooks)
inspector.unwatch("conv1") # remove a specific hook
inspector.unwatch()        # remove all hooks

API Reference

Inspector(model)

The main class. Wraps a PyTorch model and manages forward hooks.

Method Description
.watch(layer_name) Hook into a layer by its dot-path name. Returns self for chaining.
.run(x) Run a forward pass and capture activations. Returns model output.
.get_activations(layer_name=None) Get the raw activation tensor. If only one layer is watched, layer_name can be omitted.
.heatmap(layer_name=None, channel=None, **kwargs) Render a heatmap. Pass cmap, figsize, title, save_path.
.heatmap_overlay(original_image, layer_name=None, channel=None, alpha=0.5) Overlay heatmap on the input image.
.layers() List all hookable layer names in the model.
.unwatch(layer_name=None) Remove hooks (all if no name given).
.clear() Clear stored activations without removing hooks.

list_layers(model, include_containers=False)

Standalone function to list all hookable layers in a model.

heatmap(activation, channel=None, cmap="viridis", ...)

Standalone function — render any activation tensor as a heatmap.

heatmap_overlay(activation, original_image, channel=None, alpha=0.5, ...)

Standalone function — overlay an activation heatmap on an image.

Running Tests

pip install -e ".[dev]"
pytest

Project Structure

neuronview/
├── neuronview/
│   ├── __init__.py      # Public API exports
│   ├── inspector.py     # Inspector class (hooks + activation capture)
│   ├── visualize.py     # Heatmap rendering with matplotlib
│   └── utils.py         # Layer listing + lookup helpers
├── tests/
│   ├── test_inspector.py
│   └── test_visualize.py
├── pyproject.toml       # Package metadata + dependencies
└── README.md

How It Works

The core mechanism is PyTorch forward hooks. When you call inspector.watch("layer2.0.conv1"), neuronview:

  1. Walks the model's module tree to find that layer
  2. Registers a callback (register_forward_hook) on it
  3. When inspector.run(x) triggers a forward pass, the callback fires and captures the layer's output tensor
  4. The captured tensor is detached from the autograd graph and moved to CPU
  5. heatmap() averages across channels (or picks one) and renders with matplotlib

License

MIT

About

A lightweight PyTorch activation heatmap visualizer

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages