In [1]:
from typing import Tuple, List, Dict

import torch
import numpy as np
from numpy.testing import assert_array_almost_equal

from pytorch_probing import Prober

In [2]:
class ExampleModel(torch.nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()

        self.linear1 = torch.nn.Linear(input_size, hidden_size)
        self.relu = torch.nn.ReLU()
        self.linear2 = torch.nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)

        return x


In [3]:
test_input_size = 2
test_hidden_size = 3
test_output_size = 1
test_probe_size = 2

model = ExampleModel(test_input_size, test_hidden_size, test_output_size)
model.eval()

probe = torch.nn.Linear(test_hidden_size, test_probe_size)

In [4]:
probes = {"linear1":probe, "relu":None}

probed_model = Prober(model, probes)

In [5]:
inputs = torch.randn([10, 2])
outputs = probed_model(inputs)

In [6]:
linear_output = model.linear1(inputs)
probe_output = probe(linear_output).detach()

In [7]:
probed_output = outputs[1]["linear1"].detach()
assert_array_almost_equal(probe_output, probed_output)

In [8]:
relu_output = model.relu(linear_output).detach()

probed_output = outputs[1]["relu"].detach()
assert_array_almost_equal(relu_output, probed_output)