<a href="https://colab.research.google.com/github/christopherkang/llm-sparsification-release-ctkang/blob/master/HW4.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Impact of Global, Unstructured pruning on LLMs
## Models
In this experiment, we took three, large (>500M) parameter models and pruned them using PyTorch's global, unstructured pruning.

The models used were:
* (Encoder-only): [BERT-large-uncased](https://huggingface.co/bert-large-uncased-whole-word-masking-finetuned-squad?context=My+name+is+Clara+and+I+live+in+Berkeley.&question=What%27s+my+name%3F) with 336M parameters*
* (Decoder-only): [GPT2-XL](https://huggingface.co/gpt2-xl) with 1.5B parameters**
* (Enc-dec): [M2M-100](https://huggingface.co/facebook/m2m100_1.2B) with 418M parameters***

*we faced repeated, consistent issues with using Microsoft's DeBERTa. This was one of the only Encoder-only models large enough, so instead we defaulted to BERT's large version.

**this model barely fit on the GPUs...

***again, HF's code faced errors when trying to import the 1.2B model, so instead we used the 418M version


We pruned 0%, 10%, 50%, 90%, 95%, and 99% of the weights, then observed the performance as the pruning amount changed. 

## Evaluation
For each model, we selected one of the datasets in the papers and then ran the model at 0% pruning (i.e., unchanged). If the model performance was comparable with the original paper, we would then scale up the pruning. If the pruning ever yielded results analogous to complete garbage (e.g., Exact Match (EM) = 0), further pruning was halted. 

* For BERT, we used the SQuAD dataset, a QA task. The relevant metrics are F1 and EM.
* For GPT2-XL, we used the wikitext-2 dataset as a text generation task. The relevant metrics is perplexity (a measure of the validity of the output probability distribution over words). 
* For M2M-100, we used the ccmatrix_de-en dataset, a dataset which includes a number of German phrases with their English translations. The relevant metric is BLEU, a measure of the overlap of the outputted translation with the reference translation. 

In [29]:
import plotly.express as px
import pandas as pd
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots

In [49]:
SPARSITY = [0, 10, 50, 90, 95, 99]

### BERT
Run over a sample set of n=300.

In [39]:
BERT_PERFORMANCE_F1 = [89.75174603174605, 89.93525480367586, 5.101793703923397, 2.5867006975238365, 4.257774368459605, 0.0,]
BERT_PERFORMANCE_EM = [86.33333333333333,  86.33333333333333, 0.3333333333333333, 0.3333333333333333, 0.0, 0.0]
PAPER_F1 = 90.9
PAPER_EM = 84.1

In [52]:
df_BERT = pd.DataFrame({"Sparsity (%)": SPARSITY, "Performance (F1)": BERT_PERFORMANCE_F1, "Performance (EM)": BERT_PERFORMANCE_EM})
fig = make_subplots(specs=[[{"secondary_y": True}]])

# Add traces
fig.add_trace(
    go.Scatter(x=df_BERT["Sparsity (%)"], y=df_BERT["Performance (F1)"], name="Performance (F1)"),
    secondary_y=False,
)

fig.add_trace(
    go.Scatter(x=df_BERT["Sparsity (%)"], y=df_BERT["Performance (EM)"], name="Performance (EM)"),
    secondary_y=True,
)

# Add figure title
fig.update_layout(
    title_text="Performance of BERT after Sparsification in QA <br><sup>Paper F1 in red, Paper EM in green </sup>"
)

# Set x-axis title
fig.update_xaxes(title_text="Sparsity (%)")

# Set y-axes titles
fig.update_yaxes(title_text="Performance (F1)", secondary_y=False)
fig.update_yaxes(title_text="Performance (EM)", secondary_y=True)

fig.add_hline(y=PAPER_F1, line_color='#ff0000', name="Paper (F1)")
fig.add_hline(y=PAPER_EM, line_color='#00ff00', name="Paper (EM)")


fig.show()

### GPT2-XL

In [23]:
GPT2_XL_RESULTS = [14.7877, 14.7825, 36.9467, 4670.9092, np.inf, np.inf] # 0, 10, 50, 90, [95, 99 omitted]
PAPER_DEFAULT = 18.34

In [10]:
list(zip(SPARSITY, GPT2_XL_RESULTS))

[(0, 14.7877),
 (0.1, 14.7825),
 (0.5, 36.9467),
 (0.9, 4670.9092),
 (0.95, inf),
 (0.99, inf)]

In [51]:
df_gpt2 = pd.DataFrame({"Sparsity (%)": SPARSITY, "Performance (PPL)": GPT2_XL_RESULTS})

fig = px.line(df_gpt2, x="Sparsity (%)", y="Performance (PPL)", log_y=True, title="GPT2-XL Perplexity vs Sparsification <br><sup>Paper benchmark of 18.34 in red</sup>")
fig.add_hline(y=PAPER_DEFAULT, line_color='#ff0000', name="Paper")

### M2M-100

In [45]:
M2M_PERFORMANCE = [ 0.40251797333058786, 0.39995484202295456, 0.37629659926239045, 0.0, 0.0, 0.0] ## last two were not run
## NO baseline provided specifically for de-en

In [50]:
df_M2M = pd.DataFrame({"Sparsity (%)": SPARSITY, "Performance (BLEU)": M2M_PERFORMANCE})

fig = px.line(df_M2M, x="Sparsity (%)", y="Performance (BLEU)", log_y=True, title="M2M100 BLEU vs Sparsification on De-En<br><sup>No available De-En benchmark</sup>")
fig.show()

## Discussion

### Global unstructured pruning, masks, and memory consumption
The pruning method used (PT's default global pruning) simply applies a mask to the original model which deactivities specific weights. While this approach is straightforward, it actually increases memory consumption because each mask needs to be stored in conjunction with the weights. (This was particularly noticeable during experiments -- it became increasingly challenging to actually run sparsification on models.) 

### Distributed device training and deployment
It's also strange how there are a number of challenges to distributed training and inference for these LLMs. The other GPUs per node were often left idle because HuggingFace didn't know how to parallelize the work across nodes. 

### Sparsification experiments
For the actual experiments, it seems like there is an insignificant difference from 0-10%, but from 10-50% is good to okay performance, while ~90% sparsification results in total garbage. Both GPT2, BERT experienced this trend, while M2M was still performed somewhat well around 50% pruning. 

This implies two things: for global pruning, we should, with higher granularity, investigate 10-50% pruning. The granularity difference is simply too large to reveal where the optimal point is.

Second, this implies that global pruning is likely not the approach to take -- a 10-30% of reduction in memory would be substantive, but it's unclear what applications would obviate the need for such an expensive approach to memory reduction. I wonder if simply fine-tuning smaller models -- or trying different pruning techniques which truly reduce memory consumption -- would be more fruitful.