In [1]:
import torch
from torch import Tensor
from torch import nn
from torchvision import models
from functools import partial
from xpyutils import lazy_property

In [None]:
class SimpleNet(nn.Module):
    
    def __init__(self) -> None:
        
        super().__init__()
        self._linear1 = nn.Linear(2, 3)
        self._linear2 = nn.Linear(3, 1)
        self._out = {}
        
        self._linear1.register_forward_hook(
            partial(self.record_output, output_key='linear1')
        )
    
    @property   
    def out(self) -> dict[str, Tensor]:
        return self._out
        
    def forward(self, x: Tensor) -> Tensor:
        
        x = self._linear1(x)
        x = self._linear2(x)
        
        return x

    def record_output(self, module: nn.Module, input: Tensor, output: Tensor, output_key: str):
        self._out[output_key] = output.detach()


In [None]:
net = SimpleNet()
# net._linear1.register_forward_hook(hook)
x = torch.Tensor([1, 2])

In [None]:
net(x)

In [None]:
net.out

In [None]:
def f(a, b, c):
    
    print(f'a: {a}, b: {b}, c: {c}')

partial(f, b=1, c=5)(100)

In [None]:
A = set([1, 2, 3])
B = set([2, 3, 4])
A.intersection([3, 1])

In [2]:
import utils
from deepdream.models import VGG16DeepDreamer

In [5]:
dreamer = VGG16DeepDreamer().watch_layers({
    'relu1_2',
    'relu2_2',
    'relu3_3',
    'relu4_2', 'relu4_3',
    'relu5_3'
})

dreamer

VGG16DeepDreamer(
  (_model): VGG(
    (features): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU(inplace=True)
      (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (3): ReLU(inplace=True)
      (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (6): ReLU(inplace=True)
      (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (8): ReLU(inplace=True)
      (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (11): ReLU(inplace=True)
      (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (13): ReLU(inplace=True)
      (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (15): ReLU(inplace=True)
  

In [6]:
dreamer.get_output('relu5_2')

In [7]:
t = torch.randn(1, 3, 224, 224)
t.shape

torch.Size([1, 3, 224, 224])

In [8]:
dreamer(t)

tensor([[[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.2608]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0

In [10]:
dreamer.get_output('relu4_2').shape

torch.Size([1, 512, 28, 28])

In [14]:
dreamer.get_output('relu2_2')

tensor([[[[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  2.0347],
          ...,
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  7.4537,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  1.0228,  0.0889,  ...,  0.0000,  0.0000,  0.0000]],

         [[ 2.9610, 31.3792, 21.0972,  ..., 20.0930, 13.6775, 18.1205],
          [ 0.0000, 17.8857,  1.8780,  ...,  0.0000,  0.0000,  7.5708],
          [ 0.0000, 18.3544,  0.5855,  ...,  9.5420,  0.0000,  8.2746],
          ...,
          [ 0.0000, 26.1811, 13.1080,  ..., 10.9471,  0.0000, 17.2182],
          [ 0.0000, 18.7376,  6.7627,  ...,  2.4514,  0.0000, 18.9446],
          [ 0.0000,  6.9951,  0.0000,  ...,  0.0000,  0.0000,  5.1146]],

         [[ 0.0000,  0.0000,  0.0000,  ...,  1.2753,  0.0000,  2.6125],
          [ 0.0000,  6.4526, 1