In [1]:
import torch
from timm.models.vision_transformer import vit_base_patch16_224
from torchvision.models.feature_extraction import get_graph_node_names
from src.pim_module_2 import PluginMoodel

In [2]:
backbone = vit_base_patch16_224(pretrained=False)

In [None]:
return_nodes = {
    "blocks.8": "layer1",
    "blocks.9": "layer2",
    "blocks.10": "layer3",
    "blocks.11": "layer4",
}
num_selects = {
    "layer1": 32,
    "layer2": 32,
    "layer3": 32,
    "layer4": 32,
}
model = PluginMoodel(
    backbone=backbone,
    return_nodes=return_nodes,
    img_size=224,
    use_fpn=True,
    fpn_size=256,
    proj_type="Linear",
    upsample_type="Conv",
    use_selection=True,
    num_classes=10,
    num_selects=num_selects,
    use_combiner=True,
    comb_proj_size=None,
)
model.eval()
with torch.no_grad():
    x = torch.randn(10, 3, 224, 224)
    output = model(x)
    #print(output)
output.keys()

In [2]:
model = vit_base_patch16_224(init_values=1e-5)
model.eval()
with torch.no_grad():
    input = torch.randn(1, 3, 224, 224)
    output = model(input)
    print(output.shape)

torch.Size([1, 1000])


In [1]:
# train_nodes, eval_nodes = get_graph_node_names(model)
# train_nodes

In [5]:
# The nodes in the `train_nodes` and `eval_nodes` lists represent the different layers and operations in the `Model` class.
# Here's a breakdown of what each node corresponds to:
#
# 1. **'x'**: The input tensor to the model.
# 2. **'conv1.0'**: The first convolutional layer in `conv1` (Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))).
# 3. **'conv1.1'**: The first batch normalization layer in `conv1` (BatchNorm2d(64)).
# 4. **'conv1.2'**: The first ReLU activation in `conv1`.
# 5. **'conv1.3'**: The second convolutional layer in `conv1` (Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))).
# 6. **'conv1.4'**: The second batch normalization layer in `conv1` (BatchNorm2d(64)).
# 7. **'conv1.5'**: The second ReLU activation in `conv1`.
# 8. **'conv2.0'**: The first convolutional layer in `conv2` (Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))).
# 9. **'conv2.1'**: The first batch normalization layer in `conv2` (BatchNorm2d(128)).
# 10. **'conv2.2'**: The first ReLU activation in `conv2`.
# 11. **'conv2.3'**: The second convolutional layer in `conv2` (Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))).
# 12. **'conv2.4'**: The second batch normalization layer in `conv2` (BatchNorm2d(128)).
# 13. **'conv2.5'**: The second ReLU activation in `conv2`.
# 14. **'pool'**: The adaptive average pooling layer (AdaptiveAvgPool2d(output_size=(1, 1))).
# 15. **'flatten'**: The operation that flattens the tensor before passing it to the classifier.
# 16. **'classifier'**: The final linear layer that outputs the class scores (Linear(in_features=128, out_features=10, bias=True)).

# These nodes represent the sequence of operations that the input tensor goes through as it passes through the model.