### pytorch의 hook을 사용하면 network 중간 layer의 통과 결과를 얻을 수 있다

In [22]:
import torch
import torch.nn as nn
import torchsummary

In [48]:
class Savelayer5:
    def __init__(self):
        self.outputs = []
        
    def __call__(self, module, module_in, module_out):
        self.outputs.append(module_out)
        
    def clear(self):
        self.outputs = []
        
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.dim = torch.arange(100,9,-10)
        layers= []
        for in_,out_ in zip(self.dim[:-1],self.dim[1:]):
            layers.append(nn.Linear(in_,out_))
        
        self.fc_ = nn.Sequential(*layers)
        
        self.hook = Savelayer5()
        self.hook_handles=[]
        handle = self.fc_[4].register_forward_hook(self.hook)
        self.hook_handles.append(handle)
        
        
    def forward(self,x):
        
        return self.fc_(x)

In [49]:
model =Model()

In [50]:
torchsummary.summary(model=model,
                     input_size=(1,100),
                     device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                [-1, 1, 90]           9,090
            Linear-2                [-1, 1, 80]           7,280
            Linear-3                [-1, 1, 70]           5,670
            Linear-4                [-1, 1, 60]           4,260
            Linear-5                [-1, 1, 50]           3,050
            Linear-6                [-1, 1, 40]           2,040
            Linear-7                [-1, 1, 30]           1,230
            Linear-8                [-1, 1, 20]             620
            Linear-9                [-1, 1, 10]             210
Total params: 33,450
Trainable params: 33,450
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.13
Estimated Total Size (MB): 0.13
---------------------------------------------

In [55]:
test_sample_in = torch.randn(1,100)
test_sample_out = model(test_sample_in)

In [56]:
print("Input size", test_sample_in.shape)
print("Output size", test_sample_out.shape)

print("layer 5 size", model.hook.outputs[0].shape)

Input size torch.Size([1, 100])
Output size torch.Size([1, 10])
layer 5 size torch.Size([1, 50])
