In [1]:
from typing import Tuple, List, Dict

import torch
import numpy as np
from numpy.testing import assert_array_almost_equal

from pytorch_probing import Interceptor

In [2]:
class ExampleModel(torch.nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()

        self.linear1 = torch.nn.Linear(input_size, hidden_size)
        self.relu = torch.nn.ReLU()
        self.linear2 = torch.nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)

        return x


In [5]:
test_input_size = 2
test_hidden_size = 3
test_output_size = 1

example_model = ExampleModel(test_input_size, test_hidden_size, test_output_size)
example_model.eval()

ExampleModel(
  (linear1): Linear(in_features=2, out_features=3, bias=True)
  (relu): ReLU()
  (linear2): Linear(in_features=3, out_features=1, bias=True)
)

In [6]:
inputs = torch.randn([10, 2])
example_model(inputs)

tensor([[-0.3138],
        [-0.3138],
        [-0.2525],
        [-0.3138],
        [-0.3076],
        [-0.1679],
        [-0.2765],
        [-0.2830],
        [-0.3138],
        [-0.2656]], grad_fn=<AddmmBackward0>)

In [12]:
from torch.utils.data import Dataset, DataLoader

class ExampleDataset(Dataset):
    def __init__(self, x_size, y_size, len) -> None:
        super().__init__()

        self._x_size = x_size
        self._y_size = y_size
        self._len = len

    def __len__(self) -> int:
        return self._len
    
    def __getitem__(self, idx:int):
        return torch.empty(self._x_size).fill_(idx), torch.empty(self._y_size).fill_(idx)

In [13]:
dataset = ExampleDataset(test_input_size, test_output_size, 32)

In [14]:
dataloader = DataLoader(dataset, 4, shuffle=False)

In [51]:
from typing import List, Tuple
import os
import datetime

import numpy as np

def collect(module:torch.nn.Module, paths:List[str], dataloader:DataLoader, 
            save_path:str|None = None, dataset_name:str|None = None,
            device:str|None=None, 
            save_target:bool=False, save_prediction:bool=False):
    
    original_mode = module.training
    module.eval()

    if save_path is None:
        save_path = "."
    
    if dataset_name is None:
        dataset_name = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S-%f")

    dataset_path = os.path.join(save_path, dataset_name)

    if not os.path.exists(dataset_path):
        os.makedirs(dataset_path)
    


    if device is None:
        device = next(module.parameters()).device
    else:
        device = torch.device(device)
        module = module.to(device)

    with Interceptor(module, paths) as interceptor:
        with torch.no_grad():
            for chunk_index, (x, y) in enumerate(dataloader):
                x_device = x.to(device)

                pred : torch.Tensor | Tuple[torch.Tensor] = interceptor(x_device)

                intercepted_outputs = interceptor.outputs
                for name in intercepted_outputs:
                    if isinstance(intercepted_outputs[name], list):
                        for i in range(len(intercepted_outputs[name])):
                            intercepted_outputs[name][i] = intercepted_outputs[name][i].cpu()
                    else:
                        intercepted_outputs[name] = intercepted_outputs[name].cpu()  
                
                chunk = {"inputs":x, "intercepted_outputs":intercepted_outputs, "index":chunk_index}

                if save_target:
                    chunk["target"] = y
                if save_prediction:

                    pred_cpu = None
                    if isinstance(pred, tuple):
                        pred_cpu = []
                        for pred_item in pred:
                            pred_cpu.append(pred_item.cpu())
                    else:
                        pred_cpu = pred.cpu()

                    chunk["prediction"] = pred_cpu                        
                
                chunk_path = os.path.join(dataset_path, str(chunk_index))

                np.savez(chunk_path, **chunk)

    module.train(original_mode)


In [52]:
paths = ["linear1"]

collect(example_model, paths, dataloader)

In [67]:
chunk0 = np.load("2024-08-14-17-24-21-740127\\0.npz", allow_pickle=True)

In [68]:
chunk0

NpzFile '2024-08-14-17-24-21-740127\\0.npz' with keys: inputs, intercepted_outputs, index

In [69]:
chunk0["inputs"]

array([[0., 0.],
       [1., 1.],
       [2., 2.],
       [3., 3.]], dtype=float32)

In [70]:
chunk0["intercepted_outputs"]

array({'linear1': tensor([[-0.4829, -0.1482, -0.0104],
        [-1.3781, -0.6280,  0.0804],
        [-2.2733, -1.1077,  0.1712],
        [-3.1684, -1.5875,  0.2620]])}, dtype=object)

In [71]:
chunk0["index"]

array(0)

In [None]:
chunk0["index"]