# Weight Similarity Between BERT and the Extended V&L Models
We first match weight keys between bert and the V&L models.  
And then, calculate similarity for each pairs.

In [1]:
from eval_vl_glue import transformers_volta
from collections import OrderedDict

In [2]:
def _new_format_rule(weight_keys):
    """Get a key mapping that converts old format 
    to new format if needed from a PyTorch state_dict
    """
    
    mapping = {}
    
    for key in weight_keys:
        new_key = None
        if "gamma" in key:
            new_key = key.replace("gamma", "weight")
        if "beta" in key:
            new_key = key.replace("beta", "bias")
        if new_key:
            mapping[key] = new_key
    
    return mapping

In [3]:
def _volta_rule(volta_config, weight_keys):
    """Get a key mapping that converts layer names of bert 
    to those of volta using the mapping defined in volta config
    """
    
    mapping = {}
    
    for key in weight_keys:
        new_key = None
        if ".layer." in key:
            layer_id = int(key.split(".layer.")[-1].split(".")[0])
            if ".attention." in key:
                layer_src = layer_id
                layer_dest = volta_config.bert_layer2attn_sublayer.get(str(layer_id), layer_id)
                new_key = key.replace(f".layer.{layer_src}.attention.", f".layer.{layer_dest}.attention_")
            elif ".intermediate." in key:
                layer_src = layer_id
                layer_dest = volta_config.bert_layer2ff_sublayer.get(str(layer_id), layer_id)
                new_key = key.replace(f".layer.{layer_src}.intermediate.", f".layer.{layer_dest}.intermediate.")
            elif ".output." in key:
                layer_src = layer_id
                layer_dest = volta_config.bert_layer2ff_sublayer.get(str(layer_id), layer_id)
                new_key = key.replace(f".layer.{layer_src}.output.", f".layer.{layer_dest}.output.")
        if new_key:
            mapping[key] = new_key
    
    return mapping

In [4]:
def _apply_mapping(weight_keys, mapping):
    """Get a new key list by applying key mapping.
    """
    
    new_weight_keys = []
    for key in weight_keys:
        new_weight_keys.append(mapping.get(key, key))
    
    return new_weight_keys

In [5]:
def get_weight_table(bert_model, volta_model):
    """Compare the weights between bert_model and volta_model
    and get a table on the weight similarity.
    """
    
    # Get orignal and modiefied key mapping on volta model
    volta_state_dict = volta_model.state_dict()
    volta_original_keys = list(volta_state_dict.keys())
    volta_new_format_rule_mapping = _new_format_rule(volta_original_keys)
    volta_keys = _apply_mapping(volta_original_keys, volta_new_format_rule_mapping)
    
    # Get orignal and modiefied key mapping on bert model
    bert_state_dict = bert_model.state_dict()
    bert_original_keys = list(bert_state_dict.keys())
    bert_new_format_rule_mapping = _new_format_rule(bert_original_keys)
    bert_keys = _apply_mapping(bert_original_keys, bert_new_format_rule_mapping)
    bert_volta_rule_mapping = _volta_rule(volta_model.config, bert_keys)
    bert_keys = _apply_mapping(bert_keys, bert_volta_rule_mapping)
    
    # Correspond to the original keys by matchning them with the modified keys.
    volta_bert_mapping = OrderedDict()
    for key, o_key in zip(volta_keys, volta_original_keys):
        if key in bert_keys:
            i = bert_keys.index(key)
            volta_bert_mapping[o_key] = bert_original_keys[i]
        else:
            volta_bert_mapping[o_key] = None
    
    # Calculate statistics for each layer
    results = OrderedDict()
    for v, b in volta_bert_mapping.items():
        if b is None:
            data = {'transfarred': False, 'volta_weight':v, 'bert_weight':b, 'delta':None, 'volta_avg':None, 'bert_avg':None}
        else:
            x = volta_state_dict[v]
            y = bert_state_dict[b]
            # Note: when type_vocab_size is different from bert-bassed-uncased,
            # the first two embeddings were transferred from BERT.  
            if v == 'embeddings.token_type_embeddings.weight':
                x = x[:2]
            delta = (abs(x - y)).sum().item() / x.numel()
            cossim = (x*y).sum() / ((x*x).sum()*(y*y).sum())**0.5
            volta_avg = abs(x).sum().item() / x.numel()
            bert_avg = abs(y).sum().item() / y.numel()
            data = {'transfarred': True, 'volta_weight':v, 'bert_weight':b, 'delta':delta, 'cossim':cossim, 'volta_avg':volta_avg, 'bert_avg':bert_avg}
        results[v] = data
    
    return results

In [6]:
def summarize_weight_table(table):
    
    print('total_weights (tensors)', len(table))
    
    transferred_weights = list(filter(lambda r:r['transfarred'], table.values()))
    print('total_transferred_weights (tensors)', len(transferred_weights))
    
    weight_types = ('layer_norm_bias', 'layer_norm_weight', 'normal_bias','normal_weight', 'all_bias', 'all_weight')
    
    cossim = {k:[] for k  in weight_types}
    delta = {k:[] for k  in weight_types}
    volta_avg = {k:[] for k  in weight_types}
    bert_avg = {k:[] for k  in weight_types}
    
    for w in transferred_weights:
        key = w['volta_weight']
        if key.endswith('.LayerNorm.weight'):
            t = 'layer_norm_weight'
            tl = 'all_weight'
        elif key.endswith('.LayerNorm.bias'):
            t = 'layer_norm_bias'
            tl = 'all_bias'
        elif key.endswith('.weight'):
            t = 'normal_weight'
            tl = 'all_weight'
        elif key.endswith('.bias'):
            t = 'normal_bias'
            tl = 'all_bias'
        else:
            print('unknown type:', key)    
        
        for _ in (t, tl):
            delta[_].append(w['delta'])
            cossim[_].append(w['cossim'])
            volta_avg[_].append(w['volta_avg'])
            bert_avg[_].append(w['bert_avg'])
        
    mean = lambda l: sum(l) / (len(l) or 1)
    
    print('type', 'n', 'volta_avg', 'bert_avg', 'delta_avg', 'cossim_avg', sep='\t')
    for t in  weight_types:
        print(t, len(delta[t]), '%.5f'%mean(volta_avg[t]), '%.5f'%mean(bert_avg[t]), '%.5f'%mean(delta[t]), '%.5f'%mean(cossim[t]), sep='\t')

In [7]:
# Run analysis on each model.
bert_model = transformers_volta.AutoModel.from_pretrained('bert-base-uncased')

model_paths = [
    '../../vl_models/pretrained/ctrl_visual_bert',
    '../../vl_models/pretrained/ctrl_uniter',
    '../../vl_models/pretrained/ctrl_vl_bert',
    '../../vl_models/pretrained/ctrl_lxmert',
    '../../vl_models/pretrained/ctrl_vilbert',
]

for model_path in model_paths:
    volta_model = transformers_volta.AutoModel.from_pretrained(model_path)
    table = get_weight_table(bert_model, volta_model)
    print(model_path+':')
    summarize_weight_table(table)
    print()

../../vl_models/pretrained/ctrl_visual_bert:
total_weights (tensors) 399
total_transferred_weights (tensors) 197
type	n	volta_avg	bert_avg	delta_avg	cossim_avg
layer_norm_bias	25	0.08202	0.07927	0.00668	0.99726
layer_norm_weight	25	0.83292	0.83128	0.01116	0.99992
normal_bias	72	0.05829	0.05761	0.00447	0.99632
normal_weight	75	0.03053	0.02978	0.01194	0.92177
all_bias	97	0.06441	0.06319	0.00504	0.99657
all_weight	100	0.23113	0.23015	0.01174	0.94131

../../vl_models/pretrained/ctrl_uniter:
total_weights (tensors) 405
total_transferred_weights (tensors) 197
type	n	volta_avg	bert_avg	delta_avg	cossim_avg
layer_norm_bias	25	0.08183	0.07927	0.00661	0.99709
layer_norm_weight	25	0.83259	0.83128	0.01065	0.99991
normal_bias	72	0.05831	0.05761	0.00450	0.99655
normal_weight	75	0.03057	0.02978	0.01212	0.91972
all_bias	97	0.06437	0.06319	0.00505	0.99669
all_weight	100	0.23108	0.23015	0.01175	0.93977

../../vl_models/pretrained/ctrl_vl_bert:
total_weights (tensors) 404
total_transferred_weights (tenso