#### Notebook for Exploring Pruning of BERT: Inspired from: https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/text_classification.ipynb#scrollTo=YOCrQwPoIrJG

#### Task selected using Huggingface's Datasets

In [43]:
import torch
import torch.nn.utils.prune as prune
import datasets
import transformers
import numpy as np

##### Task
CoLA outputs a binary label of gramatically correct or not. It was deemed the easiest place to start. 

In [2]:
# Setup chosen task and metric
task = "cola"
checkpoint = "bert-base-uncased"
batch_size = 16

dataset = datasets.load_dataset("glue", task)
metric = datasets.load_metric("glue", task)

  metric = datasets.load_metric("glue", task)
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


In [3]:
print(dataset)
print(dataset["train"][0])

DatasetDict({
    train: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 8551
    })
    validation: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 1043
    })
    test: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 1063
    })
})
{'sentence': "Our friends won't buy this analysis, let alone the next one we propose.", 'label': 1, 'idx': 0}


In [4]:
print(metric)

# Using metric
fake_preds = np.random.randint(0, 2, size=(64,))
fake_labels = np.random.randint(0, 2, size=(64,))
metric.compute(predictions=fake_preds, references=fake_labels)

Metric(name: "glue", features: {'predictions': Value(dtype='int64', id=None), 'references': Value(dtype='int64', id=None)}, usage: """
Compute GLUE evaluation metric associated to each GLUE dataset.
Args:
    predictions: list of predictions to score.
        Each translation should be tokenized into a list of tokens.
    references: list of lists of references for each translation.
        Each reference should be tokenized into a list of tokens.
Returns: depending on the GLUE subset, one or several of:
    "accuracy": Accuracy
    "f1": F1 score
    "pearson": Pearson Correlation
    "spearmanr": Spearman Correlation
    "matthews_correlation": Matthew Correlation
Examples:

    >>> glue_metric = datasets.load_metric('glue', 'sst2')  # 'sst2' or any of ["mnli", "mnli_mismatched", "mnli_matched", "qnli", "rte", "wnli", "hans"]
    >>> references = [0, 1]
    >>> predictions = [0, 1]
    >>> results = glue_metric.compute(predictions=predictions, references=references)
    >>> print(res

{'matthews_correlation': 0.2504897164340598}

In [5]:
# Define tokenizer
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint, use_fast=True)
# Example
print(tokenizer("Hello, this one sentence!", "And this sentence goes with it."))

#print(f"Sentence: {dataset['train'][0]['sentence']}")

# preprocess function
def preprocess_function(examples):
    return tokenizer(examples['sentence'], truncation=True)

preprocess_function(dataset['train'][:5])
encoded_dataset = dataset.map(preprocess_function, batched=True)

{'input_ids': [101, 7592, 1010, 2023, 2028, 6251, 999, 102, 1998, 2023, 6251, 3632, 2007, 2009, 1012, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}




In [6]:
# Grab BERT for sequence classification

num_labels = 2
model = transformers.AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=num_labels)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [10]:
metric_name = "matthews_correlation"
model_name = checkpoint.split("/")[-1]

args = transformers.TrainingArguments(
    f"{model_name}-finetuned-{task}",
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=5,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model=metric_name,
    push_to_hub=False,
)

In [11]:
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return metric.compute(predictions=predictions, references=labels)

In [12]:
validation_key = "validation"
trainer = transformers.Trainer(
    model,
    args,
    train_dataset=encoded_dataset["train"],
    eval_dataset=encoded_dataset[validation_key],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [13]:
trainer.train()

 19%|█▉        | 502/2675 [00:22<01:37, 22.36it/s]

{'loss': 0.4991, 'grad_norm': 12.548709869384766, 'learning_rate': 1.6261682242990654e-05, 'epoch': 0.93}


                                                  
 20%|██        | 535/2675 [00:24<01:37, 21.93it/s]

{'eval_loss': 0.43717440962791443, 'eval_matthews_correlation': 0.5211138326512601, 'eval_runtime': 0.6222, 'eval_samples_per_second': 1676.417, 'eval_steps_per_second': 106.082, 'epoch': 1.0}


 37%|███▋      | 1003/2675 [00:47<01:18, 21.18it/s]

{'loss': 0.3088, 'grad_norm': 9.044013977050781, 'learning_rate': 1.2523364485981309e-05, 'epoch': 1.87}


                                                   
 40%|████      | 1070/2675 [00:51<01:15, 21.21it/s]

{'eval_loss': 0.4879399240016937, 'eval_matthews_correlation': 0.5288802917060816, 'eval_runtime': 0.5827, 'eval_samples_per_second': 1789.792, 'eval_steps_per_second': 113.256, 'epoch': 2.0}


 56%|█████▌    | 1504/2675 [01:12<00:49, 23.46it/s]

{'loss': 0.2058, 'grad_norm': 0.558449387550354, 'learning_rate': 8.785046728971963e-06, 'epoch': 2.8}


                                                   
 60%|██████    | 1605/2675 [01:18<00:49, 21.63it/s]

{'eval_loss': 0.5911318063735962, 'eval_matthews_correlation': 0.5691684038863919, 'eval_runtime': 0.5745, 'eval_samples_per_second': 1815.56, 'eval_steps_per_second': 114.887, 'epoch': 3.0}


 75%|███████▍  | 2002/2675 [01:38<00:31, 21.60it/s]

{'loss': 0.1555, 'grad_norm': 1.4555262327194214, 'learning_rate': 5.046728971962617e-06, 'epoch': 3.74}


                                                   
 80%|████████  | 2140/2675 [01:45<00:24, 22.18it/s]

{'eval_loss': 0.8835119009017944, 'eval_matthews_correlation': 0.5100410597431988, 'eval_runtime': 0.5725, 'eval_samples_per_second': 1821.778, 'eval_steps_per_second': 115.28, 'epoch': 4.0}


 94%|█████████▎| 2503/2675 [02:02<00:07, 23.28it/s]

{'loss': 0.1016, 'grad_norm': 0.17335541546344757, 'learning_rate': 1.308411214953271e-06, 'epoch': 4.67}


                                                   
100%|██████████| 2675/2675 [02:10<00:00, 21.75it/s]

{'eval_loss': 0.8546203374862671, 'eval_matthews_correlation': 0.5740921203389623, 'eval_runtime': 0.5747, 'eval_samples_per_second': 1814.738, 'eval_steps_per_second': 114.835, 'epoch': 5.0}


100%|██████████| 2675/2675 [02:12<00:00, 20.20it/s]

{'train_runtime': 132.4553, 'train_samples_per_second': 322.788, 'train_steps_per_second': 20.195, 'train_loss': 0.2443254174918772, 'epoch': 5.0}





TrainOutput(global_step=2675, training_loss=0.2443254174918772, metrics={'train_runtime': 132.4553, 'train_samples_per_second': 322.788, 'train_steps_per_second': 20.195, 'total_flos': 454848611954580.0, 'train_loss': 0.2443254174918772, 'epoch': 5.0})

In [14]:
trainer.evaluate()

100%|██████████| 66/66 [00:00<00:00, 114.56it/s]


{'eval_loss': 0.8546203374862671,
 'eval_matthews_correlation': 0.5740921203389623,
 'eval_runtime': 0.6025,
 'eval_samples_per_second': 1731.209,
 'eval_steps_per_second': 109.549,
 'epoch': 5.0}

In [15]:
# Save model
trained_model = trainer.model.module if hasattr(trainer.model, 'module') else trainer.model  # Take care of distributed/parallel training

In [17]:
print(trained_model)

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12,

In [113]:
# Calculate sparisty of the network
def calcSparsity(parameter_names, parameters_to_prune):
    zeros = 0
    elements = 0

    for name, param in trained_model.named_parameters():
        # Only check sparsity of layers that were pruned
        print(name)
        if (name + "_orig") in parameter_names:
            print(name)

            """
            zero_count = torch.sum(param==0.0).item()
            zeros += zero_count

            element_count = param.numel()
            elements += element_count
            """

    #return zeros / elements

In [116]:
# Select the parameters that you wish to prune
parameters_to_prune = []
parameter_names = []
for name, module in trained_model.named_modules():

    # Self-Attention
    if isinstance(module, transformers.models.bert.modeling_bert.BertSelfAttention):
        parameters_to_prune.extend([
                (module.query, 'weight'),
                (module.key, 'weight'),
                (module.value, 'weight'),
            ])
        parameter_names.append(name)

    # Linear
    elif isinstance(module, torch.nn.Linear):
        parameters_to_prune.append((module, 'weight'))
        parameter_names.append(name)

In [117]:
prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.5
)

for module,name in parameters_to_prune:
    prune.remove(module, name)

ValueError: Parameter 'weight' of module Linear(in_features=768, out_features=768, bias=True) has to be pruned before pruning can be removed

In [114]:
calcSparsity(parameter_names, parameters_to_prune)

bert.embeddings.word_embeddings.weight
bert.embeddings.position_embeddings.weight
bert.embeddings.token_type_embeddings.weight
bert.embeddings.LayerNorm.weight
bert.embeddings.LayerNorm.bias
bert.encoder.layer.0.attention.self.query.bias
bert.encoder.layer.0.attention.self.query.weight_orig
bert.encoder.layer.0.attention.self.key.bias
bert.encoder.layer.0.attention.self.key.weight_orig
bert.encoder.layer.0.attention.self.value.bias
bert.encoder.layer.0.attention.self.value.weight_orig
bert.encoder.layer.0.attention.output.dense.bias
bert.encoder.layer.0.attention.output.dense.weight_orig
bert.encoder.layer.0.attention.output.LayerNorm.weight
bert.encoder.layer.0.attention.output.LayerNorm.bias
bert.encoder.layer.0.intermediate.dense.bias
bert.encoder.layer.0.intermediate.dense.weight_orig
bert.encoder.layer.0.output.dense.bias
bert.encoder.layer.0.output.dense.weight_orig
bert.encoder.layer.0.output.LayerNorm.weight
bert.encoder.layer.0.output.LayerNorm.bias
bert.encoder.layer.1.attent