In [1]:
import timm
import torch.nn as nn

def extract_submodule(model_name: str, layer_path: str) -> nn.Module:
    """
    Extract a submodule from a timm Vision Transformer model up to a specified layer.

    Args:
        model_name (str): Name of the ViT model in the timm library.
        layer_path (str): The layer path up to which the submodule should be extracted, separated by slashes.

    Returns:
        nn.Module: The extracted submodule.
    """
    # Load the model
    model = timm.create_model(model_name, pretrained=True)
   
    # Split the layer path into parts
    layer_parts = layer_path.split('/')
   
    # Create a new Sequential container to hold the submodule
    submodule = nn.Sequential()
   
    current_module = model
    for part in layer_parts:
        # If part is a block index, add the blocks sequentially
        if part.startswith('blocks'):
            blocks_idx = int(part.split('.')[1])
            blocks = nn.Sequential()
            for i, block in enumerate(current_module.blocks.children()):
                blocks.add_module(f'block_{i}', block)
                if i == blocks_idx:
                    submodule.add_module('blocks', blocks)
                    return submodule
        else:
            current_module = getattr(current_module, part)
            submodule.add_module(part, current_module)
   
    return submodule

# Example usage
model_name = 'vit_base_patch16_224'
layer_path = 'blocks/10/mlp/fc1'  # Specify the path to the layer you want to extract up to
submodule = extract_submodule(model_name, layer_path)

  from .autonotebook import tqdm as notebook_tqdm


IndexError: list index out of range

: 