# Loading DINOv2 ckpts: Timm and FLAX

In [None]:
import timm

timm_model = timm.create_model("hf_hub:timm/vit_base_patch14_reg4_dinov2.lvd142m", pretrained=True)
timm_model.named_parameters

In [2]:
from transformers import FlaxDinov2Model
flax_model = FlaxDinov2Model.from_pretrained("facebook/dinov2-base")
flax_model._params_shape_tree

{'embeddings': {'cls_token': ShapeDtypeStruct(shape=(1, 1, 768), dtype=float32),
  'mask_token': ShapeDtypeStruct(shape=(1, 768), dtype=float32),
  'patch_embeddings': {'projection': {'bias': ShapeDtypeStruct(shape=(768,), dtype=float32),
    'kernel': ShapeDtypeStruct(shape=(14, 14, 3, 768), dtype=float32)}},
  'position_embeddings': ShapeDtypeStruct(shape=(1, 1370, 768), dtype=float32)},
 'encoder': {'layer': {'0': {'attention': {'attention': {'key': {'bias': ShapeDtypeStruct(shape=(768,), dtype=float32),
       'kernel': ShapeDtypeStruct(shape=(768, 768), dtype=float32)},
      'query': {'bias': ShapeDtypeStruct(shape=(768,), dtype=float32),
       'kernel': ShapeDtypeStruct(shape=(768, 768), dtype=float32)},
      'value': {'bias': ShapeDtypeStruct(shape=(768,), dtype=float32),
       'kernel': ShapeDtypeStruct(shape=(768, 768), dtype=float32)}},
     'output': {'dense': {'bias': ShapeDtypeStruct(shape=(768,), dtype=float32),
       'kernel': ShapeDtypeStruct(shape=(768, 768), dtyp

# Convert FLAX model to big_vision

In [1]:
import jax
import sys
sys.path.append('/home/austinwang/austin_big_vision')
import numpy as np
import jax.numpy as jnp
import big_vision.utils as u
from flax.core.frozen_dict import unfreeze


output_path = '/home/austinwang/bigvision_dinov2.npz'

def inspect(params):
    names_and_vals, _ = u.tree_flatten_with_names(params)
    for n,v in names_and_vals:
        print(n, v.shape, v.dtype)

def save_bigvision_checkpoint(params, output_path):
    np.savez(output_path, **params)

print("Loading Flax checkpoint...")
from transformers import FlaxDinov2Model
flax_params = FlaxDinov2Model.from_pretrained("facebook/dinov2-base")._params


2024-10-10 21:50:08.237158: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-10-10 21:50:08.259846: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-10-10 21:50:08.266671: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Loading Flax checkpoint...


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
print("Converting parameters...")
bv_params = {}
# Embedding layer
bv_params['embedding/kernel'] = jnp.transpose(flax_params['embeddings']['patch_embeddings']['projection']['kernel'], (3, 2, 0, 1))
bv_params['embedding/bias'] = flax_params['embeddings']['patch_embeddings']['projection']['bias']
# Position embedding
bv_params['pos_embedding'] = flax_params['embeddings']['position_embeddings']
inspect(bv_params)

Converting parameters...
embedding/bias (768,) float32
embedding/kernel (768, 3, 14, 14) float32
pos_embedding (1, 1370, 768) float32


In [None]:
# Transformer blocks
for i in range(12):  # Assuming 12 layers
    prefix = f'Transformer/encoderblock_{i}/'
    flax_prefix = f'encoder/layer/{i}/'
    
    # Layer Norm
    bv_params[prefix + 'LayerNorm_0/scale'] = flax_params[flax_prefix + 'norm1']['scale']
    bv_params[prefix + 'LayerNorm_0/bias'] = flax_params[flax_prefix + 'norm1']['bias']
    bv_params[prefix + 'LayerNorm_1/scale'] = flax_params[flax_prefix + 'norm2']['scale']
    bv_params[prefix + 'LayerNorm_1/bias'] = flax_params[flax_prefix + 'norm2']['bias']
    
    # Multi-Head Attention
    bv_params[prefix + 'MultiHeadDotProductAttention_0/query/kernel'] = flax_params[flax_prefix + 'attention/attention/query/kernel'].reshape(768, 12, 64)
    bv_params[prefix + 'MultiHeadDotProductAttention_0/query/bias'] = flax_params[flax_prefix + 'attention/attention/query/bias'].reshape(12, 64)
    bv_params[prefix + 'MultiHeadDotProductAttention_0/key/kernel'] = flax_params[flax_prefix + 'attention/attention/key/kernel'].reshape(768, 12, 64)
    bv_params[prefix + 'MultiHeadDotProductAttention_0/key/bias'] = flax_params[flax_prefix + 'attention/attention/key/bias'].reshape(12, 64)
    bv_params[prefix + 'MultiHeadDotProductAttention_0/value/kernel'] = flax_params[flax_prefix + 'attention/attention/value/kernel'].reshape(768, 12, 64)
    bv_params[prefix + 'MultiHeadDotProductAttention_0/value/bias'] = flax_params[flax_prefix + 'attention/attention/value/bias'].reshape(12, 64)
    bv_params[prefix + 'MultiHeadDotProductAttention_0/out/kernel'] = flax_params[flax_prefix + 'attention/output/dense/kernel'].reshape(12, 64, 768)
    bv_params[prefix + 'MultiHeadDotProductAttention_0/out/bias'] = flax_params[flax_prefix + 'attention/output/dense/bias']
    
    # MLP
    bv_params[prefix + 'MlpBlock_0/Dense_0/kernel'] = flax_params[flax_prefix + 'mlp/fc1/kernel']
    bv_params[prefix + 'MlpBlock_0/Dense_0/bias'] = flax_params[flax_prefix + 'mlp/fc1/bias']
    bv_params[prefix + 'MlpBlock_0/Dense_1/kernel'] = flax_params[flax_prefix + 'mlp/fc2/kernel']
    bv_params[prefix + 'MlpBlock_0/Dense_1/bias'] = flax_params[flax_prefix + 'mlp/fc2/bias']

# Final Layer Norm
bv_params['Transformer/encoder_norm/scale'] = flax_params['layernorm']['scale']
bv_params['Transformer/encoder_norm/bias'] = flax_params['layernorm']['bias']


In [None]:

print("Saving Big Vision compatible checkpoint...")
save_bigvision_checkpoint(bv_params, output_path)

print(f"Conversion complete. Checkpoint saved to {output_path}")