In [1]:
import torch
import torch.nn as nn 
import torchvision
from torchvision.models import vit_l_16
from torchvision.models import ViT_L_16_Weights

model = torchvision.models.vit_l_16(weights = ViT_L_16_Weights.IMAGENET1K_V1)

In [41]:
print(type(model.named_children()))

for foo1,foo2 in model.named_modules():
    print(foo1)

<class 'generator'>

conv_proj
encoder
encoder.dropout
encoder.layers
encoder.layers.encoder_layer_0
encoder.layers.encoder_layer_0.ln_1
encoder.layers.encoder_layer_0.self_attention
encoder.layers.encoder_layer_0.self_attention.out_proj
encoder.layers.encoder_layer_0.dropout
encoder.layers.encoder_layer_0.ln_2
encoder.layers.encoder_layer_0.mlp
encoder.layers.encoder_layer_0.mlp.0
encoder.layers.encoder_layer_0.mlp.1
encoder.layers.encoder_layer_0.mlp.2
encoder.layers.encoder_layer_0.mlp.3
encoder.layers.encoder_layer_0.mlp.4
encoder.layers.encoder_layer_1
encoder.layers.encoder_layer_1.ln_1
encoder.layers.encoder_layer_1.self_attention
encoder.layers.encoder_layer_1.self_attention.out_proj
encoder.layers.encoder_layer_1.dropout
encoder.layers.encoder_layer_1.ln_2
encoder.layers.encoder_layer_1.mlp
encoder.layers.encoder_layer_1.mlp.0
encoder.layers.encoder_layer_1.mlp.1
encoder.layers.encoder_layer_1.mlp.2
encoder.layers.encoder_layer_1.mlp.3
encoder.layers.encoder_layer_1.mlp.4
enco

In [2]:
# how to use hooks 
# https://medium.com/the-dl/how-to-use-pytorch-hooks-5041d777f904

class tinymodel(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.hooks = [4, 11, 17, 23]
        self.model = torchvision.models.vit_l_16(weights=ViT_L_16_Weights.IMAGENET1K_V1)
        
        # Access the encoder
        self.encoder = self.model.encoder.layers
        
        # Ensure the encoder has the layers we expect
        assert hasattr(self.encoder, f'encoder_layer_0'), "Encoder layer naming is inconsistent!"
        self.features = {}

        # Register forward hooks
        for hook in self.hooks:
            layer_name = f'encoder_layer_{hook}'
            if hasattr(self.encoder, layer_name):
                getattr(self.encoder, layer_name).register_forward_hook(self.save_output_hook(layer_name))
            else:
                print(f"Warning: {layer_name} not found in encoder layers.")
            
    def save_output_hook(self, layer_key):
        def save_output(_, __, output):
            self.features[layer_key] = output
        return save_output

    def forward(self, x):
        self.model(x)
        return self.features



silly_model = tinymodel()


In [3]:
dummy = torch.randn(size=(10,3,224,224)) # i think it should receive (B,C,W,H)? docs says "Accepts PIL.Image, batched (B, C, H, W)" 
out = silly_model(dummy)

assert len(out) == len(silly_model.hooks), print(f'size of the output : {len(out)}')


In [8]:
# the transformer's output shape should be (B, number_of_patches, embedding_dimension)
# We set B as 10 [OK]
# Number of patches = 224x224 / 16x16 = 196 . Plus +1 class token = 197 patches(ATT: class token is at [..,0,..])
# embedding dimension = 1024 [OK]
out['encoder_layer_4'].shape 


torch.Size([10, 197, 1024])

In [12]:
# in the Reassemble block, we must get rid of the `class token`
def Reassemble(out):
    """
    out size is (B, number of patches + class token, embedding dimension)
    out.shape = (B,Patches,D)
    I  want out.shape to be (B,D,16,16)
    """
    res = out[:,1:,:] # 196 patches
    print(res.shape)
    res = torch.reshape(res,(res.shape[0],res.shape[-1],14,14))
    print(res.shape)
    return res

tmr = Reassemble(out['encoder_layer_4'])
print(out['encoder_layer_4'].shape)

torch.Size([10, 196, 1024])
torch.Size([10, 1024, 14, 14])
torch.Size([10, 197, 1024])
