In [11]:
import torch

from pytorch_probing import Prober, ParallelModuleDict

In [12]:
class ExampleModel(torch.nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()

        self.linear1 = torch.nn.Linear(input_size, hidden_size)
        self.relu = torch.nn.ReLU()
        self.linear2 = torch.nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)

        return x


In [13]:
input_size = 2
hidden_size = 3
output_size = 1

model = ExampleModel(input_size, hidden_size, output_size)
model.eval()

ExampleModel(
  (linear1): Linear(in_features=2, out_features=3, bias=True)
  (relu): ReLU()
  (linear2): Linear(in_features=3, out_features=1, bias=True)
)

In [14]:
probe_size = 2

probe = torch.nn.Linear(hidden_size, probe_size)
probe.eval()

Linear(in_features=3, out_features=2, bias=True)

In [15]:
probes = {"linear1":probe, "relu":None}

probed_model = Prober(model, probes)

In [16]:
inputs = torch.randn([10, 2])

with torch.no_grad():
    outputs = probed_model(inputs)

In [17]:
outputs[1]

{'linear1': tensor([[ 0.7286, -0.2807],
         [ 0.7017, -0.2080],
         [ 0.3983,  0.7984],
         [ 0.6345, -0.2140],
         [ 0.5134, -0.0495],
         [ 0.8028, -0.5010],
         [ 0.6500,  0.2377],
         [ 0.7007,  0.2208],
         [ 0.6548,  0.3234],
         [ 0.8624, -0.7273]]),
 'relu': tensor([[0.0652, 0.7930, 0.4492],
         [0.0000, 0.7011, 0.3725],
         [0.0000, 0.0000, 0.0000],
         [0.0000, 0.5843, 0.5463],
         [0.0000, 0.2685, 0.5175],
         [0.2736, 1.0580, 0.6998],
         [0.0000, 0.3413, 0.0000],
         [0.0000, 0.4423, 0.0000],
         [0.0000, 0.2985, 0.0000],
         [0.4734, 1.3006, 0.9972]])}

In [18]:
probe2_size = 1

probe2 = torch.nn.Linear(hidden_size, probe2_size)
probe2.eval()

Linear(in_features=3, out_features=1, bias=True)

In [19]:
model = probed_model.reduce()

In [20]:
linear1_probes = ParallelModuleDict({"probe1":probe, "probe2":probe2})

probes = {"linear1":linear1_probes}

probed_model = Prober(model, probes)

In [21]:
inputs = torch.randn([10, 2])

with torch.no_grad():
    outputs = probed_model(inputs)

In [23]:
outputs[1]

{'linear1': {'probe1': tensor([[ 1.2929, -0.9426],
          [ 0.6545,  0.1600],
          [ 0.7082, -0.1651],
          [ 0.8528,  0.2046],
          [ 0.8608, -0.5736],
          [ 0.5992,  0.3035],
          [ 0.7458, -0.8529],
          [ 0.4288,  0.1822],
          [ 0.3741,  0.5822],
          [ 1.2560, -1.0933]]),
  'probe2': tensor([[-0.1078],
          [ 0.2343],
          [ 0.2032],
          [ 0.1320],
          [ 0.1199],
          [ 0.2645],
          [ 0.1767],
          [ 0.3516],
          [ 0.3840],
          [-0.0902]])}}