In [1]:
from typing import Tuple, List, Dict

import torch
import numpy as np
from numpy.testing import assert_array_almost_equal

from pytorch_probing import collect

In [2]:
class ExampleModel(torch.nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()

        self.linear1 = torch.nn.Linear(input_size, hidden_size)
        self.relu = torch.nn.ReLU()
        self.linear2 = torch.nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)

        return x


In [3]:
test_input_size = 2
test_hidden_size = 3
test_output_size = 1

example_model = ExampleModel(test_input_size, test_hidden_size, test_output_size)
example_model.eval()

ExampleModel(
  (linear1): Linear(in_features=2, out_features=3, bias=True)
  (relu): ReLU()
  (linear2): Linear(in_features=3, out_features=1, bias=True)
)

In [4]:
inputs = torch.randn([10, 2])
example_model(inputs)

tensor([[0.3525],
        [0.4384],
        [0.4575],
        [0.4426],
        [0.4537],
        [0.1262],
        [0.3229],
        [0.1582],
        [0.4582],
        [0.4465]], grad_fn=<AddmmBackward0>)

In [5]:
from torch.utils.data import Dataset, DataLoader

class ExampleDataset(Dataset):
    def __init__(self, x_size, y_size, len) -> None:
        super().__init__()

        self._x_size = x_size
        self._y_size = y_size
        self._len = len

    def __len__(self) -> int:
        return self._len
    
    def __getitem__(self, idx:int):
        return torch.empty(self._x_size).fill_(idx), torch.empty(self._y_size).fill_(idx)

In [6]:
dataset = ExampleDataset(test_input_size, test_output_size, 32)

In [7]:
dataloader = DataLoader(dataset, 4, shuffle=False)

In [8]:
paths = ["linear1"]

dataset_path = collect(example_model, paths, dataloader)

In [9]:
chunk0 = np.load(dataset_path+"\\0.npz", allow_pickle=True)

In [10]:
chunk0

NpzFile '.\\2024-08-15-14-26-25-534017\\0.npz' with keys: intercepted_outputs, index

In [11]:
chunk0["intercepted_outputs"]

array({'linear1': tensor([[-0.5331,  0.3687,  0.3659],
        [-1.4146, -0.4742,  0.3194],
        [-2.2961, -1.3171,  0.2728],
        [-3.1777, -2.1600,  0.2262]])}, dtype=object)

In [12]:
chunk0["index"]

array(0)

In [13]:
dataset_path = collect(example_model, paths, dataloader, save_prediction=True, save_target=True, save_input=True)

In [14]:
chunk0 = np.load(dataset_path+"\\0.npz", allow_pickle=True)

In [15]:
list(chunk0.keys())

['intercepted_outputs', 'index', 'input', 'target', 'prediction']

In [16]:
chunk0["target"]

array([[0.],
       [1.],
       [2.],
       [3.]], dtype=float32)

In [17]:
chunk0["prediction"]

array([[0.25266975],
       [0.45033097],
       [0.45147413],
       [0.45261732]], dtype=float32)

In [18]:
from pytorch_probing import CollectedDataset

In [19]:
collected_dataset = CollectedDataset(dataset_path, True, True, True)

In [20]:
collected_dataset[0]

({'linear1': tensor([-0.5331,  0.3687,  0.3659])},
 array([0.], dtype=float32),
 array([0.25266975], dtype=float32),
 array([0., 0.], dtype=float32))

In [25]:
chunk0["intercepted_outputs"]

array({'linear1': tensor([[-0.5331,  0.3687,  0.3659],
        [-1.4146, -0.4742,  0.3194],
        [-2.2961, -1.3171,  0.2728],
        [-3.1777, -2.1600,  0.2262]])}, dtype=object)