In [None]:
import matplotlib.pyplot as plt
import numpy as np

%matplotlib inline

In [None]:
model_name1 = "microsoft/xtremedistil-l6-h256-uncased"
model_name2 = "tmills/tiny-dtr"

In [None]:
from transformers import AutoConfig, AutoModel

from cnlpt.CnlpModelForClassification import CnlpModelForClassification

# from cnlpt.CnlpModelForClassification import CnlpConfig

In [None]:
def get_encoder_for_name(model_name: str):
    config = AutoConfig.from_pretrained(model_name)
    if "CnlpModelForClassification" in config.architectures:
        model = CnlpModelForClassification.from_pretrained(model_name, config=config)
        model = model.encoder
    else:
        model = AutoModel.from_pretrained(model_name, config=config)

    return model

In [None]:
def plot_hist(values, num_bins):
    hist, bins = np.histogram(values, bins=num_bins)
    width = 0.7 * (bins[1] - bins[0])
    center = (bins[:-1] + bins[1:]) / 2
    plt.bar(center, hist, align="center", width=width)

In [None]:
model1 = get_encoder_for_name(model_name1)
model2 = get_encoder_for_name(model_name2)
print(model1.num_parameters())
print(model2.num_parameters())

In [None]:
def get_param_dict(model):
    params = {}
    param_names = []
    for param in model.named_parameters():
        params[param[0]] = param[1]
        param_names.append(param[0])
    param_names.sort()

    return params, param_names

In [None]:
model1_params, model1_param_names = get_param_dict(model1)
model2_params, model2_param_names = get_param_dict(model2)

common_param_names = list(set(model1_param_names).intersection(set(model2_param_names)))
common_param_names.sort()
print("%d common parameter names" % (len(common_param_names)))

In [None]:
s1 = model1_params["embeddings.word_embeddings.weight"]
s2 = model2_params["embeddings.word_embeddings.weight"]
print(s1.shape)
print(s2.shape)
print(s1.ndim)
print(s1.detach().numpy().flatten())

In [None]:
m1_params = []
m2_params = []
param_diffs = []
param_pct_diffs = []
param_to_start_ind = {}

for param_name in common_param_names:
    param_to_start_ind[param_name] = len(param_diffs)
    # print("Processing param %s" % (param_name))
    m1v = model1_params[param_name]
    m2v = model2_params[param_name]
    if m1v.ndim == 2:
        d1 = min(m1v.shape[0], m2v.shape[0])
        m1_flat = m1v[:d1, :].detach().numpy().flatten()
        m2_flat = m2v[:d1, :].detach().numpy().flatten()
    else:
        m1_flat = m1v.detach().numpy().flatten()
        m2_flat = m2v.detach().numpy().flatten()

    m1_params.extend(m1_flat)
    m2_params.extend(m2_flat)
    param_diffs.extend(np.abs(m1_flat - m2_flat))
    param_pct_diffs.extend(
        np.abs(m1_flat - m2_flat) / np.abs(m1_flat + np.finfo(float).eps)
    )

print(len(param_diffs))

In [None]:
def get_param_fam_for_ind(ind):
    last_param_name = ""

    for param_name in common_param_names:
        param_start_ind = param_to_start_ind[param_name]
        if ind < param_start_ind:
            return last_param_name

        last_param_name = param_name

    return "Param not found"

## Histogram: Are changes in weights distributed normally or as a power law? (or something else?)

In [None]:
plot_hist(param_diffs, num_bins=20)
plt.show()

## Histogram: how much are weights changing by percentage over initialized value?

In [None]:
# param_pct_diffs
num_bins = 20
plot_hist(param_pct_diffs, num_bins)
plt.yscale("log")
plt.show()
# plt.hist(param_pct_diffs, 20)

In [None]:
print(param_diffs[:10])
print(param_pct_diffs[:10])
print(min(param_pct_diffs))

In [None]:
num_to_check = 30

print(m1_params[:num_to_check])
print(m2_params[:num_to_check])
print(param_diffs[:num_to_check])
print(param_pct_diffs[:num_to_check])

In [None]:
num = 24

print(m1_params[num])
print(m2_params[num])
print(param_pct_diffs[num])

In [None]:
rank_inds = np.argsort(param_pct_diffs)

In [None]:
print(len(rank_inds))

## Where in the encoder are the biggest changes to the parameters?

In [None]:
nz_start = 257
num_to_check = 100

for rank_ind in range(nz_start, nz_start + num_to_check):
    p_ind = rank_inds[-rank_ind]
    param_fam = get_param_fam_for_ind(p_ind)
    print(f"{param_pct_diffs[p_ind]:f} : {param_fam}")

In [None]:
ind = rank_inds[-257]
print(ind)
print(m1_params[ind])
print(m2_params[ind])
print(param_pct_diffs[ind])
print(get_param_fam_for_ind(ind))

## Next step: see how much performance we lose if we only use the 100(0...) largest deltas after fine-tuning
