# Prober example

This example shows how to use `Prober` to couple already trained probes to a model.

In [17]:
import torch # PyTorch

from pytorch_probing import Prober, ParallelModuleDict # Prober and dictionary of modules

We start creating a example model, a simple MLP:

In [19]:
class ExampleModel(torch.nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()

        self.linear1 = torch.nn.Linear(input_size, hidden_size)
        self.relu = torch.nn.ReLU()
        self.linear2 = torch.nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)

        return x


In [20]:
input_size = 2
hidden_size = 3
output_size = 1

model = ExampleModel(input_size, hidden_size, output_size)
model.eval()

ExampleModel(
  (linear1): Linear(in_features=2, out_features=3, bias=True)
  (relu): ReLU()
  (linear2): Linear(in_features=3, out_features=1, bias=True)
)

And a probe. Any torch module can be used as a probe. In this example we gonna use a simple Linear layer:

In [21]:
probe_size = 2

probe = torch.nn.Linear(hidden_size, probe_size)
probe.eval()

Linear(in_features=3, out_features=2, bias=True)

And we created the Prober, passing it the model and a dictionary mapping the paths of the modules to the probes that must be coupled to its outputs. When a `None` value is passed, it creates a `Identity` module, that just pass its inputs to the outputs:

In [22]:
probes = {"linear1":probe, "relu":None}

probed_model = Prober(model, probes)

We pass a sample value to the model:

In [23]:
inputs = torch.randn([10, 2])

with torch.no_grad():
    outputs = probed_model(inputs)

And the output is a tuple with the model output in the first value, and the probes outputs in the second:

In [24]:
outputs[0]

tensor([[-0.8719],
        [-0.5281],
        [-0.6706],
        [-0.6388],
        [-0.6144],
        [-0.6609],
        [-0.7498],
        [-0.6969],
        [-0.8241],
        [-0.5516]])

In [25]:
outputs[1]

{'linear1': tensor([[-0.5845,  0.6322],
         [-0.1392,  0.1684],
         [-0.4905,  0.4023],
         [-0.3429,  0.4295],
         [-0.4883,  0.3331],
         [-0.3034,  0.4913],
         [-0.3489,  0.6187],
         [-0.3726,  0.5127],
         [-0.3951,  0.7203],
         [-0.1433,  0.3360]]),
 'relu': tensor([[0.0000, 1.6433, 0.2960],
         [0.1537, 0.0000, 0.0000],
         [0.2034, 0.7000, 0.2462],
         [0.0000, 0.5001, 0.0000],
         [0.3455, 0.4674, 0.3000],
         [0.0000, 0.6264, 0.0000],
         [0.0000, 1.1360, 0.0000],
         [0.0000, 0.8330, 0.0000],
         [0.0000, 1.5619, 0.0000],
         [0.0000, 0.0000, 0.0000]])}

## Multiple probes in the same place

We can also use more than one probe in the same place. For showing it, we gonna create a second probe and reduce the probed model to the original model:

In [26]:
probe2_size = 1

probe2 = torch.nn.Linear(hidden_size, probe2_size)
probe2.eval()

Linear(in_features=3, out_features=1, bias=True)

In [27]:
model = probed_model.reduce()

We can than create a `ParallelModuleDict` with the two probes. We called with some input, the `ParallelModuleDict` pass the input to all its modules, and return a dictionary with each module output.

In [28]:
linear1_probes = ParallelModuleDict({"probe1":probe, "probe2":probe2})

probes = {"linear1":linear1_probes}

probed_model = Prober(model, probes)

In [29]:
inputs = torch.randn([10, 2])

with torch.no_grad():
    outputs = probed_model(inputs)

In [30]:
outputs[1]

{'linear1': {'probe1': tensor([[-0.1818,  0.1917],
          [-0.4507,  0.5777],
          [-0.2739,  0.4999],
          [-0.3951,  0.2838],
          [-0.4525,  0.2892],
          [-0.3325,  0.3914],
          [-0.0384,  0.5197],
          [-0.3696,  0.6055],
          [-0.5103,  0.5117],
          [-0.2589,  0.6148]]),
  'probe2': tensor([[ 0.5407],
          [-0.1149],
          [ 0.2945],
          [ 0.0469],
          [-0.0836],
          [ 0.1753],
          [ 0.8246],
          [ 0.0651],
          [-0.2414],
          [ 0.3144]])}}