In [1]:
%load_ext autoreload
%autoreload 2
%load_ext tensorboard

In [2]:
import torch
from torch import nn
from torch.nn import functional as F

from torch.utils.tensorboard import SummaryWriter

import pathlib

torch.manual_seed(42)

<torch._C.Generator at 0x7fc4e81fd4d0>

In [3]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(100, 200)
        self.fc2 = nn.Linear(200, 50)
        self.fc3 = nn.Linear(50, 5)
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [4]:
tensorboard --logdir=tensorboard

In [5]:
path = pathlib.Path.cwd() / 'tensorboard'
writer = SummaryWriter(path)

x = torch.randn(10, 100)
net = Net()

In [6]:
class Hooks:
    def __init__(self, model):
        self.model = model
        
    def forward_hooks(self):
        for module in net.modules():
            if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
                module.register_forward_hook(self.activation_hook)
                
    def activation_hook(self, module, inp, out):
        writer.add_histogram(f'Pre-Activations/{repr(module)}', out)
        
    def gradient_hooks(self):
        for module in net.modules():
            if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
                print(repr(module))
                module.weight.register_hook(self.grad_hook_wrapper(module))
    
    def grad_hook_wrapper(self, module):
        def grad_hook(grad):
            writer.add_histogram(f'Gradients/{repr(module)}', grad)
        return grad_hook

In [7]:
hooks = Hooks(net)
hooks.forward_hooks()
hooks.gradient_hooks()

Linear(in_features=100, out_features=200, bias=True)
Linear(in_features=200, out_features=50, bias=True)
Linear(in_features=50, out_features=5, bias=True)


In [8]:
y = net(x)
y.sum().backward()

In [None]:
net.fc1.weight.register_hook

In [10]:
net.

<bound method Module.eval of Net(
  (fc1): Linear(in_features=100, out_features=200, bias=True)
  (fc2): Linear(in_features=200, out_features=50, bias=True)
  (fc3): Linear(in_features=50, out_features=5, bias=True)
)>