In [1]:
import timm
import torch
from timm.models.vision_transformer import VisionTransformer, _cfg
image = torch.rand(1,3,224,224)

# Base model

In [6]:
import os
import sys
ROOT = os.getenv('BONNER_ROOT_PATH')
sys.path.append(ROOT)

import sys
import torchvision
import torch
from torch import nn
import pickle
import os
import timm
from timm.models.vision_transformer import VisionTransformer
from model_features.layer_operations.output import Output
torch.manual_seed(0)
torch.cuda.manual_seed(0)
import torch
import torch.nn as nn
# load untrained mdoel
untrained_model = timm.create_model('vit_base_patch16_224', pretrained=False)

EMBED_DIM = untrained_model.blocks[0].attn.qkv.in_features  
NUM_HEADS = untrained_model.blocks[0].attn.num_heads 
OUT_FEATURES = untrained_model.blocks[0].mlp.fc1.out_features
MLP_RATIO = int(OUT_FEATURES/EMBED_DIM)  
NUM_LAYERS = len(untrained_model.blocks) 

In [7]:
model = VisionTransformer(
            img_size=224,
            patch_size=16,
            in_chans=108,
            embed_dim=OUT_FEATURES,
            depth=len(untrained_model.blocks),  # more layers
            num_heads=NUM_HEADS,  # more attention heads
            mlp_ratio=int(OUT_FEATURES/EMBED_DIM),  # higher mlp ratio
            num_classes=1000)

In [8]:
print(model)

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(108, 3072, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((3072,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=3072, out_features=9216, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=3072, out_features=3072, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((3072,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=3072, out_features=12288, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)
        (norm):

In [64]:

class TimmViT(nn.Module):
    def __init__(self, out_features:int):
        super(TimmViT, self).__init__()

        self.out_features = out_features
        #self.model = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=out_features)
        if self.out_features < 12:
            NUM_HEADS = 6
        else:
            NUM_HEADS = untrained_model.blocks[0].attn.num_heads 
            
        self.model =  VisionTransformer(
            img_size=224,
            patch_size=16,
            embed_dim=self.out_features,
            depth=len(untrained_model.blocks),  # more layers
            num_heads=NUM_HEADS,  # more attention heads
            mlp_ratio=int(OUT_FEATURES/EMBED_DIM)  ,  # higher mlp ratio
            num_classes=1000  
        )

    def forward(self, x):
        return self.model(x)    
    
    
    
class BaseModel(nn.Module):
    def __init__(self, out_features:int, block:int, fixed_pos_encoding:bool, last:nn.Module, device:str='cuda'):
        super(BaseModel, self).__init__()

        self.block = block
        self.out_features = out_features
        self.last = last
        self.device = device
        self.model = TimmViT(out_features = self.out_features)
        self.activations = {}

        # Register hooks
        self.register_hooks()

    
    def _hook_fn(self, idx, module, input, output):
        """ Hook function to capture activations without the class token. """
        # Remove the class token (first token) from the output
        activations_without_class_token = output[:, 1:, :]  # Exclude the first token
        self.activations[f'block_{idx}'] = activations_without_class_token

    def register_hooks(self):
        # Register hooks on each transformer block with the correct index
        for i, blk in enumerate(self.model.model.blocks):
            blk.register_forward_hook(lambda m, inp, out, idx=i: self._hook_fn(idx, m, inp, out))

    def forward(self, x):
        x = x.to(self.device)
        self.model.to(self.device)
        
        _ = self.model(x)  # Perform the forward pass to populate activations

        # Use activations from a specific layer
        activations = self.activations.get(f'block_{self.block}', None)
        if activations is not None:
            x = self.last(activations)
        else:
            raise ValueError(f"No activations found for block {self.block}")

        return x
    
    
class ViTBase:
    
    def __init__(self,out_features:int, fixed_pos_encoding:bool=False, block=11,device='cuda'):
    
        self.block = block
        self.out_features = out_features
        self.fixed_pos_encoding = fixed_pos_encoding
        self.device = device
        
    def Build(self):
    
        last = Output()
        
        return BaseModel(
            out_features = self.out_features,
            fixed_pos_encoding=self.fixed_pos_encoding,
            block=self.block,
                last = last)

In [66]:
model = ViTBase(out_features = 6, device='cuda').Build()
model(torch.randn(1, 3, 224, 224)).shape

torch.Size([1, 1176])

In [69]:
model = ViTBase(out_features = 12*5, device='cuda').Build()
model(torch.randn(1, 3, 224, 224)).shape

torch.Size([1, 11760])

In [67]:
model = ViTBase(out_features = 12*50, device='cuda').Build()
model(torch.randn(1, 3, 224, 224)).shape

torch.Size([1, 117600])

In [37]:
# Example Usage
vit_large = ViTLarge(out_features=3, device='cuda')
model = vit_large.Build()
model.register_hooks()  # Register hooks on the transformer blocks

dummy_input = torch.randn(1, 3, 224, 224)  # Example input
output = model(dummy_input.to('cuda'))

torch.Size([1, 108])


In [38]:
model.get_activations('block_11')

tensor([[[ 0.6006, -1.3674, -0.1669,  ..., -0.5781, -0.0098,  0.5785],
         [-1.2437,  0.1974, -0.8219,  ...,  0.1891, -0.7614,  0.2268],
         [-0.7659, -0.1718,  0.3430,  ...,  0.2475,  1.0051,  0.9157],
         ...,
         [-0.7101, -0.0846, -0.6696,  ..., -0.0134,  0.1675,  0.1802],
         [ 0.5027, -1.3741,  0.2895,  ..., -1.0228,  0.1562,  0.0765],
         [-0.8894,  0.2291, -0.4392,  ..., -0.5029, -0.1683,  0.6237]]],
       device='cuda:0', grad_fn=<SliceBackward0>)

In [20]:
model(image).shape

torch.Size([1, 108])


torch.Size([1, 108])

In [3]:
# Load a standard Vision Transformer model for reference
reference_model = timm.create_model('vit_base_patch16_224', pretrained=False)
reference_model.head = torch.nn.Identity()

pretrained_model = timm.create_model('vit_base_patch16_224', pretrained=True)
reference_model.pos_embed = pretrained_model.pos_embed # use learned positional embeddings
reference_model(image).shape

torch.Size([1, 768])

In [4]:
print(reference_model)

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)
        (norm): Identity(

In [19]:
embed_dim = reference_model.blocks[0].attn.qkv.in_features  
print('embed_dim:',embed_dim)
num_heads = reference_model.blocks[0].attn.num_heads 
print('num_heads:',num_heads)
out_features = reference_model.blocks[0].mlp.fc1.out_features
print('mlp_out_features:',out_features)
mlp_ratio = int(out_features/embed_dim)  
print('mlp_ratio:',mlp_ratio)
num_layers = len(reference_model.blocks) 
print('num_layers:',num_layers)

embed_dim: 768
num_heads: 12
mlp_out_features: 3072
mlp_ratio: 4
num_layers: 12


# Base Model Test

In [2]:
import os
import sys
ROOT = os.getenv('BONNER_ROOT_PATH')
sys.path.append(ROOT)
print(ROOT)
import sys
import torchvision
import torch
from model_features.models.ViT import ViTBase
from torch import nn


/home/akazemi3/Desktop/untrained_models_of_visual_cortex/


# Larger transformer

In [26]:
large_model = VisionTransformer(
    img_size=224,
    patch_size=16,
    embed_dim=embed_dim,
    depth=num_layers*2, # more layers
    num_heads=num_heads*4, #more attention heads
    mlp_ratio=mlp_ratio*2, # higher mlp ratio
    num_classes=1000  
)
pretrained_model = timm.create_model('vit_base_patch16_224', pretrained=True)
large_model.pos_embed = pretrained_model.pos_embed # use learned positional embeddings
large_model.head = torch.nn.Linear(embed_dim, 108000)
large_model(image).shape

torch.Size([1, 108000])

In [24]:
print('embed_dim:',embed_dim)
print('num_heads:',num_heads*4)
print('mlp_ratio:',mlp_ratio*2)
print('mlp_out_features:',out_features*2)
print('num_layers:',num_layers*2)

embed_dim: 768
num_heads: 48
mlp_ratio: 8
mlp_out_features: 6144
num_layers: 24
