In [7]:
%load_ext autoreload
%autoreload 2
%cd ~/austin_big_vision

import io
import jax
import importlib
import numpy as np
import ml_collections
import jax.numpy as jnp
import big_vision.utils as u
from big_vision.models.vit import scan_to_pyloop

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
/home/austinwang/austin_big_vision


  bkms = self.shell.db.get('bookmarks', {})
  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


# Convert FLAX model to big_vision

In [6]:
def inspect(params):
    names_and_vals, _ = u.tree_flatten_with_names(params)
    for n,v in names_and_vals:
        print(n, v.shape, v.dtype)

print("Loading Flax checkpoint...")
from transformers import FlaxDinov2Model
flax_model = FlaxDinov2Model.from_pretrained("facebook/dinov2-base")
flax_params = flax_model._params

print("Converting parameters...")
bv_params = {}
# Embedding layer
bv_params['embedding/kernel'] = flax_params['embeddings']['patch_embeddings']['projection']['kernel']
bv_params['embedding/bias'] = flax_params['embeddings']['patch_embeddings']['projection']['bias']
# Position embedding
bv_params['pos_embedding'] = flax_params['embeddings']['position_embeddings']
# Transformer blocks
for i in range(12):  # Assuming 12 layers
    prefix = f'Transformer/encoderblock_{i}/'
    flax_prefix = f'encoder/layer/{i}/'
    # Layer Norm
    bv_params[prefix + 'LayerNorm_0/scale'] = flax_params['encoder']['layer'][str(i)]['norm1']['scale']
    bv_params[prefix + 'LayerNorm_0/bias'] = flax_params['encoder']['layer'][str(i)]['norm1']['bias']
    bv_params[prefix + 'LayerNorm_1/scale'] = flax_params['encoder']['layer'][str(i)]['norm2']['scale']
    bv_params[prefix + 'LayerNorm_1/bias'] = flax_params['encoder']['layer'][str(i)]['norm2']['bias']
    # Multi-Head Attention
    bv_params[prefix + 'MultiHeadDotProductAttention_0/query/kernel'] = flax_params['encoder']['layer'][str(i)]['attention']['attention']['query']['kernel'].reshape(768, 12, 64)
    bv_params[prefix + 'MultiHeadDotProductAttention_0/query/bias'] = flax_params['encoder']['layer'][str(i)]['attention']['attention']['query']['bias'].reshape(12, 64)
    bv_params[prefix + 'MultiHeadDotProductAttention_0/key/kernel'] = flax_params['encoder']['layer'][str(i)]['attention']['attention']['key']['kernel'].reshape(768, 12, 64)
    bv_params[prefix + 'MultiHeadDotProductAttention_0/key/bias'] = flax_params['encoder']['layer'][str(i)]['attention']['attention']['key']['bias'].reshape(12, 64)
    bv_params[prefix + 'MultiHeadDotProductAttention_0/value/kernel'] = flax_params['encoder']['layer'][str(i)]['attention']['attention']['value']['kernel'].reshape(768, 12, 64)
    bv_params[prefix + 'MultiHeadDotProductAttention_0/value/bias'] = flax_params['encoder']['layer'][str(i)]['attention']['attention']['value']['bias'].reshape(12, 64)
    bv_params[prefix + 'MultiHeadDotProductAttention_0/out/kernel'] = flax_params['encoder']['layer'][str(i)]['attention']['output']['dense']['kernel'].reshape(12, 64, 768)
    bv_params[prefix + 'MultiHeadDotProductAttention_0/out/bias'] = flax_params['encoder']['layer'][str(i)]['attention']['output']['dense']['bias']
    # MLP
    bv_params[prefix + 'MlpBlock_0/Dense_0/kernel'] = flax_params['encoder']['layer'][str(i)]['mlp']['fc1']['kernel']
    bv_params[prefix + 'MlpBlock_0/Dense_0/bias'] = flax_params['encoder']['layer'][str(i)]['mlp']['fc1']['bias']
    bv_params[prefix + 'MlpBlock_0/Dense_1/kernel'] = flax_params['encoder']['layer'][str(i)]['mlp']['fc2']['kernel']
    bv_params[prefix + 'MlpBlock_0/Dense_1/bias'] = flax_params['encoder']['layer'][str(i)]['mlp']['fc2']['bias']
# Final Layer Norm
bv_params['Transformer/encoder_norm/scale'] = flax_params['layernorm']['scale']
bv_params['Transformer/encoder_norm/bias'] = flax_params['layernorm']['bias']

print("Inspecting parameters...")
inspect(bv_params)

Loading Flax checkpoint...


KeyboardInterrupt: 

In [None]:
save_bv_vit = False
local_np_save_path = '/home/austinwang/bigvision_dinov2.npz'

if save_bv_vit:
    ckpt = {'params': {'img': bv_params}}
    io_buffer = io.BytesIO()
    names_and_vals, _ = u.tree_flatten_with_names(ckpt)
    np.savez(io_buffer, **{k: v for k, v in names_and_vals})

    with open(local_np_save_path, 'wb') as f: f.write(io_buffer.getvalue())

big_vision_vit = u.npload(local_np_save_path)
for key in big_vision_vit.keys():
    print(key, big_vision_vit[key].shape, big_vision_vit[key].dtype)

# big_vision ViT CKPT Loading

In [8]:
config = ml_collections.ConfigDict()
config.model_name = 'vit'
config.model_load = {}
config.model_init = '/home/austinwang/bigvision_dinov2.npz:img'

config.model = dict(
    variant = 'B/14',
    posemb = 'learn',
    rep_size = False,
    dropout = 0.0,
    pool_type = 'gap',
    head_zeroinit = False,
    mask = None, # fully visible mask
    normalize_qk = False,
    scan = False,
    remat_policy = 'nothing_saveable',
    dtype_mm = 'float32',
    proj_bias = False,
)
model_mod = importlib.import_module(f"big_vision.models.{config.model_name}")
bv_model = model_mod.Model(**config.get("model", {}))
model_cfg = config.get("model")
# load ckpt weights
rng = jax.random.PRNGKey(42)
dummy_img = jnp.zeros([2, 224, 224, 3], jnp.float32)
init_params = jax.jit(bv_model.init, backend="cpu")(rng, dummy_img)['params']
params = model_mod.load(init_params, config.model_init, model_cfg, **config.get("model_load", {}))
jax.tree_map(lambda x: x.shape, params)

  jax.tree_map(lambda x: x.shape, params)


{'Transformer': {'encoder_norm': {'bias': (768,), 'scale': (768,)},
  'encoderblock_0': {'LayerNorm_0': {'bias': (768,), 'scale': (768,)},
   'LayerNorm_1': {'bias': (768,), 'scale': (768,)},
   'MlpBlock_0': {'Dense_0': {'bias': (3072,), 'kernel': (768, 3072)},
    'Dense_1': {'bias': (768,), 'kernel': (3072, 768)}},
   'MultiHeadDotProductAttention_0': {'key': {'bias': (12, 64),
     'kernel': (768, 12, 64)},
    'out': {'bias': (768,), 'kernel': (12, 64, 768)},
    'query': {'bias': (12, 64), 'kernel': (768, 12, 64)},
    'value': {'bias': (12, 64), 'kernel': (768, 12, 64)}}},
  'encoderblock_1': {'LayerNorm_0': {'bias': (768,), 'scale': (768,)},
   'LayerNorm_1': {'bias': (768,), 'scale': (768,)},
   'MlpBlock_0': {'Dense_0': {'bias': (3072,), 'kernel': (768, 3072)},
    'Dense_1': {'bias': (768,), 'kernel': (3072, 768)}},
   'MultiHeadDotProductAttention_0': {'key': {'bias': (12, 64),
     'kernel': (768, 12, 64)},
    'out': {'bias': (768,), 'kernel': (12, 64, 768)},
    'query':

# Output Similarity Test

In [35]:
from transformers import AutoImageProcessor, FlaxDinov2Model
from PIL import Image
import requests

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)

image_processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
flax_model = FlaxDinov2Model.from_pretrained("facebook/dinov2-base")

flax_inputs = image_processor(images=image, return_tensors="np")
flax_outputs = flax_model(**flax_inputs)
flax_last_hidden_states = flax_outputs.last_hidden_state
print(f"flax_last_hidden_states.shape: {flax_last_hidden_states.shape}")

bv_img = jnp.transpose(jnp.array(flax_inputs.pixel_values), (0, 2, 3, 1))
bv_outputs = bv_model.apply({'params':params}, bv_img, train=False)
bv_outputs[0].shape, bv_outputs[1].keys()

flax_last_hidden_states.shape: (1, 257, 768)


((1, 768),
 dict_keys(['stem', 'with_posemb', 'encoder', 'encoded', 'head_input', 'pre_logits_2d', 'pre_logits']))

In [36]:
FlaxDinov2Model

transformers.models.dinov2.modeling_flax_dinov2.FlaxDinov2Model

In [31]:
flax_outputs.keys()

odict_keys(['last_hidden_state', 'pooler_output'])

In [18]:
bv_outputs[1]['encoded'].shape, flax_last_hidden_states.shape

((1, 256, 768), (1, 257, 768))

In [33]:
# sim between bv_out[1]['encoded'] and flax_last_hidden_states
bv_outputs[1]['encoded'].sum(-1)[0,:5], flax_last_hidden_states.sum(-1)[1,:6]

(Array([5.180847 , 6.4842377, 9.029211 , 6.8737183, 5.3015556], dtype=float32),
 Array([  5.9947643, -26.457676 ,   3.7039318, -16.864552 ,   5.1114597,
         -7.4066467], dtype=float32))