In [1]:
import timm

timm_model = timm.create_model("hf_hub:timm/vit_base_patch14_reg4_dinov2.lvd142m", pretrained=True)
timm_model.named_parameters

  from .autonotebook import tqdm as notebook_tqdm


<bound method Module.named_parameters of VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(14, 14), stride=(14, 14))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): LayerScale()
      (drop_path1): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.

In [2]:
from transformers import FlaxDinov2Model
flax_model = FlaxDinov2Model.from_pretrained("facebook/dinov2-base")
flax_model._params_shape_tree

{'embeddings': {'cls_token': ShapeDtypeStruct(shape=(1, 1, 768), dtype=float32),
  'mask_token': ShapeDtypeStruct(shape=(1, 768), dtype=float32),
  'patch_embeddings': {'projection': {'bias': ShapeDtypeStruct(shape=(768,), dtype=float32),
    'kernel': ShapeDtypeStruct(shape=(14, 14, 3, 768), dtype=float32)}},
  'position_embeddings': ShapeDtypeStruct(shape=(1, 1370, 768), dtype=float32)},
 'encoder': {'layer': {'0': {'attention': {'attention': {'key': {'bias': ShapeDtypeStruct(shape=(768,), dtype=float32),
       'kernel': ShapeDtypeStruct(shape=(768, 768), dtype=float32)},
      'query': {'bias': ShapeDtypeStruct(shape=(768,), dtype=float32),
       'kernel': ShapeDtypeStruct(shape=(768, 768), dtype=float32)},
      'value': {'bias': ShapeDtypeStruct(shape=(768,), dtype=float32),
       'kernel': ShapeDtypeStruct(shape=(768, 768), dtype=float32)}},
     'output': {'dense': {'bias': ShapeDtypeStruct(shape=(768,), dtype=float32),
       'kernel': ShapeDtypeStruct(shape=(768, 768), dtyp

In [4]:
import numpy as np
import jax.numpy as jnp
from flax.training import checkpoints
from flax.core.frozen_dict import unfreeze
import os

def load_flax_checkpoint():
    from transformers import FlaxDinov2Model
    flax_model = FlaxDinov2Model.from_pretrained("facebook/dinov2-base")
    return flax_model._params

def convert_dinov2_to_bigvision(flax_params):
    bv_params = {}
    
    # Embedding layer
    bv_params['embedding/kernel'] = jnp.transpose(flax_params['embeddings']['patch_embeddings']['projection']['kernel'], (3, 2, 0, 1))
    bv_params['embedding/bias'] = flax_params['embeddings']['patch_embeddings']['projection']['bias']
    
    # Position embedding
    bv_params['pos_embedding'] = flax_params['embeddings']['position_embeddings']
    
    # Transformer blocks
    for i in range(12):  # Assuming 12 layers
        prefix = f'Transformer/encoderblock_{i}/'
        flax_prefix = f'encoder/layer/{i}/'
        
        # Layer Norm
        bv_params[prefix + 'LayerNorm_0/scale'] = flax_params[flax_prefix + 'norm1']['scale']
        bv_params[prefix + 'LayerNorm_0/bias'] = flax_params[flax_prefix + 'norm1']['bias']
        bv_params[prefix + 'LayerNorm_1/scale'] = flax_params[flax_prefix + 'norm2']['scale']
        bv_params[prefix + 'LayerNorm_1/bias'] = flax_params[flax_prefix + 'norm2']['bias']
        
        # Multi-Head Attention
        bv_params[prefix + 'MultiHeadDotProductAttention_0/query/kernel'] = flax_params[flax_prefix + 'attention/attention/query/kernel'].reshape(768, 12, 64)
        bv_params[prefix + 'MultiHeadDotProductAttention_0/query/bias'] = flax_params[flax_prefix + 'attention/attention/query/bias'].reshape(12, 64)
        bv_params[prefix + 'MultiHeadDotProductAttention_0/key/kernel'] = flax_params[flax_prefix + 'attention/attention/key/kernel'].reshape(768, 12, 64)
        bv_params[prefix + 'MultiHeadDotProductAttention_0/key/bias'] = flax_params[flax_prefix + 'attention/attention/key/bias'].reshape(12, 64)
        bv_params[prefix + 'MultiHeadDotProductAttention_0/value/kernel'] = flax_params[flax_prefix + 'attention/attention/value/kernel'].reshape(768, 12, 64)
        bv_params[prefix + 'MultiHeadDotProductAttention_0/value/bias'] = flax_params[flax_prefix + 'attention/attention/value/bias'].reshape(12, 64)
        bv_params[prefix + 'MultiHeadDotProductAttention_0/out/kernel'] = flax_params[flax_prefix + 'attention/output/dense/kernel'].reshape(12, 64, 768)
        bv_params[prefix + 'MultiHeadDotProductAttention_0/out/bias'] = flax_params[flax_prefix + 'attention/output/dense/bias']
        
        # MLP
        bv_params[prefix + 'MlpBlock_0/Dense_0/kernel'] = flax_params[flax_prefix + 'mlp/fc1/kernel']
        bv_params[prefix + 'MlpBlock_0/Dense_0/bias'] = flax_params[flax_prefix + 'mlp/fc1/bias']
        bv_params[prefix + 'MlpBlock_0/Dense_1/kernel'] = flax_params[flax_prefix + 'mlp/fc2/kernel']
        bv_params[prefix + 'MlpBlock_0/Dense_1/bias'] = flax_params[flax_prefix + 'mlp/fc2/bias']
    
    # Final Layer Norm
    bv_params['Transformer/encoder_norm/scale'] = flax_params['layernorm']['scale']
    bv_params['Transformer/encoder_norm/bias'] = flax_params['layernorm']['bias']
    
    return bv_params

def save_bigvision_checkpoint(params, output_path):
    np.savez(output_path, **params)

def main():
    output_path = '/home/austinwang/bigvision_dinov2.npz'
    
    print("Loading Flax checkpoint...")
    flax_params = load_flax_checkpoint()
    
    print("Converting parameters...")
    bv_params = convert_dinov2_to_bigvision(unfreeze(flax_params))
    
    print("Saving Big Vision compatible checkpoint...")
    save_bigvision_checkpoint(bv_params, output_path)
    
    print(f"Conversion complete. Checkpoint saved to {output_path}")

if __name__ == "__main__":
    main()

Loading Flax checkpoint...
Converting parameters...


KeyError: 'encoder/layer/0/norm1'