In [1]:
%load_ext autoreload
%autoreload 2
%cd ~/austin_big_vision

import io
import jax
import importlib
import numpy as np
import ml_collections
import jax.numpy as jnp
import big_vision.utils as u
from big_vision.models.vit import scan_to_pyloop

  bkms = self.shell.db.get('bookmarks', {})
  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


/home/austinwang/austin_big_vision


2024-10-10 22:37:13.406566: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-10-10 22:37:13.429018: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-10-10 22:37:13.435935: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


# Convert FLAX model to big_vision

In [2]:
def inspect(params):
    names_and_vals, _ = u.tree_flatten_with_names(params)
    for n,v in names_and_vals:
        print(n, v.shape, v.dtype)

print("Loading Flax checkpoint...")
from transformers import FlaxDinov2Model
flax_params = FlaxDinov2Model.from_pretrained("facebook/dinov2-base")._params

print("Converting parameters...")
bv_params = {}
# Embedding layer
bv_params['embedding/kernel'] = jnp.transpose(flax_params['embeddings']['patch_embeddings']['projection']['kernel'], (3, 2, 0, 1))
bv_params['embedding/bias'] = flax_params['embeddings']['patch_embeddings']['projection']['bias']
# Position embedding
bv_params['pos_embedding'] = flax_params['embeddings']['position_embeddings']

# Transformer blocks
for i in range(12):  # Assuming 12 layers
    prefix = f'Transformer/encoderblock_{i}/'
    flax_prefix = f'encoder/layer/{i}/'
    
    # Layer Norm
    bv_params[prefix + 'LayerNorm_0/scale'] = flax_params['encoder']['layer'][str(i)]['norm1']['scale']
    bv_params[prefix + 'LayerNorm_0/bias'] = flax_params['encoder']['layer'][str(i)]['norm1']['bias']
    bv_params[prefix + 'LayerNorm_1/scale'] = flax_params['encoder']['layer'][str(i)]['norm2']['scale']
    bv_params[prefix + 'LayerNorm_1/bias'] = flax_params['encoder']['layer'][str(i)]['norm2']['bias']
    
    # Multi-Head Attention
    bv_params[prefix + 'MultiHeadDotProductAttention_0/query/kernel'] = flax_params['encoder']['layer'][str(i)]['attention']['attention']['query']['kernel'].reshape(768, 12, 64)
    bv_params[prefix + 'MultiHeadDotProductAttention_0/query/bias'] = flax_params['encoder']['layer'][str(i)]['attention']['attention']['query']['bias'].reshape(12, 64)
    bv_params[prefix + 'MultiHeadDotProductAttention_0/key/kernel'] = flax_params['encoder']['layer'][str(i)]['attention']['attention']['key']['kernel'].reshape(768, 12, 64)
    bv_params[prefix + 'MultiHeadDotProductAttention_0/key/bias'] = flax_params['encoder']['layer'][str(i)]['attention']['attention']['key']['bias'].reshape(12, 64)
    bv_params[prefix + 'MultiHeadDotProductAttention_0/value/kernel'] = flax_params['encoder']['layer'][str(i)]['attention']['attention']['value']['kernel'].reshape(768, 12, 64)
    bv_params[prefix + 'MultiHeadDotProductAttention_0/value/bias'] = flax_params['encoder']['layer'][str(i)]['attention']['attention']['value']['bias'].reshape(12, 64)
    bv_params[prefix + 'MultiHeadDotProductAttention_0/out/kernel'] = flax_params['encoder']['layer'][str(i)]['attention']['output']['dense']['kernel'].reshape(12, 64, 768)
    bv_params[prefix + 'MultiHeadDotProductAttention_0/out/bias'] = flax_params['encoder']['layer'][str(i)]['attention']['output']['dense']['bias']
    
    # MLP
    bv_params[prefix + 'MlpBlock_0/Dense_0/kernel'] = flax_params['encoder']['layer'][str(i)]['mlp']['fc1']['kernel']
    bv_params[prefix + 'MlpBlock_0/Dense_0/bias'] = flax_params['encoder']['layer'][str(i)]['mlp']['fc1']['bias']
    bv_params[prefix + 'MlpBlock_0/Dense_1/kernel'] = flax_params['encoder']['layer'][str(i)]['mlp']['fc2']['kernel']
    bv_params[prefix + 'MlpBlock_0/Dense_1/bias'] = flax_params['encoder']['layer'][str(i)]['mlp']['fc2']['bias']

# Final Layer Norm
bv_params['Transformer/encoder_norm/scale'] = flax_params['layernorm']['scale']
bv_params['Transformer/encoder_norm/bias'] = flax_params['layernorm']['bias']

inspect(bv_params)

Loading Flax checkpoint...
Converting parameters...
Transformer/encoder_norm/bias (768,) float32
Transformer/encoder_norm/scale (768,) float32
Transformer/encoderblock_0/LayerNorm_0/bias (768,) float32
Transformer/encoderblock_0/LayerNorm_0/scale (768,) float32
Transformer/encoderblock_0/LayerNorm_1/bias (768,) float32
Transformer/encoderblock_0/LayerNorm_1/scale (768,) float32
Transformer/encoderblock_0/MlpBlock_0/Dense_0/bias (3072,) float32
Transformer/encoderblock_0/MlpBlock_0/Dense_0/kernel (768, 3072) float32
Transformer/encoderblock_0/MlpBlock_0/Dense_1/bias (768,) float32
Transformer/encoderblock_0/MlpBlock_0/Dense_1/kernel (3072, 768) float32
Transformer/encoderblock_0/MultiHeadDotProductAttention_0/key/bias (12, 64) float32
Transformer/encoderblock_0/MultiHeadDotProductAttention_0/key/kernel (768, 12, 64) float32
Transformer/encoderblock_0/MultiHeadDotProductAttention_0/out/bias (768,) float32
Transformer/encoderblock_0/MultiHeadDotProductAttention_0/out/kernel (12, 64, 768) 

In [3]:
save_bv_vit = False
local_np_save_path = '/home/austinwang/bigvision_dinov2.npz'

if save_bv_vit:
    ckpt = {'params': {'img': bv_params}}
    io_buffer = io.BytesIO()
    names_and_vals, _ = u.tree_flatten_with_names(ckpt)
    np.savez(io_buffer, **{k: v for k, v in names_and_vals})

    with open(local_np_save_path, 'wb') as f: f.write(io_buffer.getvalue())

big_vision_vit = u.npload(local_np_save_path)
for key in big_vision_vit.keys():
    print(key, big_vision_vit[key].shape, big_vision_vit[key].dtype)

params/img/Transformer/encoder_norm/bias (768,) float32
params/img/Transformer/encoder_norm/scale (768,) float32
params/img/Transformer/encoderblock_0/LayerNorm_0/bias (768,) float32
params/img/Transformer/encoderblock_0/LayerNorm_0/scale (768,) float32
params/img/Transformer/encoderblock_0/LayerNorm_1/bias (768,) float32
params/img/Transformer/encoderblock_0/LayerNorm_1/scale (768,) float32
params/img/Transformer/encoderblock_0/MlpBlock_0/Dense_0/bias (3072,) float32
params/img/Transformer/encoderblock_0/MlpBlock_0/Dense_0/kernel (768, 3072) float32
params/img/Transformer/encoderblock_0/MlpBlock_0/Dense_1/bias (768,) float32
params/img/Transformer/encoderblock_0/MlpBlock_0/Dense_1/kernel (3072, 768) float32
params/img/Transformer/encoderblock_0/MultiHeadDotProductAttention_0/key/bias (12, 64) float32
params/img/Transformer/encoderblock_0/MultiHeadDotProductAttention_0/key/kernel (768, 12, 64) float32
params/img/Transformer/encoderblock_0/MultiHeadDotProductAttention_0/out/bias (768,) 

# Big Vision ViT

In [4]:
config = ml_collections.ConfigDict()
config.model_name = 'vit'
config.model_load = {}
config.model_init = '/home/austinwang/bigvision_dinov2.npz'

config.model = dict(
    variant = 'B/14',
    posemb = 'learn',
    rep_size = False,
    dropout = 0.0,
    pool_type = 'gap',
    head_zeroinit = False,
    mask = None, # fully visible mask
    normalize_qk = False,
    scan = False,
    remat_policy = 'nothing_saveable',
    dtype_mm = 'float32',
    proj_bias = False,
)

In [5]:
model_mod = importlib.import_module(f"big_vision.models.{config.model_name}")
# model = model_mod.Model(**config.get("model", {}))