Adapted From https://huggingface.co/blog/fine-tune-wav2vec2-english, with modifications for custom synthetic data set

In [20]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

Sun Aug 14 21:06:56 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   36C    P8     9W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [21]:
%%capture
!pip install datasets==1.18.3
!pip install transformers==4.17.0
!pip install jiwer # library to work with audio files

In [8]:
from huggingface_hub import notebook_login

notebook_login()

Login successful
Your token has been saved to /root/.huggingface/token
[1m[31mAuthenticated through git-credential store but this isn't the helper defined on your machine.
You might have to re-authenticate when pushing to the Hugging Face Hub. Run the following command in your terminal in case you want to set this credential helper as the default

git config --global credential.helper store[0m


In [80]:
%%capture
!apt install git-lfs # to upload model checkpoints (lfs = large file storage)

# Creating the tokenier


In [150]:
from datasets import load_dataset, load_metric, Dataset, DatasetDict

dataset = load_dataset("saahith/ems_synth_2020_v1", use_auth_token=True)



  0%|          | 0/1 [00:00<?, ?it/s]

In [151]:
dataset = dataset.rename_column("transcripts", "text")


split_data = dataset["train"].train_test_split(test_size=0.05)
synth_data = DatasetDict(
    {
        'train': split_data['train'],
        'test': split_data['test']
    }
)

In [152]:
synth_data

DatasetDict({
    train: Dataset({
        features: ['audio', 'text'],
        num_rows: 414
    })
    test: Dataset({
        features: ['audio', 'text'],
        num_rows: 22
    })
})

In [153]:
from datasets import ClassLabel
import random
import pandas as pd
from IPython.display import display, HTML

def show_random_elements(dataset, num_examples=10):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    
    df = pd.DataFrame(dataset[picks])
    display(HTML(df.to_html()))

show_random_elements(synth_data["train"].remove_columns(["audio"]), num_examples=10)

Unnamed: 0,text
0,patient was placed on a continuous lead
1,O saturation obtained times three . Hospital
2,S . There are no other associated
3,substances. She states that she has taken
4,O . She denies consuming E T
5,"steps to E M S stretcher, where"
6,couch in the living room on the
7,observed coming from his head. The patient
8,due to combativeness. P T was carried
9,her toes. She states that she has


In [154]:
import re
chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"\u200b\’]'

def remove_special_characters(batch):
    batch["text"] = re.sub(chars_to_ignore_regex, '', batch["text"]).lower() + " "
    return batch

In [155]:
synth_data = synth_data.map(remove_special_characters)

0ex [00:00, ?ex/s]

0ex [00:00, ?ex/s]

In [156]:
show_random_elements(synth_data["train"].remove_columns(["audio"]))

Unnamed: 0,text
0,negative edema strong regular radial pulses present
1,sidewalk outside of the business clothed in
2,f sitting on her couch knees to
3,becomes increasingly aggitated and swings from tearful
4,and moaning p t had no
5,non tender distended no rebound tenderness no
6,to the left side of her body
7,rash medical history medications and allergies as
8,well as nauseous and dizzy p t
9,report and care turned over to the


In [157]:
def extract_all_chars(batch):
  all_text = " ".join(batch["text"])
  vocab = list(set(all_text))
  return {"vocab": [vocab], "all_text": [all_text]}

In [158]:
vocabs = synth_data.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=synth_data.column_names["train"])

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

In [159]:
vocab_list = list(set(vocabs["train"]["vocab"][0]) | set(vocabs["test"]["vocab"][0]))

In [160]:
vocab_dict = {v: k for k, v in enumerate(vocab_list)}
vocab_dict

{' ': 18,
 "'": 2,
 'a': 14,
 'b': 16,
 'c': 9,
 'd': 5,
 'e': 15,
 'f': 3,
 'g': 7,
 'h': 0,
 'i': 12,
 'j': 1,
 'k': 23,
 'l': 21,
 'm': 25,
 'n': 20,
 'o': 10,
 'p': 26,
 'q': 27,
 'r': 22,
 's': 13,
 't': 6,
 'u': 24,
 'v': 4,
 'w': 17,
 'x': 8,
 'y': 19,
 'z': 11}

In [161]:
vocab_dict["|"] = vocab_dict[" "]   # replace " " with |
del vocab_dict[" "]

In [162]:
vocab_dict["[UNK]"] = len(vocab_dict)
vocab_dict["[PAD]"] = len(vocab_dict)
len(vocab_dict)

# len(vocab_dict) ---> linear layer added on top of Wav2Vec2 checkpoint will have output dimension of 30

30

In [163]:
# save vocabulary in .json file
import json
with open('vocab.json', 'w') as vocab_file:
    json.dump(vocab_dict, vocab_file)

In [164]:
from transformers import Wav2Vec2CTCTokenizer

tokenizer = Wav2Vec2CTCTokenizer("./vocab.json", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")

In [165]:
repo_name = "ems-wav2vec2-tokenizer"
tokenizer.push_to_hub(repo_name)

tokenizer config file saved in ems-wav2vec2-tokenizer/tokenizer_config.json
Special tokens file saved in ems-wav2vec2-tokenizer/special_tokens_map.json
To https://huggingface.co/saahith/ems-wav2vec2-tokenizer
   616cd51..b93964b  main -> main

   616cd51..b93964b  main -> main



'https://huggingface.co/saahith/ems-wav2vec2-tokenizer/commit/b93964ba7b551003480c47d1b16a05a06d144f80'

# Creating the Feature Extractor

In [166]:
from transformers import Wav2Vec2FeatureExtractor

feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=False)

In [167]:
from transformers import Wav2Vec2Processor

processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

# Data Preprocessing

In [168]:
import IPython.display as ipd
import numpy as np
import random

rand_int = random.randint(0, len(synth_data["train"]))

print(synth_data["train"][rand_int]["text"])
ipd.Audio(data=np.asarray(synth_data["train"][rand_int]["audio"]["array"]), autoplay=True, rate=16000)

condition during transport or treatment ridealong w 


In [169]:
rand_int = random.randint(0, len(synth_data["train"]))

print("Target text:", synth_data["train"][rand_int]["text"])
print("Input array shape:", np.asarray(synth_data["train"][rand_int]["audio"]["array"]).shape)
print("Sampling rate:", synth_data["train"][rand_int]["audio"]["sampling_rate"])

Target text: stretcher and secured with point harness in 
Input array shape: (50434,)
Sampling rate: 16000


In [170]:
# 1. load and resample audio data
# 2. extract input_values (features) using processor object defined previously 
# 3. enode transcriptions to label ids (words --> numbers)

def prepare_dataset(batch):
    audio = batch["audio"]

    # batched output is "un-batched" to ensure mapping is correct
    batch["input_values"] = processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0]
    batch["input_length"] = len(batch["input_values"])
    
    with processor.as_target_processor():
        batch["labels"] = processor(batch["text"]).input_ids
    return batch

In [171]:
synth_data = synth_data.map(prepare_dataset, remove_columns=synth_data.column_names["train"], num_proc=4)

In [172]:
max_input_length_in_sec = 4.5
synth_data["train"] = synth_data["train"].filter(lambda x: x < max_input_length_in_sec * processor.feature_extractor.sampling_rate, input_columns=["input_length"])
synth_data  # filtering for clips under 4 seconds takes out a lot of samples :(

  0%|          | 0/1 [00:00<?, ?ba/s]

DatasetDict({
    train: Dataset({
        features: ['input_values', 'input_length', 'labels'],
        num_rows: 377
    })
    test: Dataset({
        features: ['input_values', 'input_length', 'labels'],
        num_rows: 22
    })
})

# Training and Evaluation

In [173]:
# data collator to take care of padding
# padding is different for inputs (form feature extractor) and labels (numbers representing characters outputted by network)

import torch

from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union

@dataclass
class DataCollatorCTCWithPadding:
    """
    Data collator that will dynamically pad the inputs received.
    Args:
        processor (:class:`~transformers.Wav2Vec2Processor`)
            The processor used for proccessing the data.
        padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
            Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
            among:
            * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
              sequence if provided).
            * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
              maximum acceptable input length for the model if that argument is not provided.
            * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
              different lengths).
    """

    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lenghts and need
        # different padding methods
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        batch = self.processor.pad(
            input_features,
            padding=self.padding,
            return_tensors="pt",
        )
        with self.processor.as_target_processor():
            labels_batch = self.processor.pad(
                label_features,
                padding=self.padding,
                return_tensors="pt",
            )

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        batch["labels"] = labels

        return batch

In [174]:
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)

In [175]:
wer_metric = load_metric("wer")

In [176]:
def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.batch_decode(pred_ids)
    # we do not want to group tokens when computing the metrics
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

    wer = wer_metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

In [177]:
from transformers import Wav2Vec2ForCTC

model = Wav2Vec2ForCTC.from_pretrained(
    "facebook/wav2vec2-base",
    ctc_loss_reduction="mean", 
    pad_token_id=processor.tokenizer.pad_token_id,
)

loading configuration file https://huggingface.co/facebook/wav2vec2-base/resolve/main/config.json from cache at /root/.cache/huggingface/transformers/c7746642f045322fd01afa31271dd490e677ea11999e68660a92619ec7c892b4.ce1f96bfaf3d7475cb8187b9668c7f19437ade45fb9ceb78d2b06a2cec198015
  "Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 "
Model config Wav2Vec2Config {
  "activation_dropout": 0.0,
  "adapter_kernel_size": 3,
  "adapter_stride": 2,
  "add_adapter": false,
  "apply_spec_augment": true,
  "architectures": [
    "Wav2Vec2ForPreTraining"
  ],
  "attention_dropout": 0.1,
  "bos_token_id": 1,
  "classifier_proj_size": 256,
  "codevector_dim": 256,
  "contrastive_logits_temperature": 0.1,
  "conv_bias": false,
  "conv_dim": [
    512,
    512,
    512,
    512,
    512,
    512,
    512
  ],
  "conv_kernel": [
    10,
    3,
    3,
    3,
    3,
    2,
    2
  ],
  "conv_stride": [
    5,
    2,
    2,
    2,
    2,
    2,
    2
  ],


In [178]:
model.freeze_feature_encoder()  # don't train the CNN layer that extract signal from raw speech signal, because it's already been trained

In [183]:
from transformers import TrainingArguments

training_args = TrainingArguments(
  output_dir=repo_name,
  group_by_length=True,
  per_device_train_batch_size=8,
  evaluation_strategy="steps",
  num_train_epochs=20,
  fp16=True,
  gradient_checkpointing=True,
  save_steps=100,
  eval_steps=100,
  logging_steps=100,
  learning_rate=1e-4,
  weight_decay=0.005,
  warmup_steps=1000,
  save_total_limit=2,
)

PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).


In [184]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=synth_data["train"],
    eval_dataset=synth_data["test"],
    tokenizer=processor.feature_extractor,
)

Using amp half precision backend


# Training

In [185]:
trainer.train()

The following columns in the training set  don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length. If input_length are not expected by `Wav2Vec2ForCTC.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 377
  Num Epochs = 20
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 960


Step,Training Loss,Validation Loss,Wer
100,0.7142,0.688542,0.644737
200,0.6147,0.620202,0.644737
300,0.4867,0.480505,0.539474
400,0.3872,0.413537,0.453947
500,0.3109,0.382683,0.407895
600,0.2618,0.360187,0.421053
700,0.2169,0.396493,0.401316
800,0.1983,0.332686,0.381579
900,0.1734,0.387157,0.328947


The following columns in the evaluation set  don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length. If input_length are not expected by `Wav2Vec2ForCTC.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 22
  Batch size = 8
Saving model checkpoint to ems-wav2vec2-tokenizer/checkpoint-100
Configuration saved in ems-wav2vec2-tokenizer/checkpoint-100/config.json
Model weights saved in ems-wav2vec2-tokenizer/checkpoint-100/pytorch_model.bin
Feature extractor saved in ems-wav2vec2-tokenizer/checkpoint-100/preprocessor_config.json
The following columns in the evaluation set  don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length. If input_length are not expected by `Wav2Vec2ForCTC.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 22
  Batch size = 8
Saving model checkpoint to ems-wav2vec2-tokenizer/checkpoint-200
Co

TrainOutput(global_step=960, training_loss=0.36006935040156046, metrics={'train_runtime': 390.3259, 'train_samples_per_second': 19.317, 'train_steps_per_second': 2.459, 'total_flos': 2.316850468499376e+17, 'train_loss': 0.36006935040156046, 'epoch': 20.0})

In [186]:
trainer.push_to_hub()

/content/ems-wav2vec2-tokenizer is already a clone of https://huggingface.co/saahith/ems-wav2vec2-tokenizer. Make sure you pull the latest changes with `repo.git_pull()`.
Saving model checkpoint to ems-wav2vec2-tokenizer
Configuration saved in ems-wav2vec2-tokenizer/config.json
Model weights saved in ems-wav2vec2-tokenizer/pytorch_model.bin
Feature extractor saved in ems-wav2vec2-tokenizer/preprocessor_config.json


Upload file pytorch_model.bin:   0%|          | 3.34k/360M [00:00<?, ?B/s]

Upload file runs/Aug14_22-26-22_801dfe646291/events.out.tfevents.1660515992.801dfe646291.71.2:  36%|###5      …

Upload file training_args.bin: 100%|##########| 2.92k/2.92k [00:00<?, ?B/s]

Upload file runs/Aug14_22-26-22_801dfe646291/1660515992.6972961/events.out.tfevents.1660515992.801dfe646291.71…

Upload file runs/Aug14_22-19-58_801dfe646291/1660515737.9994903/events.out.tfevents.1660515737.801dfe646291.71…

Upload file runs/Aug14_22-19-58_801dfe646291/events.out.tfevents.1660515737.801dfe646291.71.0:  63%|######2   …

remote: Scanning LFS files for validity, may be slow...        
remote: LFS file scan complete.        
To https://huggingface.co/saahith/ems-wav2vec2-tokenizer
   b93964b..0153b34  main -> main

remote: LFS file scan complete.        
To https://huggingface.co/saahith/ems-wav2vec2-tokenizer
   b93964b..0153b34  main -> main

Dropping the following result as it does not have all the necessary fields:
{}
To https://huggingface.co/saahith/ems-wav2vec2-tokenizer
   0153b34..aeadc20  main -> main

   0153b34..aeadc20  main -> main



'https://huggingface.co/saahith/ems-wav2vec2-tokenizer/commit/0153b342911e21ff5a297cde948896e63f5d6bc1'

# Evaluate

In [187]:
processor = Wav2Vec2Processor.from_pretrained("saahith/ems-wav2vec2-tokenizer")

https://huggingface.co/saahith/ems-wav2vec2-tokenizer/resolve/main/preprocessor_config.json not found in cache or force_download set to True, downloading to /root/.cache/huggingface/transformers/tmpqqpoymok


Downloading:   0%|          | 0.00/215 [00:00<?, ?B/s]

storing https://huggingface.co/saahith/ems-wav2vec2-tokenizer/resolve/main/preprocessor_config.json in cache at /root/.cache/huggingface/transformers/307ef44a2f153369c0ff8ef26eb200c05e483c1d21f66102882cbc653fb8cd6b.0e3e6656f99a6f7b9eddd943463eb7f34363640fce9e87b047ebd50d4b112b50
creating metadata file for /root/.cache/huggingface/transformers/307ef44a2f153369c0ff8ef26eb200c05e483c1d21f66102882cbc653fb8cd6b.0e3e6656f99a6f7b9eddd943463eb7f34363640fce9e87b047ebd50d4b112b50
loading feature extractor configuration file https://huggingface.co/saahith/ems-wav2vec2-tokenizer/resolve/main/preprocessor_config.json from cache at /root/.cache/huggingface/transformers/307ef44a2f153369c0ff8ef26eb200c05e483c1d21f66102882cbc653fb8cd6b.0e3e6656f99a6f7b9eddd943463eb7f34363640fce9e87b047ebd50d4b112b50
Feature extractor Wav2Vec2FeatureExtractor {
  "do_normalize": true,
  "feature_extractor_type": "Wav2Vec2FeatureExtractor",
  "feature_size": 1,
  "padding_side": "right",
  "padding_value": 0.0,
  "return

Downloading:   0%|          | 0.00/217 [00:00<?, ?B/s]

storing https://huggingface.co/saahith/ems-wav2vec2-tokenizer/resolve/main/tokenizer_config.json in cache at /root/.cache/huggingface/transformers/cce2de62c28702b867a3b58c0fab1ee29426cdc622efff65c4819ca87854a312.59710b1a6a5501d31e746b6e464f5c44de3e55a58f80634196025936683a68a9
creating metadata file for /root/.cache/huggingface/transformers/cce2de62c28702b867a3b58c0fab1ee29426cdc622efff65c4819ca87854a312.59710b1a6a5501d31e746b6e464f5c44de3e55a58f80634196025936683a68a9
https://huggingface.co/saahith/ems-wav2vec2-tokenizer/resolve/main/vocab.json not found in cache or force_download set to True, downloading to /root/.cache/huggingface/transformers/tmp3bqr_93g


Downloading:   0%|          | 0.00/268 [00:00<?, ?B/s]

storing https://huggingface.co/saahith/ems-wav2vec2-tokenizer/resolve/main/vocab.json in cache at /root/.cache/huggingface/transformers/28bcb47162a8f602098eaa76b931356b1d2abee39e55410a055cab053477b0b1.f4656e6c876c32251681b0b638f49e4a2a7a66d8fddeda2989a02d40bdab6475
creating metadata file for /root/.cache/huggingface/transformers/28bcb47162a8f602098eaa76b931356b1d2abee39e55410a055cab053477b0b1.f4656e6c876c32251681b0b638f49e4a2a7a66d8fddeda2989a02d40bdab6475
https://huggingface.co/saahith/ems-wav2vec2-tokenizer/resolve/main/special_tokens_map.json not found in cache or force_download set to True, downloading to /root/.cache/huggingface/transformers/tmpxwqh1sjt


Downloading:   0%|          | 0.00/85.0 [00:00<?, ?B/s]

storing https://huggingface.co/saahith/ems-wav2vec2-tokenizer/resolve/main/special_tokens_map.json in cache at /root/.cache/huggingface/transformers/c78b88a91090b962e1548f3590a88c12fdb38e7311f80ae82bd3a49e7e80d14f.a21d51735cf8667bcd610f057e88548d5d6a381401f6b4501a8bc6c1a9dc8498
creating metadata file for /root/.cache/huggingface/transformers/c78b88a91090b962e1548f3590a88c12fdb38e7311f80ae82bd3a49e7e80d14f.a21d51735cf8667bcd610f057e88548d5d6a381401f6b4501a8bc6c1a9dc8498
loading file https://huggingface.co/saahith/ems-wav2vec2-tokenizer/resolve/main/vocab.json from cache at /root/.cache/huggingface/transformers/28bcb47162a8f602098eaa76b931356b1d2abee39e55410a055cab053477b0b1.f4656e6c876c32251681b0b638f49e4a2a7a66d8fddeda2989a02d40bdab6475
loading file https://huggingface.co/saahith/ems-wav2vec2-tokenizer/resolve/main/tokenizer_config.json from cache at /root/.cache/huggingface/transformers/cce2de62c28702b867a3b58c0fab1ee29426cdc622efff65c4819ca87854a312.59710b1a6a5501d31e746b6e464f5c44de

In [188]:
model = Wav2Vec2ForCTC.from_pretrained("saahith/ems-wav2vec2-tokenizer").cuda()

https://huggingface.co/saahith/ems-wav2vec2-tokenizer/resolve/main/config.json not found in cache or force_download set to True, downloading to /root/.cache/huggingface/transformers/tmpqdzqpuyu


Downloading:   0%|          | 0.00/2.29k [00:00<?, ?B/s]

storing https://huggingface.co/saahith/ems-wav2vec2-tokenizer/resolve/main/config.json in cache at /root/.cache/huggingface/transformers/2215b6a58ceb4814aa52c746cf8bf7edef2235a5aa0c099a90f0e2802bdb94bc.3202909be16e9c482257d0c4d35b4f4c2558cea585027c729eed803ba4bf0d8e
creating metadata file for /root/.cache/huggingface/transformers/2215b6a58ceb4814aa52c746cf8bf7edef2235a5aa0c099a90f0e2802bdb94bc.3202909be16e9c482257d0c4d35b4f4c2558cea585027c729eed803ba4bf0d8e
loading configuration file https://huggingface.co/saahith/ems-wav2vec2-tokenizer/resolve/main/config.json from cache at /root/.cache/huggingface/transformers/2215b6a58ceb4814aa52c746cf8bf7edef2235a5aa0c099a90f0e2802bdb94bc.3202909be16e9c482257d0c4d35b4f4c2558cea585027c729eed803ba4bf0d8e
Model config Wav2Vec2Config {
  "_name_or_path": "facebook/wav2vec2-base",
  "activation_dropout": 0.0,
  "adapter_kernel_size": 3,
  "adapter_stride": 2,
  "add_adapter": false,
  "apply_spec_augment": true,
  "architectures": [
    "Wav2Vec2ForCTC"

Downloading:   0%|          | 0.00/360M [00:00<?, ?B/s]

storing https://huggingface.co/saahith/ems-wav2vec2-tokenizer/resolve/main/pytorch_model.bin in cache at /root/.cache/huggingface/transformers/f2a877c9fdf110af85b0f816152f8cda7060566167ae1a367bf401d303a42eec.d050abd26e01b39dd2c629b162be6f37c2461c6a7462e01ae009ac6066f72769
creating metadata file for /root/.cache/huggingface/transformers/f2a877c9fdf110af85b0f816152f8cda7060566167ae1a367bf401d303a42eec.d050abd26e01b39dd2c629b162be6f37c2461c6a7462e01ae009ac6066f72769
loading weights file https://huggingface.co/saahith/ems-wav2vec2-tokenizer/resolve/main/pytorch_model.bin from cache at /root/.cache/huggingface/transformers/f2a877c9fdf110af85b0f816152f8cda7060566167ae1a367bf401d303a42eec.d050abd26e01b39dd2c629b162be6f37c2461c6a7462e01ae009ac6066f72769
All model checkpoint weights were used when initializing Wav2Vec2ForCTC.

All the weights of Wav2Vec2ForCTC were initialized from the model checkpoint at saahith/ems-wav2vec2-tokenizer.
If your task is similar to the task the model of the check

In [190]:
def map_to_result(batch):
  with torch.no_grad():
    input_values = torch.tensor(batch["input_values"], device="cuda").unsqueeze(0)
    logits = model(input_values).logits

  pred_ids = torch.argmax(logits, dim=-1)
  batch["pred_str"] = processor.batch_decode(pred_ids)[0]
  batch["text"] = processor.decode(batch["labels"], group_tokens=False)
  
  return batch

In [191]:
results = synth_data["test"].map(map_to_result, remove_columns=synth_data["test"].column_names)

0ex [00:00, ?ex/s]

In [192]:
print("Test WER: {:.3f}".format(wer_metric.compute(predictions=results["pred_str"], references=results["text"])))

Test WER: 0.230


In [193]:
show_random_elements(results)

Unnamed: 0,pred_str,text
0,e d no patient belongings limited,e d no patient belongings limited
1,vius stretcher she transfered her self to the,via stretcher she transferred herself to the
2,a combnation of facters vitals lead point,a combination of factors vitals lead point
3,monitord as noted p t placed on,monitored as noted p t placed on
4,rapedely moved from the sidewalk to the,rapidly moved from the sidewalk to the
5,intact exstremities unremarkable integamenteury skin warm dry,intact extremities unremarkable integumentary skin warm dry
6,deformity to any other body ariies leid,deformity to any other body areas lead
7,m four at e n,m four h e e n
8,upon arrival at destination is noted no,upon arrival at destination as noted no
9,and independent respirations seurculation ful and regular,and independent respirations circulation full and regular


In [194]:
model.to("cuda")

with torch.no_grad():
  logits = model(torch.tensor(synth_data["test"][:1]["input_values"], device="cuda")).logits

pred_ids = torch.argmax(logits, dim=-1)

# convert ids to tokens
" ".join(processor.tokenizer.convert_ids_to_tokens(pred_ids[0].tolist()))

'[PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] e e e | | | [PAD] [PAD] d [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] | | [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] n n [PAD] [PAD] [PAD] [PAD] o o | | [PAD] [PAD] p [PAD] [PAD] [PAD] [PAD] [PAD] a a a a t t i i e e n t t t | | b [PAD] e e e e [PAD] l l [PAD] [PAD] [PAD] [PAD] [PAD] o o n n g g [PAD] [PAD] [PAD] i i i n n n g g g s s | | [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] l l [PAD] [PAD] i i i m m [PAD] [PAD] i i t t [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] e e d d d | | [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]'