# Pruning Wav2Vec2 for English ASR

In [1]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

import torch 
torch.cuda.current_device()

Sat May  7 02:50:27 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla K80           Off  | 00000000:00:04.0 Off |                    0 |
| N/A   74C    P8    34W / 149W |      0MiB / 11441MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

0

In [2]:
%%capture
!pip install datasets==1.13.3
!pip install transformers==4.11.3

!pip install librosa # to load audio files
!pip install jiwer # to use wer metric
!sudo apt install git-lfs

In [3]:
from huggingface_hub import notebook_login

notebook_login()

Login successful
Your token has been saved to /root/.huggingface/token
[1m[31mAuthenticated through git-crendential store but this isn't the helper defined on your machine.
You will have to re-authenticate when pushing to the Hugging Face Hub. Run the following command in your terminal to set it as the default

git config --global credential.helper store[0m


In [4]:
!git config --global credential.helper store

## Part 1: Prepare Data, Tokenizer, Feature Extractor

### Create Wav2Vec2CTCTokenizer

Download Timit & apply text-normalization.

In [5]:
from datasets import load_dataset, load_metric
from datasets import ClassLabel
import random
import pandas as pd
from IPython.display import display, HTML

timit = load_dataset("timit_asr")
timit = timit.remove_columns(["phonetic_detail", "word_detail", "dialect_region", "id", "sentence_type", "speaker_id"])

def show_random_elements(dataset, num_examples=10):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    
    df = pd.DataFrame(dataset[picks])
    display(HTML(df.to_html()))

## text normalization    
import re
chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"]'

def remove_special_characters(batch):
    batch["text"] = re.sub(chars_to_ignore_regex, '', batch["text"]).lower() + " "
    return batch

timit = timit.map(remove_special_characters)
show_random_elements(timit["train"].remove_columns(["audio", "file"]))

Downloading:   0%|          | 0.00/2.40k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.06k [00:00<?, ?B/s]

Downloading and preparing dataset timit_asr/clean (download: 828.75 MiB, generated: 7.90 MiB, post-processed: Unknown size, total: 836.65 MiB) to /root/.cache/huggingface/datasets/timit_asr/clean/2.0.1/5bebea6cd9df0fc2c8c871250de23293a94c1dc49324182b330b6759ae6718f8...


Downloading:   0%|          | 0.00/869M [00:00<?, ?B/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

Dataset timit_asr downloaded and prepared to /root/.cache/huggingface/datasets/timit_asr/clean/2.0.1/5bebea6cd9df0fc2c8c871250de23293a94c1dc49324182b330b6759ae6718f8. Subsequent calls will reuse this data.


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/4620 [00:00<?, ?ex/s]

  0%|          | 0/1680 [00:00<?, ?ex/s]

Unnamed: 0,text
0,my desires are simple give me one informative paragraph on the subject
1,withdraw only as much money as you need
2,when did women begin to assert themselves sexually
3,the mango and the papaya are in a bowl
4,we got drenched from the uninterrupted rain
5,would you allow acts of violence
6,those answers will be straightforward if you think through them carefully first
7,andrei's skilled eye sized them up
8,do you hear the sleigh bells ringing
9,splendor by sorcery it's a horror


In [7]:
def extract_all_chars(batch):
    all_text = " ".join(batch["text"])
    vocab = list(set(all_text))
    return {"vocab": [vocab], "all_text": [all_text]}

vocabs = timit.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=timit.column_names["train"])
vocab_list = list(set(vocabs["train"]["vocab"][0]) | set(vocabs["test"]["vocab"][0]))
vocab_dict = {v: k for k, v in enumerate(vocab_list)}
vocab_list = list(set(vocabs["train"]["vocab"][0]) | set(vocabs["test"]["vocab"][0]))
vocab_dict["|"] = vocab_dict[" "]
del vocab_dict[" "]
vocab_dict["[UNK]"] = len(vocab_dict)
vocab_dict["[PAD]"] = len(vocab_dict)

## instantiate an object of the Wav2Vec2CTCTokenizer class.
import json
from transformers import Wav2Vec2CTCTokenizer

with open('vocab.json', 'w') as vocab_file:
    json.dump(vocab_dict, vocab_file)

tokenizer = Wav2Vec2CTCTokenizer("./vocab.json", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

In [8]:
repo_name = "wav2vec2-base-timit-demo-colab"
tokenizer.push_to_hub(repo_name)

Cloning https://huggingface.co/smileysky/wav2vec2-base-timit-demo-colab into local empty directory.
remote: Enforcing permissions...        
remote: Allowed refs: all        
To https://huggingface.co/smileysky/wav2vec2-base-timit-demo-colab
   66b54ec..507d830  main -> main



'https://huggingface.co/smileysky/wav2vec2-base-timit-demo-colab/commit/507d8308716b4f24e67695ad15e920a6ba51ce3d'

### Create Wav2Vec2 Feature Extractor

In [10]:
from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Processor

feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=False)
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

### Preprocess Data


In [11]:
import IPython.display as ipd
import numpy as np
import random

def prepare_dataset(batch):
    audio = batch["audio"]

    # batched output is "un-batched" to ensure mapping is correct
    batch["input_values"] = processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0]
    batch["input_length"] = len(batch["input_values"])
    
    with processor.as_target_processor():
        batch["labels"] = processor(batch["text"]).input_ids
    return batch

timit = timit.map(prepare_dataset, remove_columns=timit.column_names["train"], num_proc=4)

# trim out audio sequences that are longer than 4sec. 
max_input_length_in_sec = 4.0
timit["train"] = timit["train"].filter(lambda x: x < max_input_length_in_sec * processor.feature_extractor.sampling_rate, input_columns=["input_length"])

  0%|          | 0/5 [00:00<?, ?ba/s]

In [12]:
from transformers import Wav2Vec2ForCTC

pred_processor = Wav2Vec2Processor.from_pretrained("patrickvonplaten/wav2vec2-base-timit-demo")
finetuned_model = Wav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-base-timit-demo").to("cuda")

Downloading:   0%|          | 0.00/159 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/268 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/138 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/85.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.60k [00:00<?, ?B/s]

  "Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 "
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Downloading:   0%|          | 0.00/360M [00:00<?, ?B/s]

In [13]:
wer_metric = load_metric("wer")
def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.batch_decode(pred_ids)
    # we do not want to group tokens when computing the metrics
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

    wer = wer_metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

Downloading:   0%|          | 0.00/1.95k [00:00<?, ?B/s]

In [14]:
import torch 

def map_to_result(batch):
    with torch.no_grad():
        input_values = torch.tensor(batch["input_values"], device="cuda").unsqueeze(0)
        logits = finetuned_model(input_values).logits

    pred_ids = torch.argmax(logits, dim=-1)
    batch["pred_str"] = pred_processor.batch_decode(pred_ids)[0]
    batch["text"] = processor.decode(batch["labels"], group_tokens=False)

    return batch

# decode Timit test set 
results = timit["test"].map(map_to_result, remove_columns=timit["test"].column_names)

  0%|          | 0/1680 [00:00<?, ?ex/s]

Let's compute the overall WER now, and look at some predictions made by the finetuned wav2vec2.

In [None]:
print("Test WER: {:.3f}".format(wer_metric.compute(predictions=results["pred_str"], references=results["text"])))

show_random_elements(results)

Test WER: 0.186


Unnamed: 0,pred_str,text
0,beged that gard for one galon of gas,beg that guard for one gallon of gas
1,out would be a collossealashame to throuw away a story like this,it would be a colossal shame to throw away a story like this
2,lary's costume needed black glothes to be completely elegent,lori's costume needed black gloves to be completely elegant
3,he always seemed to have money in his pocket,he always seemed to have money in his pocket
4,i'd rather not buy these shoes than be overcharged,i'd rather not buy these shoes than be overcharged
5,don't ask me to carry an oily rag like that,don't ask me to carry an oily rag like that
6,only lawyers love milionars,only lawyers love millionaires
7,she said sharks have no bones and shrimp swam backward,she said sharks have no bones and shrimp swam backward
8,they enjoy it when i ad dition,they enjoy it when i audition
9,gues the question from the answer,guess the question from the answer


In [None]:
finetuned_model.wav2vec2

Wav2Vec2Model(
  (feature_extractor): Wav2Vec2FeatureExtractor(
    (conv_layers): ModuleList(
      (0): Wav2Vec2GroupNormConvLayer(
        (conv): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
        (layer_norm): GroupNorm(512, 512, eps=1e-05, affine=True)
      )
      (1): Wav2Vec2NoLayerNormConvLayer(
        (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
      )
      (2): Wav2Vec2NoLayerNormConvLayer(
        (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
      )
      (3): Wav2Vec2NoLayerNormConvLayer(
        (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
      )
      (4): Wav2Vec2NoLayerNormConvLayer(
        (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
      )
      (5): Wav2Vec2NoLayerNormConvLayer(
        (conv): Conv1d(512, 512, kernel_size=(2,), stride=(2,), bias=False)
      )
      (6): Wav2Vec2NoLayerNormConvLayer(
        (conv): Conv1d(512, 512, kernel_size=(2,), 

In [None]:
import torch.nn.utils.prune as prune

def pruning_bert(model, px, model_type='wav2vec_small'):
    """
    prune out wav2vec 2.0 BERT: 12 transformer layers for BASE, and 24 
                                transformer layers for LARGE

    note: position encoding, projection heads, layernorm statistics are not pruned. 
    """
    if model_type == 'wav2vec_small':
        num_transformer_blocks = 12
    elif model_type == 'libri960_big' or model_type == 'xlsr_53_56k':
        num_transformer_blocks = 24
    else:
        print('model type {} not supported'.format(model_type))        
    print('num_transformer_blocks is', num_transformer_blocks)

    parameters_to_prune =[]
    for ii in range(num_transformer_blocks):
        parameters_to_prune.append((model.encoder.layers[ii].attention.k_proj, 'weight'))
        parameters_to_prune.append((model.encoder.layers[ii].attention.k_proj, 'bias'))
        parameters_to_prune.append((model.encoder.layers[ii].attention.v_proj, 'weight'))
        parameters_to_prune.append((model.encoder.layers[ii].attention.v_proj, 'bias'))
        parameters_to_prune.append((model.encoder.layers[ii].attention.q_proj, 'weight'))
        parameters_to_prune.append((model.encoder.layers[ii].attention.q_proj, 'bias'))
        parameters_to_prune.append((model.encoder.layers[ii].attention.out_proj, 'weight'))
        parameters_to_prune.append((model.encoder.layers[ii].attention.out_proj, 'bias'))
        parameters_to_prune.append((model.encoder.layers[ii].feed_forward.intermediate_dense, 'weight'))
        parameters_to_prune.append((model.encoder.layers[ii].feed_forward.intermediate_dense, 'bias'))
        parameters_to_prune.append((model.encoder.layers[ii].feed_forward.output_dense, 'weight'))
        parameters_to_prune.append((model.encoder.layers[ii].feed_forward.output_dense, 'bias'))

    parameters_to_prune = tuple(parameters_to_prune)

    prune.global_unstructured(
        parameters_to_prune,
        pruning_method=prune.L1Unstructured,
        amount=px,
    )
        
def unprune_bert(model, model_type='wav2vec_small'):
    """
    remove pruning forward pre-hook. This is useful when we want to tweek the learned pruned mask, which is used in PARP.
    """
    if model_type == 'wav2vec_small':
        num_transformer_blocks = 12
    elif model_type == 'libri960_big' or model_type == 'xlsr_53_56k':
        num_transformer_blocks = 24
    else:
        print('model type {} not supported'.format(model_type))
    print('num_transformer_blocks is', num_transformer_blocks)

    parameters_to_prune =[]
    for ii in range(num_transformer_blocks):
        parameters_to_prune.append(model.encoder.layers[ii].attention.k_proj)
        parameters_to_prune.append(model.encoder.layers[ii].attention.v_proj)
        parameters_to_prune.append(model.encoder.layers[ii].attention.q_proj)
        parameters_to_prune.append(model.encoder.layers[ii].attention.out_proj)
        parameters_to_prune.append(model.encoder.layers[ii].feed_forward.intermediate_dense)
        parameters_to_prune.append(model.encoder.layers[ii].feed_forward.output_dense)

    for ii in range(0, len(parameters_to_prune)): # applying both weight+bias masks
        prune.remove(parameters_to_prune[ii], 'weight')
        prune.remove(parameters_to_prune[ii], 'bias')

def see_weight_rate(model, model_type='wav2vec_small'):
    """ check a model's zero rate 
    """
    if model_type == 'wav2vec_small':
        num_transformer_blocks = 12
    elif model_type == 'libri960_big' or model_type == 'xlsr_53_56k':
        num_transformer_blocks = 24
    else:
        print('model type {} not supported'.format(model_type))        
    print('num_transformer_blocks is', num_transformer_blocks)

    sum_list_2, zero_sum_2 = 0, 0
    for ii in range(num_transformer_blocks):
        sum_list_2 = sum_list_2 + float(model.encoder.layers[ii].attention.k_proj.weight.nelement())
        zero_sum_2 = zero_sum_2 + float(torch.sum(model.encoder.layers[ii].attention.k_proj.weight == 0))
        sum_list_2 = sum_list_2 + float(model.encoder.layers[ii].attention.k_proj.bias.nelement())
        zero_sum_2 = zero_sum_2 + float(torch.sum(model.encoder.layers[ii].attention.k_proj.bias == 0))

        sum_list_2 = sum_list_2 + float(model.encoder.layers[ii].attention.v_proj.weight.nelement())
        zero_sum_2 = zero_sum_2 + float(torch.sum(model.encoder.layers[ii].attention.v_proj.weight == 0))
        sum_list_2 = sum_list_2 + float(model.encoder.layers[ii].attention.v_proj.bias.nelement())
        zero_sum_2 = zero_sum_2 + float(torch.sum(model.encoder.layers[ii].attention.v_proj.bias == 0))

        sum_list_2 = sum_list_2 + float(model.encoder.layers[ii].attention.q_proj.weight.nelement())
        zero_sum_2 = zero_sum_2 + float(torch.sum(model.encoder.layers[ii].attention.q_proj.weight == 0))
        sum_list_2 = sum_list_2 + float(model.encoder.layers[ii].attention.q_proj.bias.nelement())
        zero_sum_2 = zero_sum_2 + float(torch.sum(model.encoder.layers[ii].attention.q_proj.bias == 0))

        sum_list_2 = sum_list_2 + float(model.encoder.layers[ii].attention.out_proj.weight.nelement())
        zero_sum_2 = zero_sum_2 + float(torch.sum(model.encoder.layers[ii].attention.out_proj.weight == 0))
        sum_list_2 = sum_list_2 + float(model.encoder.layers[ii].attention.out_proj.bias.nelement())
        zero_sum_2 = zero_sum_2 + float(torch.sum(model.encoder.layers[ii].attention.out_proj.bias == 0))

        sum_list_2 = sum_list_2 + float(model.encoder.layers[ii].feed_forward.intermediate_dense.weight.nelement())
        zero_sum_2 = zero_sum_2 + float(torch.sum(model.encoder.layers[ii].feed_forward.intermediate_dense.weight == 0))
        sum_list_2 = sum_list_2 + float(model.encoder.layers[ii].feed_forward.intermediate_dense.bias.nelement())
        zero_sum_2 = zero_sum_2 + float(torch.sum(model.encoder.layers[ii].feed_forward.intermediate_dense.bias == 0))

        sum_list_2 = sum_list_2 + float(model.encoder.layers[ii].feed_forward.output_dense.weight.nelement())
        zero_sum_2 = zero_sum_2 + float(torch.sum(model.encoder.layers[ii].feed_forward.output_dense.weight == 0))
        sum_list_2 = sum_list_2 + float(model.encoder.layers[ii].feed_forward.output_dense.bias.nelement())
        zero_sum_2 = zero_sum_2 + float(torch.sum(model.encoder.layers[ii].feed_forward.output_dense.bias == 0))

    bert_zero_rate = 100 * zero_sum_2 / sum_list_2
    print('BERT zero rate is {0:.2f}'.format(bert_zero_rate))
    return bert_zero_rate

In [None]:
pruning_rate = 0.5
pruning_bert(finetuned_model.wav2vec2, pruning_rate, model_type='wav2vec_small')
see_weight_rate(finetuned_model.wav2vec2)

num_transformer_blocks is 12
num_transformer_blocks is 12
BERT zero rate is 50.00


50.0

In [None]:
mask_dict = {}; weight_dict = {}
model_dict = finetuned_model.state_dict()

for key in model_dict.keys():
    if 'mask' in key:
        mask_dict[key] = model_dict[key]
    else:
        weight_dict[key] = model_dict[key]

torch.save(mask_dict, 'pruned-w2v2_' + str(pruning_rate) + '_mask.pt')
torch.save(weight_dict, 'pruned-w2v2_' + str(pruning_rate) + '_weight.pt')

In [None]:
# decode Timit test set again
results = timit["test"].map(map_to_result, remove_columns=timit["test"].column_names)

print("Test WER: {:.3f}".format(wer_metric.compute(predictions=results["pred_str"], references=results["text"])))

show_random_elements(results)



  0%|          | 0/1680 [00:00<?, ?ex/s]

Test WER: 0.928


Unnamed: 0,pred_str,text
0,rsorwr rry l tat shld andolgo t ru ay y theri,or certain words or rituals that child and adult go through may do the trick
1,don'task mey yaryanolrag la that,don't ask me to carry an oily rag like that
2,yits fomatyyu no,it's a formality you know
3,shy had your dark sut in greysywash waeral yr,she had your dark suit in greasy wash water all year
4,dotask my ty ryan oily rag li hat,don't ask me to carry an oily rag like that
5,yby as s rse l,a boring novel is a superb sleeping pill
6,thaywr bot yarylt,they were both very fluent
7,bapayperou iven ro rls,bob papered over the living room murals
8,propatof aiasa is n onpausn r any govrner,the prospect of cutting back spending is an unpleasant one for any governor
9,untayincolscosaidedwtth brdaf a ng sistompar,each untimely income loss coincided with the breakdown of a heating system part


In [None]:
def apply_pruning_mask(model, mask_dict, model_type='wav2vec_small'):
    """
    apply existing pruning mask to a pre-trained wav2vec 2.0. 
    """
    if model_type == 'wav2vec_small':
        num_transformer_blocks = 12
    elif model_type == 'libri960_big' or model_type == 'xlsr_53_56k':
        num_transformer_blocks = 24
    else:
        print('model type {} not supported'.format(model_type))        
    print('num_transformer_blocks is', num_transformer_blocks)

    parameters_to_prune =[]
    mask_list_w, mask_list_b = [], [] # maks list for weight and bias
    for ii in range(num_transformer_blocks):
        parameters_to_prune.append(model.encoder.layers[ii].attention.k_proj)
        mask_list_w.append(mask_dict['wav2vec2.encoder.layers.' + str(ii) + '.attention.k_proj.weight_mask'])
        mask_list_b.append(mask_dict['wav2vec2.encoder.layers.' + str(ii) + '.attention.k_proj.bias_mask'])
        parameters_to_prune.append(model.encoder.layers[ii].attention.v_proj)
        mask_list_w.append(mask_dict['wav2vec2.encoder.layers.' + str(ii) + '.attention.v_proj.weight_mask'])
        mask_list_b.append(mask_dict['wav2vec2.encoder.layers.' + str(ii) + '.attention.v_proj.bias_mask'])
        parameters_to_prune.append(model.encoder.layers[ii].attention.q_proj)
        mask_list_w.append(mask_dict['wav2vec2.encoder.layers.' + str(ii) + '.attention.q_proj.weight_mask'])
        mask_list_b.append(mask_dict['wav2vec2.encoder.layers.' + str(ii) + '.attention.q_proj.bias_mask'])
        parameters_to_prune.append(model.encoder.layers[ii].attention.out_proj)
        mask_list_w.append(mask_dict['wav2vec2.encoder.layers.' + str(ii) + '.attention.out_proj.weight_mask'])
        mask_list_b.append(mask_dict['wav2vec2.encoder.layers.' + str(ii) + '.attention.out_proj.bias_mask'])
        parameters_to_prune.append(model.encoder.layers[ii].feed_forward.intermediate_dense)
        mask_list_w.append(mask_dict['wav2vec2.encoder.layers.' + str(ii) + '.feed_forward.intermediate_dense.weight_mask'])
        mask_list_b.append(mask_dict['wav2vec2.encoder.layers.' + str(ii) + '.feed_forward.intermediate_dense.bias_mask'])
        parameters_to_prune.append(model.encoder.layers[ii].feed_forward.output_dense)
        mask_list_w.append(mask_dict['wav2vec2.encoder.layers.' + str(ii) + '.feed_forward.output_dense.weight_mask'])
        mask_list_b.append(mask_dict['wav2vec2.encoder.layers.' + str(ii) + '.feed_forward.output_dense.bias_mask'])

    for ii in range(0, len(parameters_to_prune)): # applying both weight+bias masks
        prune.CustomFromMask.apply(parameters_to_prune[ii], 'weight', mask=mask_list_w[ii])
        prune.CustomFromMask.apply(parameters_to_prune[ii], 'bias', mask=mask_list_b[ii])

In [None]:
from transformers import Wav2Vec2ForCTC

# load pre-trained model (not the finetuned one)
pretrained_model = Wav2Vec2ForCTC.from_pretrained(
    "facebook/wav2vec2-base", 
    ctc_loss_reduction="mean", 
    pad_token_id=processor.tokenizer.pad_token_id,
)

# apply the 50% pruning mask back to pre-traiend initialization 
apply_pruning_mask(pretrained_model.wav2vec2, mask_dict)

# double-check the pre-trained model now has 50% sparsity 
see_weight_rate(pretrained_model.wav2vec2)

  "Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 "
Some weights of the model checkpoint at facebook/wav2vec2-base were not used when initializing Wav2Vec2ForCTC: ['quantizer.codevectors', 'project_q.bias', 'project_q.weight', 'project_hid.weight', 'quantizer.weight_proj.weight', 'quantizer.weight_proj.bias', 'project_hid.bias']
- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base and are newly initialized: ['lm_head.weight', 'lm

num_transformer_blocks is 12
num_transformer_blocks is 12
BERT zero rate is 50.00


50.0

## Part 3: Sparse wav2vec 2.0 subnetwork Re-Training

### Set-up Trainer

In [15]:
import torch

from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union

@dataclass
class DataCollatorCTCWithPadding:
    """
    Data collator that will dynamically pad the inputs received.
    Args:
        processor (:class:`~transformers.Wav2Vec2Processor`)
            The processor used for proccessing the data.
        padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
            Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
            among:
            * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
              sequence if provided).
            * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
              maximum acceptable input length for the model if that argument is not provided.
            * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
              different lengths).
        max_length (:obj:`int`, `optional`):
            Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
        max_length_labels (:obj:`int`, `optional`):
            Maximum length of the ``labels`` returned list and optionally padding length (see above).
        pad_to_multiple_of (:obj:`int`, `optional`):
            If set will pad the sequence to a multiple of the provided value.
            This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
            7.5 (Volta).
    """

    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True
    max_length: Optional[int] = None
    max_length_labels: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None
    pad_to_multiple_of_labels: Optional[int] = None

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lenghts and need
        # different padding methods
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        batch = self.processor.pad(
            input_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )
        with self.processor.as_target_processor():
            labels_batch = self.processor.pad(
                label_features,
                padding=self.padding,
                max_length=self.max_length_labels,
                pad_to_multiple_of=self.pad_to_multiple_of_labels,
                return_tensors="pt",
            )

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        batch["labels"] = labels

        return batch
    
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)

In [None]:
from transformers import TrainingArguments

training_args = TrainingArguments(
  output_dir=repo_name,
  group_by_length=True,
  per_device_train_batch_size=32,
  gradient_accumulation_steps=1, 
  evaluation_strategy="steps",
  num_train_epochs=30,
  fp16=False,
  gradient_checkpointing=True,
  save_steps=100,
  eval_steps=100,
  logging_steps=100,
  learning_rate=1e-4,
  weight_decay=0.005,
  warmup_steps=1000,
  save_total_limit=5,
  push_to_hub=True,
)

In [None]:
from transformers import Trainer

trainer = Trainer(
    model=pretrained_model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=timit["train"],
    eval_dataset=timit["test"],
    tokenizer=processor.feature_extractor,
)

/data/sls/temp/clai24/lottery-ticket/fairseq/examples/wav2vec/wav2vec2-base-timit-demo-colab is already a clone of https://huggingface.co/jefflai108/wav2vec2-base-timit-demo-colab. Make sure you pull the latest changes with `repo.git_pull()`.


### Training

In [None]:
torch.cuda.empty_cache()
trainer.train()

The following columns in the training set  don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.
***** Running training *****
  Num examples = 3978
  Num Epochs = 30
  Instantaneous batch size per device = 32
  Total train batch size (w. parallel, distributed & accumulation) = 32
  Gradient Accumulation steps = 1
  Total optimization steps = 3750


Step,Training Loss,Validation Loss,Wer
100,7.1385,5.30809,1.0
200,3.4432,3.229337,1.0
300,3.0762,3.147481,1.0
400,3.0206,3.231803,1.0
500,2.965,3.123697,1.0
600,2.7204,2.526277,1.000069
700,2.3124,1.961237,1.023775
800,1.9346,1.586578,0.936738
900,1.6304,1.366963,0.84784
1000,1.4367,1.172826,0.807732


The following columns in the evaluation set  don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.
***** Running Evaluation *****
  Num examples = 1680
  Batch size = 8
Saving model checkpoint to wav2vec2-base-timit-demo-colab/checkpoint-100
Configuration saved in wav2vec2-base-timit-demo-colab/checkpoint-100/config.json
Model weights saved in wav2vec2-base-timit-demo-colab/checkpoint-100/pytorch_model.bin
Configuration saved in wav2vec2-base-timit-demo-colab/checkpoint-100/preprocessor_config.json
Configuration saved in wav2vec2-base-timit-demo-colab/preprocessor_config.json
Several commits (5) will be pushed upstream.
The following columns in the evaluation set  don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.
***** Running Evaluation *****
  Num examples = 1680
  Batch size = 8
Saving model checkpoint to wav2vec2-base-timit-demo-colab/checkpoint-200
Configuration saved in wav2vec2-base

***** Running Evaluation *****
  Num examples = 1680
  Batch size = 8
Saving model checkpoint to wav2vec2-base-timit-demo-colab/checkpoint-1400
Configuration saved in wav2vec2-base-timit-demo-colab/checkpoint-1400/config.json
Model weights saved in wav2vec2-base-timit-demo-colab/checkpoint-1400/pytorch_model.bin
Configuration saved in wav2vec2-base-timit-demo-colab/checkpoint-1400/preprocessor_config.json
Deleting older checkpoint [wav2vec2-base-timit-demo-colab/checkpoint-900] due to args.save_total_limit
The following columns in the evaluation set  don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.
***** Running Evaluation *****
  Num examples = 1680
  Batch size = 8
Saving model checkpoint to wav2vec2-base-timit-demo-colab/checkpoint-1500
Configuration saved in wav2vec2-base-timit-demo-colab/checkpoint-1500/config.json
Model weights saved in wav2vec2-base-timit-demo-colab/checkpoint-1500/pytorch_model.bin
Configuration saved in wav2v

Configuration saved in wav2vec2-base-timit-demo-colab/checkpoint-2600/preprocessor_config.json
Deleting older checkpoint [wav2vec2-base-timit-demo-colab/checkpoint-2100] due to args.save_total_limit
The following columns in the evaluation set  don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.
***** Running Evaluation *****
  Num examples = 1680
  Batch size = 8
Saving model checkpoint to wav2vec2-base-timit-demo-colab/checkpoint-2700
Configuration saved in wav2vec2-base-timit-demo-colab/checkpoint-2700/config.json
Model weights saved in wav2vec2-base-timit-demo-colab/checkpoint-2700/pytorch_model.bin
Configuration saved in wav2vec2-base-timit-demo-colab/checkpoint-2700/preprocessor_config.json
Deleting older checkpoint [wav2vec2-base-timit-demo-colab/checkpoint-2200] due to args.save_total_limit
The following columns in the evaluation set  don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_leng

TrainOutput(global_step=3750, training_loss=0.9929066851298014, metrics={'train_runtime': 9556.2086, 'train_samples_per_second': 12.488, 'train_steps_per_second': 0.392, 'total_flos': 3.0987418295396214e+18, 'train_loss': 0.9929066851298014, 'epoch': 30.0})

In [None]:
trainer.push_to_hub()

In [None]:
import copy
finetuned_pruned_model = copy.deepcopy(finetuned_model)
finetuned_pruned_model.wav2vec2.encoder.layers[11].attention.k_proj.weight_mask

tensor([[1., 1., 1.,  ..., 0., 1., 0.],
        [1., 0., 1.,  ..., 0., 1., 1.],
        [1., 0., 1.,  ..., 0., 0., 1.],
        ...,
        [0., 1., 0.,  ..., 0., 0., 1.],
        [1., 0., 0.,  ..., 1., 1., 1.],
        [1., 1., 0.,  ..., 1., 0., 1.]], device='cuda:0')

In [None]:
finetuned_pruned_model.load_state_dict(torch.load("wav2vec2-base-timit-demo-colab/checkpoint-3700/pytorch_model.bin"))
finetuned_pruned_model.eval()

Wav2Vec2ForCTC(
  (wav2vec2): Wav2Vec2Model(
    (feature_extractor): Wav2Vec2FeatureExtractor(
      (conv_layers): ModuleList(
        (0): Wav2Vec2GroupNormConvLayer(
          (conv): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
          (layer_norm): GroupNorm(512, 512, eps=1e-05, affine=True)
        )
        (1): Wav2Vec2NoLayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
        )
        (2): Wav2Vec2NoLayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
        )
        (3): Wav2Vec2NoLayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
        )
        (4): Wav2Vec2NoLayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
        )
        (5): Wav2Vec2NoLayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(2,), stride=(2,), bias=False)
        )
        (6): Wav2Vec

Double check the new model is 50% sparsity.

In [None]:
see_weight_rate(finetuned_pruned_model.wav2vec2)

num_transformer_blocks is 12
BERT zero rate is 50.00


50.0

In [None]:
# decode Timit test set again
def map_to_result(batch):
    with torch.no_grad():
        input_values = torch.tensor(batch["input_values"], device="cuda").unsqueeze(0)
        logits = finetuned_pruned_model(input_values).logits

    pred_ids = torch.argmax(logits, dim=-1)
    batch["pred_str"] = processor.batch_decode(pred_ids)[0]
    batch["text"] = processor.decode(batch["labels"], group_tokens=False)
    
    return batch

results = timit["test"].map(map_to_result, remove_columns=timit["test"].column_names)

print("Test WER: {:.3f}".format(wer_metric.compute(predictions=results["pred_str"], references=results["text"])))

show_random_elements(results)

  0%|          | 0/1680 [00:00<?, ?ex/s]

Test WER: 0.241


Unnamed: 0,pred_str,text
0,the tooth verry forgot to come when roger's tooth fell out,the tooth fairy forgot to come when roger's tooth fell out
1,why single me out on this permite deal,why single me out on this permit deal
2,do without fancy table clauths,do without fancy tablecloths
3,are yolooking for employiment,are you looking for employment
4,a chosen few will become generals,a chosen few will become generals
5,widow anice sort of woman,widow nice sort of woman
6,don't do charly's derty dishes,don't do charlie's dirty dishes
7,don't ask me to carry an oily rag like that,don't ask me to carry an oily rag like that
8,how good is your endurance,how good is your endurance
9,she slipped and sprained her ancle on the steep slope,she slipped and sprained her ankle on the steep slope
