#### Notebook for Exploring Pruning of BERT: Inspired from: https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/text_classification.ipynb#scrollTo=YOCrQwPoIrJG

#### Task selected using Huggingface's Datasets

In [9]:
import torch
import datasets
import transformers
import numpy as np

##### Task
CoLA outputs a binary label of gramatically correct or not. It was deemed the easiest place to start. 

In [5]:
# Setup chosen task and metric
task = "cola"
checkpoint = "bert-base-uncased"
batch_size = 16

dataset = datasets.load_dataset("glue", task)
metric = datasets.load_metric("glue", task)

Downloading readme: 100%|██████████| 35.3k/35.3k [00:00<00:00, 2.08MB/s]
Downloading data: 100%|██████████| 251k/251k [00:00<00:00, 748kB/s]
Downloading data: 100%|██████████| 37.6k/37.6k [00:00<00:00, 204kB/s]
Downloading data: 100%|██████████| 37.7k/37.7k [00:00<00:00, 190kB/s]
Generating train split: 100%|██████████| 8551/8551 [00:00<00:00, 886005.27 examples/s]
Generating validation split: 100%|██████████| 1043/1043 [00:00<00:00, 369793.67 examples/s]
Generating test split: 100%|██████████| 1063/1063 [00:00<00:00, 421891.10 examples/s]
  metric = datasets.load_metric("glue", task)
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
Downloading builder script: 5.76kB [00:00, 8.91MB/s]                   


In [7]:
print(dataset)
print(dataset["train"][0])

DatasetDict({
    train: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 8551
    })
    validation: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 1043
    })
    test: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 1063
    })
})
{'sentence': "Our friends won't buy this analysis, let alone the next one we propose.", 'label': 1, 'idx': 0}


In [10]:
print(metric)

# Using metric
fake_preds = np.random.randint(0, 2, size=(64,))
fake_labels = np.random.randint(0, 2, size=(64,))
metric.compute(predictions=fake_preds, references=fake_labels)

Metric(name: "glue", features: {'predictions': Value(dtype='int64', id=None), 'references': Value(dtype='int64', id=None)}, usage: """
Compute GLUE evaluation metric associated to each GLUE dataset.
Args:
    predictions: list of predictions to score.
        Each translation should be tokenized into a list of tokens.
    references: list of lists of references for each translation.
        Each reference should be tokenized into a list of tokens.
Returns: depending on the GLUE subset, one or several of:
    "accuracy": Accuracy
    "f1": F1 score
    "pearson": Pearson Correlation
    "spearmanr": Spearman Correlation
    "matthews_correlation": Matthew Correlation
Examples:

    >>> glue_metric = datasets.load_metric('glue', 'sst2')  # 'sst2' or any of ["mnli", "mnli_mismatched", "mnli_matched", "qnli", "rte", "wnli", "hans"]
    >>> references = [0, 1]
    >>> predictions = [0, 1]
    >>> results = glue_metric.compute(predictions=predictions, references=references)
    >>> print(res

{'matthews_correlation': 0.2135744251723958}

In [19]:
# Define tokenizer
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint, use_fast=True)
# Example
print(tokenizer("Hello, this one sentence!", "And this sentence goes with it."))

#print(f"Sentence: {dataset['train'][0]['sentence']}")

# preprocess function
def preprocess_function(examples):
    return tokenizer(examples['sentence'], truncation=True)

preprocess_function(dataset['train'][:5])
encoded_dataset = dataset.map(preprocess_function, batched=True)



{'input_ids': [101, 7592, 1010, 2023, 2028, 6251, 999, 102, 1998, 2023, 6251, 3632, 2007, 2009, 1012, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}


Map: 100%|██████████| 8551/8551 [00:00<00:00, 39751.79 examples/s]
Map: 100%|██████████| 1043/1043 [00:00<00:00, 50726.57 examples/s]
Map: 100%|██████████| 1063/1063 [00:00<00:00, 51302.49 examples/s]
