#### Notebook for Exploring Pruning of BERT: Inspired from: https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/text_classification.ipynb#scrollTo=YOCrQwPoIrJG

#### Task selected using Huggingface's Datasets

In [3]:
import torch
import torch.nn.utils.prune as prune
import datasets
import transformers
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


##### Task
CoLA outputs a binary label of gramatically correct or not. It was deemed the easiest place to start. 

In [4]:
# Setup chosen task and metric
task = "cola"
checkpoint = "bert-base-uncased"
batch_size = 16

dataset = datasets.load_dataset("glue", task)
metric = datasets.load_metric("glue", task)

  metric = datasets.load_metric("glue", task)
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


In [5]:
print(dataset)
print(dataset["train"][0])

DatasetDict({
    train: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 8551
    })
    validation: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 1043
    })
    test: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 1063
    })
})
{'sentence': "Our friends won't buy this analysis, let alone the next one we propose.", 'label': 1, 'idx': 0}


In [6]:
print(metric)

# Using metric
fake_preds = np.random.randint(0, 2, size=(64,))
fake_labels = np.random.randint(0, 2, size=(64,))
metric.compute(predictions=fake_preds, references=fake_labels)

Metric(name: "glue", features: {'predictions': Value(dtype='int64', id=None), 'references': Value(dtype='int64', id=None)}, usage: """
Compute GLUE evaluation metric associated to each GLUE dataset.
Args:
    predictions: list of predictions to score.
        Each translation should be tokenized into a list of tokens.
    references: list of lists of references for each translation.
        Each reference should be tokenized into a list of tokens.
Returns: depending on the GLUE subset, one or several of:
    "accuracy": Accuracy
    "f1": F1 score
    "pearson": Pearson Correlation
    "spearmanr": Spearman Correlation
    "matthews_correlation": Matthew Correlation
Examples:

    >>> glue_metric = datasets.load_metric('glue', 'sst2')  # 'sst2' or any of ["mnli", "mnli_mismatched", "mnli_matched", "qnli", "rte", "wnli", "hans"]
    >>> references = [0, 1]
    >>> predictions = [0, 1]
    >>> results = glue_metric.compute(predictions=predictions, references=references)
    >>> print(res

{'matthews_correlation': 0.0640210337986158}

In [7]:
# Define tokenizer
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint, use_fast=True)
# Example
print(tokenizer("Hello, this one sentence!", "And this sentence goes with it."))

#print(f"Sentence: {dataset['train'][0]['sentence']}")

# preprocess function
def preprocess_function(examples):
    return tokenizer(examples['sentence'], truncation=True)

preprocess_function(dataset['train'][:5])
encoded_dataset = dataset.map(preprocess_function, batched=True)

{'input_ids': [101, 7592, 1010, 2023, 2028, 6251, 999, 102, 1998, 2023, 6251, 3632, 2007, 2009, 1012, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}




In [82]:
# Grab BERT for sequence classification

num_labels = 2
model = transformers.BertForSequenceClassification.from_pretrained(checkpoint, num_labels=num_labels)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [83]:
metric_name = "matthews_correlation"
model_name = checkpoint.split("/")[-1]

args = transformers.TrainingArguments(
    f"{model_name}-finetuned-{task}",
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=5,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model=metric_name,
    push_to_hub=False,
)

In [84]:
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return metric.compute(predictions=predictions, references=labels)

In [85]:
validation_key = "validation"
trainer = transformers.Trainer(
    model,
    args,
    train_dataset=encoded_dataset["train"],
    eval_dataset=encoded_dataset[validation_key],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [86]:
trainer.train()

  0%|          | 0/2675 [00:00<?, ?it/s]

 19%|█▉        | 503/2675 [00:23<01:40, 21.69it/s]

{'loss': 0.4892, 'grad_norm': 8.492907524108887, 'learning_rate': 1.6261682242990654e-05, 'epoch': 0.93}


 20%|█▉        | 533/2675 [00:24<01:27, 24.52it/s]
 20%|██        | 535/2675 [00:25<01:27, 24.52it/s]

{'eval_loss': 0.44444724917411804, 'eval_matthews_correlation': 0.5272175286888164, 'eval_runtime': 0.6017, 'eval_samples_per_second': 1733.381, 'eval_steps_per_second': 109.687, 'epoch': 1.0}


 38%|███▊      | 1004/2675 [00:49<01:14, 22.56it/s]

{'loss': 0.2934, 'grad_norm': 7.86383581161499, 'learning_rate': 1.2523364485981309e-05, 'epoch': 1.87}


 40%|████      | 1070/2675 [00:52<01:12, 22.08it/s]
 40%|████      | 1070/2675 [00:53<01:12, 22.08it/s]

{'eval_loss': 0.47043007612228394, 'eval_matthews_correlation': 0.5932805322494611, 'eval_runtime': 0.6902, 'eval_samples_per_second': 1511.161, 'eval_steps_per_second': 95.625, 'epoch': 2.0}


 56%|█████▌    | 1502/2675 [01:14<00:54, 21.38it/s]

{'loss': 0.1912, 'grad_norm': 0.32137614488601685, 'learning_rate': 8.785046728971963e-06, 'epoch': 2.8}


 60%|█████▉    | 1604/2675 [01:19<00:40, 26.19it/s]
 60%|██████    | 1605/2675 [01:19<00:40, 26.19it/s]

{'eval_loss': 0.6296312808990479, 'eval_matthews_correlation': 0.5680628967969402, 'eval_runtime': 0.5923, 'eval_samples_per_second': 1760.785, 'eval_steps_per_second': 111.421, 'epoch': 3.0}


 75%|███████▍  | 2003/2675 [01:40<00:26, 25.22it/s]

{'loss': 0.138, 'grad_norm': 0.15411821007728577, 'learning_rate': 5.046728971962617e-06, 'epoch': 3.74}


 80%|███████▉  | 2138/2675 [01:46<00:25, 20.78it/s]
 80%|████████  | 2140/2675 [01:47<00:25, 20.78it/s]

{'eval_loss': 0.9082818627357483, 'eval_matthews_correlation': 0.5572696682585848, 'eval_runtime': 0.5918, 'eval_samples_per_second': 1762.401, 'eval_steps_per_second': 111.523, 'epoch': 4.0}


 94%|█████████▎| 2502/2675 [02:06<00:07, 21.71it/s]

{'loss': 0.0874, 'grad_norm': 0.07196631282567978, 'learning_rate': 1.308411214953271e-06, 'epoch': 4.67}


100%|█████████▉| 2673/2675 [02:14<00:00, 23.54it/s]
100%|██████████| 2675/2675 [02:14<00:00, 23.54it/s]

{'eval_loss': 0.9003195762634277, 'eval_matthews_correlation': 0.5857509882742485, 'eval_runtime': 0.577, 'eval_samples_per_second': 1807.756, 'eval_steps_per_second': 114.393, 'epoch': 5.0}


100%|██████████| 2675/2675 [02:17<00:00, 19.43it/s]

{'train_runtime': 137.6442, 'train_samples_per_second': 310.62, 'train_steps_per_second': 19.434, 'train_loss': 0.22967105580267505, 'epoch': 5.0}





TrainOutput(global_step=2675, training_loss=0.22967105580267505, metrics={'train_runtime': 137.6442, 'train_samples_per_second': 310.62, 'train_steps_per_second': 19.434, 'total_flos': 454848611954580.0, 'train_loss': 0.22967105580267505, 'epoch': 5.0})

In [87]:
trainer.evaluate()

100%|██████████| 66/66 [00:00<00:00, 118.17it/s]


{'eval_loss': 0.47043007612228394,
 'eval_matthews_correlation': 0.5932805322494611,
 'eval_runtime': 0.5726,
 'eval_samples_per_second': 1821.65,
 'eval_steps_per_second': 115.272,
 'epoch': 5.0}

In [88]:
# Save model
trainer.save_model("bert_pruning_model")

In [79]:
#model == trained_model

#for p1, p2 in zip(model.parameters(), trained_model.parameters()):
#    if(p1.data.ne(p2.data).sum() > 0):
#        print('False')
#print('True')

True

In [147]:
# Load model
trained_model = transformers.BertForSequenceClassification.from_pretrained('bert_pruning_model', num_labels=num_labels)

In [148]:
# Calculate sparisty of the network
def calcSparsity(parameter_names):
    zeros = 0
    elements = 0

    for name, param in trained_model.named_parameters():

        # Layer was pruned if orig exists
        #parsed_name = name.split('_orig')
        # Layer was pruned
        #if(len(parsed_name) == 2 and parsed_name[0] in parameter_names):

        zero_count = torch.sum(param==0.0).item()
        zeros += zero_count

        element_count = param.numel()
        elements += element_count

    if(elements == 0):
        return 0

    return zeros / elements

In [149]:
# Select the parameters that you wish to prune
parameters_to_prune = []
parameter_names = []
for name, module in trained_model.named_modules():

    # Self-Attention
    #if isinstance(module, transformers.models.bert.modeling_bert.BertSelfAttention):
    #    parameters_to_prune.extend([
    #            (module.query, 'weight'),
    #            (module.key, 'weight'),
    #            (module.value, 'weight'),
    #        ])
    #print(name)
    #parameter_names.append(name)
    #parameter_names.extend([name + '.query', name + '.key', name + '.value'])

    # Linear: Includes linear layers in attention head as well
    if isinstance(module, torch.nn.Linear):
        parameters_to_prune.append((module, 'weight'))
        parameter_names.append(name + '.weight')

In [150]:
prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.7
)

for module,name in parameters_to_prune:
    prune.remove(module, name)

In [152]:
calcSparsity(parameter_names)

0.5468226626231331