In [1]:
from typing import Tuple, List, Dict

import torch
import numpy as np
from numpy.testing import assert_array_almost_equal

from pytorch_probing import collect

In [2]:
class ExampleModel(torch.nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()

        self.linear1 = torch.nn.Linear(input_size, hidden_size)
        self.relu = torch.nn.ReLU()
        self.linear2 = torch.nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)

        return x


In [3]:
test_input_size = 2
test_hidden_size = 3
test_output_size = 1

example_model = ExampleModel(test_input_size, test_hidden_size, test_output_size)
example_model.eval()

ExampleModel(
  (linear1): Linear(in_features=2, out_features=3, bias=True)
  (relu): ReLU()
  (linear2): Linear(in_features=3, out_features=1, bias=True)
)

In [4]:
inputs = torch.randn([10, 2])
example_model(inputs)

tensor([[-0.1559],
        [-0.1385],
        [ 0.2696],
        [ 0.0836],
        [ 0.0540],
        [ 0.2701],
        [ 0.0978],
        [ 0.1786],
        [ 0.0248],
        [ 0.0812]], grad_fn=<AddmmBackward0>)

In [5]:
from torch.utils.data import Dataset, DataLoader

class ExampleDataset(Dataset):
    def __init__(self, x_size, y_size, len) -> None:
        super().__init__()

        self._x_size = x_size
        self._y_size = y_size
        self._len = len

    def __len__(self) -> int:
        return self._len
    
    def __getitem__(self, idx:int):
        return torch.empty(self._x_size).fill_(idx), torch.empty(self._y_size).fill_(idx)

In [6]:
dataset = ExampleDataset(test_input_size, test_output_size, 32)

In [7]:
dataloader = DataLoader(dataset, 4, shuffle=False)

In [8]:
paths = ["linear1"]

dataset_path = collect(example_model, paths, dataloader)

In [9]:
chunk0 = torch.load(dataset_path+"\\0.pt")

In [10]:
chunk0["intercepted_outputs"]

{'linear1': tensor([[-0.3236,  0.5352,  0.3620],
         [-0.4746,  0.0420, -0.3842],
         [-0.6257, -0.4512, -1.1303],
         [-0.7767, -0.9444, -1.8765]])}

In [11]:
chunk0["index"]

0

In [12]:
dataset_path = collect(example_model, paths, dataloader, save_prediction=True, save_target=True, save_input=True)

In [13]:
chunk0 = torch.load(dataset_path+"\\0.pt")

In [14]:
list(chunk0.keys())

['intercepted_outputs', 'index', 'input', 'target', 'prediction']

In [15]:
chunk0["target"]

tensor([[0.],
        [1.],
        [2.],
        [3.]])

In [16]:
chunk0["prediction"]

tensor([[0.0958],
        [0.2809],
        [0.2833],
        [0.2833]])

In [17]:
from pytorch_probing import CollectedDataset

In [18]:
collected_dataset = CollectedDataset(dataset_path, True, True, True)

In [19]:
collected_dataset[0]

({'linear1': tensor([-0.3236,  0.5352,  0.3620])},
 tensor([0.]),
 tensor([0.0958]),
 tensor([0., 0.]))

In [25]:
chunk0

{'intercepted_outputs': {'linear1': tensor([[-0.3236,  0.5352,  0.3620],
          [-0.4746,  0.0420, -0.3842],
          [-0.6257, -0.4512, -1.1303],
          [-0.7767, -0.9444, -1.8765]])},
 'index': 0,
 'input': tensor([[0., 0.],
         [1., 1.],
         [2., 2.],
         [3., 3.]]),
 'target': tensor([[0.],
         [1.],
         [2.],
         [3.]]),
 'prediction': tensor([[0.0958],
         [0.2809],
         [0.2833],
         [0.2833]])}

In [21]:
chunk0 = dict(chunk0)

In [22]:
torch.save(chunk0, "test.pt")

In [23]:
a = torch.load("test.pt")