<a href="https://colab.research.google.com/github/CalculatedContent/WeightWatcher/blob/master/examples/WW_BERT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# How to use weightwatcher to pick the right  model

Notebook that compare the distribution of layer Alpha's $\alpha$ for
- BERT
- RoBERTa, and 
- XLNet

As discussed on the [CalculatedContent Blog](https://calculatedcontent.com/2022/07/22/better-than-bert-pick-your-best-model/)

The WeightWatcher Power-Law (PL) metric Alpha $(\alpha)$ is a DNN model quality metric; smaller is better. This plot above displays all the layer Alpha $(\alpha)$ values for the 3 models. It is immediately clear that the XNLet layers look much better than BERT or RoBERTa; the Alpha $(\alpha)$ values are smaller on average, and there are no lphas larger than 5: $(\alpha <=5)$. 

In contrast, the BERT and RoBERTa Alphas are much larger on average, and both models have too many large Alphas.

This is totally consistent with the published results.: In the [original paper (from Microsoft Research)](https://arxiv.org/abs/1906.08237), XLNet outperforms BERT on 20 different NLP tasks.



In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

In [None]:
!pip install transformers weightwatcher gwpy matplotlib==3.1.3


In [None]:

from transformers import BertModel
bert = BertModel.from_pretrained('bert-base-uncased')


from transformers import RobertaModel
roberta = RobertaModel.from_pretrained("roberta-base")


from transformers import XLNetModel
xlnet = XLNetModel.from_pretrained("xlnet-base-cased")

In [None]:
import weightwatcher as ww
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(ww.__NAME__) 
logger.setLevel(logging.INFO)


In [None]:
#%%capture

import warnings
warnings.filterwarnings('ignore')

watcher = ww.WeightWatcher(model=bert)
bert_details = watcher.describe()

watcher = ww.WeightWatcher(model=roberta)
roberta_details = watcher.describe()


watcher = ww.WeightWatcher(model=xlnet)
xlnet_details = watcher.describe()

In [None]:
#%%capture

import warnings
warnings.filterwarnings('ignore')

watcher = ww.WeightWatcher(model=bert)
bert_details = watcher.analyze(min_evals = 50)

watcher = ww.WeightWatcher(model=roberta)
roberta_details = watcher.analyze(min_evals = 50)

watcher = ww.WeightWatcher(model=xlnet)
xlnet_details = watcher.analyze(min_evals = 50)

In [None]:
watcher = ww.WeightWatcher(model=xlnet)
xlnet_details = watcher.analyze(min_evals = 50)

In [None]:
MAX_ALPHA = 8

B = bert_details[(bert_details.alpha<MAX_ALPHA) & (bert_details.alpha>0)]
B.alpha.plot.hist(bins=100, label='BERT', density=True, color='blue')
plt.axvline(B.alpha.mean(), color='blue', linestyle='dashed')

R = roberta_details[(roberta_details.alpha<MAX_ALPHA) & (roberta_details.alpha>0)]
R.alpha.plot.hist(bins=100, label='RoBERTa' ,alpha=0.5, density=True, color='red')
plt.axvline(R.alpha.mean(), color='red', linestyle='dashed')

X = xlnet_details[(xlnet_details.alpha<MAX_ALPHA) & (roberta_details.alpha>0)]
X.alpha.plot.hist(bins=100, label='XLNet' ,alpha=0.5, density=True, color='green')
plt.axvline(R.alpha.mean(), color='green', linestyle='dashed')

plt.legend()
plt.show()

x = B.layer_id.to_numpy()
y = B.alpha.to_numpy()
plt.scatter(x,y,color='blue')
plt.axhline(np.mean(y), color='blue', linestyle='dashed')

x = R.layer_id.to_numpy()
y = R.alpha.to_numpy()
plt.scatter(x,y,color='red')
plt.axhline(np.mean(y), color='red', linestyle='dashed')

x = X.layer_id.to_numpy()
y = X.alpha.to_numpy()
plt.scatter(x,y,color='green')
plt.axhline(np.mean(y), color='green', linestyle='dashed')
#plt.show()


