# Inspect Extraction Model

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchinfo import summary
import sys
sys.path.append('../')
from torch.utils.tensorboard import SummaryWriter

In [2]:
from model.extraction import ExtractionConfig, Extraction

img_h, img_w, n_class = 224, 224, 1000
model_args = dict(
    img_h=img_h,
    img_w=img_w,
    n_class=n_class
)

model_config = ExtractionConfig(**model_args)
model = Extraction(model_config)

batch_size = 1
summary(model, input_size=[(batch_size, 3, img_h, img_w)], device='cpu',
        col_names=("input_size", "output_size", "num_params", "kernel_size", "mult_adds", "trainable"), verbose=2,
        depth=4, row_settings=("depth", "var_names"));

Layer (type (var_name):depth-idx)             Input Shape               Output Shape              Param #                   Kernel Shape              Mult-Adds                 Trainable
Extraction (Extraction)                       [1, 3, 224, 224]          [1, 1000]                 --                        --                        --                        True
├─ExtractionBackbone (backbone): 1-1          [1, 3, 224, 224]          [1, 1024, 7, 7]           --                        --                        --                        True
│    └─conv1.conv.weight                                                                          ├─9,408                   [64, 3, 7, 7]
│    └─conv1.bn.weight                                                                            ├─64                      [64]
│    └─conv1.bn.bias                                                                              ├─64                      [64]
│    └─conv2.conv.weight                                

In [3]:
writer = SummaryWriter()
imgs = torch.randn(batch_size, 3, img_h, img_w)
targets = torch.randint(0, n_class, (batch_size,))
writer.add_graph(model, [imgs, targets])
writer.close()

In [4]:
logits, loss = model(imgs, targets)
print(logits.shape)
print(loss.shape)

torch.Size([1, 1000])
torch.Size([])


## Inspect state_dict

In [4]:
import tempfile

# Save state_dict as a temporary file
with tempfile.NamedTemporaryFile(suffix='.pth') as temp_file:
    torch.save(model.state_dict(), temp_file.name)
    state_dict = torch.load(temp_file.name)

# Print the keys of the state_dict
for key in state_dict.keys():
    print(key)


backbone.conv1.conv.weight
backbone.conv1.bn.weight
backbone.conv1.bn.bias
backbone.conv1.bn.running_mean
backbone.conv1.bn.running_var
backbone.conv1.bn.num_batches_tracked
backbone.conv2.conv.weight
backbone.conv2.bn.weight
backbone.conv2.bn.bias
backbone.conv2.bn.running_mean
backbone.conv2.bn.running_var
backbone.conv2.bn.num_batches_tracked
backbone.conv3.conv.weight
backbone.conv3.bn.weight
backbone.conv3.bn.bias
backbone.conv3.bn.running_mean
backbone.conv3.bn.running_var
backbone.conv3.bn.num_batches_tracked
backbone.conv4.conv.weight
backbone.conv4.bn.weight
backbone.conv4.bn.bias
backbone.conv4.bn.running_mean
backbone.conv4.bn.running_var
backbone.conv4.bn.num_batches_tracked
backbone.conv5.conv.weight
backbone.conv5.bn.weight
backbone.conv5.bn.bias
backbone.conv5.bn.running_mean
backbone.conv5.bn.running_var
backbone.conv5.bn.num_batches_tracked
backbone.conv6.conv.weight
backbone.conv6.bn.weight
backbone.conv6.bn.bias
backbone.conv6.bn.running_mean
backbone.conv6.bn.runnin