### Configure and verify the environment

In [None]:
# uncommand following to install library

# pip install transformers
# pip install datasets
# pip install torch==2.0.1 torchvision==0.15.2
# pip install accelerate

### Preprocessing


In [12]:
# import libs
import os
import re
import sys
import time
import numpy as np
import torch
import torchvision
import transformers
import datasets
from datasets import load_dataset, load_metric
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer

In [None]:
task = 'mnli'
# task = 'mnli-mm'

dataset = load_dataset("glue", task)
metric = load_metric("glue", task)

# before feed texts to model, need to prepocessing data, it can be done by Transformer Tokenizer
batch_size = 16

# model type: BERT && DistilBERT
# model_checkpoint = "bert-base-uncased"
model_checkpoint = "distilbert-base-uncased"

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)

task_to_keys = {"mnli": ("premise", "hypothesis")}
sentence1_key, sentence2_key = task_to_keys[task]

def preprocess_function(samples):
  return tokenizer(samples[sentence1_key], samples[sentence2_key], truncation=True)

# use one single command to preprocess train, validation and test data
encoded_dataset = dataset.map(preprocess_function, batched=True)

# MNLI has 3 labels
num_labels = 3
model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=num_labels)
metric_name = 'accuracy'
model_name = model_checkpoint.split("/")[-1]

In [None]:
# Trainer configuration
args = TrainingArguments(
    f"{model_name}_output",
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=1,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model=metric_name
)

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)

    return metric.compute(predictions=predictions, references=labels)

validation_key = "validation_mismatched" if task == "mnli-mm" else "validation_matched"

trainer = Trainer(
    model,
    args,
    train_dataset=encoded_dataset["train"],
    eval_dataset=encoded_dataset[validation_key],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

In [None]:
trainer.train()

In [None]:
# Output two models size
def print_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    print('Size (MB):', os.path.getsize("temp.p")/1e6)
    os.remove('temp.p')

In [None]:
# define a evaluation function
def evaluate(model, encoded_dataset, mnli_dataset, test_dataset):
  matched = 0
  N = len(encoded_dataset)
  print(f'Total matched samples: {N}')

  '''
  corresponding encoded number
  netural => 1
  contradiction => 2
  entailment => 0
  '''
  for i, batches in enumerate(encoded_dataset):
    premise = batches['premise']
    hypothesis = batches['hypothesis']
    idx = batches['idx']
    label = mnli_dataset[idx]['label']
    # input to model and predict the label
    encode_input = tokenizer(premise, hypothesis, return_tensors='pt')
    output = model(**encode_input)
    # need Tensor.cpu() to copy the tensor to host memory first
    pred = np.argmax(output.logits.detach().cpu().numpy(), axis=1)

    if test_dataset:
      # all the labels in test_dataset is contradiction
      if pred[0] == 2:
        matched += 1
      # the label of contradiction is -1 in test_dataset
      if label != -1:
        print('exception in test dataset')
    elif pred[0] == label:
      matched += 1

    if i != 0 and i % 500 == 0:
      print(f'Step at: {i / 500}, accu: {matched / N }, matched {matched} out of {i}')

  return matched / N

In [None]:
def time_model_evaluation(model, encoded_dataset, mnli_dataset, test_dataset):
  eval_start_time = time.time()
  acc = evaluate(model, encoded_dataset, mnli_dataset, test_dataset)
  eval_end_time = time.time()
  eval_duration_time = eval_end_time - eval_start_time
  print("\nEND INFO:")
  print("Evaluate total time (seconds): {0:.1f}".format(eval_duration_time))
  print(f'Evaluate end accuracy is {acc}')

In [None]:
# size of initial model
print_size_of_model(model)

In [None]:
device = torch.device('cpu')
model.to(device)
acc = time_model_evaluation(model, encoded_dataset['validation_matched'], dataset["validation_matched"], test_dataset=False)

### PrunBERT

In [None]:
# encoder_layers_to_keep = ['0', '1', '2', '3', '4', '5']
encoder_layers_to_keep = ['0', '1', '2', '4']

def prune_state_dict(state_dict):
    def create_pruning_pass(layers_to_keep, layer_name):
        keep_layers = sorted(
            [int(layer_string) for layer_string in layers_to_keep]
        )
        mapping_dict = {}
        for i in range(len(keep_layers)):
            mapping_dict[str(keep_layers[i])] = str(i)

        regex = re.compile("^{layer}.*\.layers\.(\d+)".format(layer=layer_name))
        return {"substitution_regex": regex, "mapping_dict": mapping_dict}

    pruning_passes = []
    if encoder_layers_to_keep:
        pruning_passes.append(create_pruning_pass(encoder_layers_to_keep, "encoder"))

    new_state_dict = {}
    for layer_name in state_dict.keys():
        match = re.search("\.layer\.(\d+)\.", layer_name)
        # if layer has no number in it, it is a supporting layer, such as an
        # embedding
        if not match:
            # print(f'keeps layer name = {layer_name}.')
            new_state_dict[layer_name] = state_dict[layer_name]
            continue

        # otherwise, layer should be pruned.
        original_layer_number = match.group(1)

        # figure out which mapping dict to replace from
        for pruning_pass in pruning_passes:
            if original_layer_number in pruning_pass["mapping_dict"]:
                new_layer_number = pruning_pass["mapping_dict"][original_layer_number]
                idx = layer_name.find(str(original_layer_number))
                new_state_key = (
                    layer_name[: idx]
                    + new_layer_number
                    + layer_name[idx + 1 :]
                )
                # print(f'original layer name = {layer_name}.           , original_layer_number = {original_layer_number}')
                # print(f'new layer name      = {new_state_key}         , new_layer_number =  {new_layer_number}')
                new_state_dict[new_state_key] = state_dict[layer_name]

    return new_state_dict

In [None]:
def load_state_dict(state_dict, strict=True):
  new_state_dict = prune_state_dict(state_dict)
  return new_state_dict

pruned_state_dict = load_state_dict(model.state_dict())

In [None]:
pruned_model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=num_labels, num_hidden_layers=len(encoder_layers_to_keep))

In [None]:
# load pre-trained weight for pruned_model
pruned_model.load_state_dict(pruned_state_dict)

In [None]:
# train prunBERT
prunBERT_trainer = Trainer(
    pruned_model,
    args,
    train_dataset=encoded_dataset["train"],
    eval_dataset=encoded_dataset[validation_key],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

In [None]:
prunBERT_trainer.train()

In [None]:
print_size_of_model(pruned_model)

In [None]:
pruned_model.to(device)
acc = time_model_evaluation(pruned_model, encoded_dataset['validation_matched'], dataset["validation_matched"], test_dataset=False)

### Quantization

In [None]:
# quantization
quantized_model = torch.quantization.quantize_dynamic(
    model, {torch.nn.Linear}, dtype=torch.qint8
)

triple_model = torch.quantization.quantize_dynamic(
    pruned_model, {torch.nn.Linear}, dtype=torch.qint8
)

In [None]:
print_size_of_model(quantized_model)
print_size_of_model(triple_model)

In [None]:
quantized_model.to(device)
acc = time_model_evaluation(quantized_model, encoded_dataset['validation_matched'], dataset["validation_matched"], test_dataset=False)

In [None]:
triple_model.to(device)
acc = time_model_evaluation(triple_model, encoded_dataset['validation_matched'], dataset["validation_matched"], test_dataset=False)