In [69]:
from datasets import load_dataset, load_metric
import soundfile as sf
from transformers import Wav2Vec2Processor
from transformers import Wav2Vec2ForCTC
import torch
import random
import torchaudio
import IPython.display as ipd
import numpy as np
from sklearn.preprocessing import MinMaxScaler

In [70]:
speech_dataset = load_dataset("CAiRE/ASCEND")

In [71]:
speech_dataset['train'][9]

{'id': '00009',
 'path': '/storage/hf-datasets-cache/all/datasets/16739474757983-config-parquet-and-info-CAiRE-ASCEND-5c1abf9c/downloads/extracted/f0790e45797bd654a35ecd1eb4865fa761f1cbd842b674e0defb6812ae8cffbf/waves/ses1_spk1_L22_38.900_1.480.wav',
 'audio': {'path': 'ses1_spk1_L22_38.900_1.480.wav',
  'array': array([ 2.34680176e-02,  4.13513184e-02,  3.21044922e-02, ...,
         -3.05175781e-05,  3.35693359e-04, -6.10351562e-05]),
  'sampling_rate': 16000},
 'transcription': 'delicious sea food',
 'duration': 1.4800000190734863,
 'language': 'en',
 'original_speaker_id': 1,
 'session_id': 1,
 'topic': 'persona'}

In [158]:
class AudioUtil():
  @staticmethod
  def open(audio_file):
    language = speech_dataset['train'][audio_file]['language']
    while language != 'en':
      audio_file = random.randint(0, len(speech_dataset['train']))
      language = speech_dataset['train'][audio_file]['language']

    sig, sr = torch.tensor(speech_dataset['train'][audio_file]['audio']['array']), speech_dataset['train'][audio_file]['audio']['sampling_rate']
    sig = sig.view(1, sig.shape[0]).to(torch.float32)
    truth = speech_dataset['train'][audio_file]['transcription']
    return (sig, sr), truth
  
  @staticmethod
  def pad_trunc(aud, max_ms):
    sig, sr = aud
    num_rows, sig_len = sig.shape
    max_len = sr//1000 * max_ms

    if (sig_len > max_len):
      # Truncate the signal to the given length
      sig = sig[:,:max_len]

    elif (sig_len < max_len):
      # Length of padding to add at the beginning and end of the signal
      pad_begin_len = random.randint(0, max_len - sig_len)
      pad_end_len = max_len - sig_len - pad_begin_len

      # Pad with 0s
      pad_begin = torch.zeros((num_rows, pad_begin_len))
      pad_end = torch.zeros((num_rows, pad_end_len))

      sig = torch.cat((pad_begin, sig, pad_end), 1)
      
    return (sig, sr)
  
  @staticmethod
  def resample(aud, newsr):
    sig, sr = aud

    if (sr == newsr):
      # Nothing to do
      return aud

    num_channels = sig.shape[0]
    # Resample first channel
    resig = torchaudio.transforms.Resample(sr, newsr)(sig[:1,:])
    if (num_channels > 1):
      # Resample the second channel and merge both channels
      retwo = torchaudio.transforms.Resample(sr, newsr)(sig[1:,:])
      resig = torch.cat([resig, retwo])

    return ((resig, newsr))
  
  @staticmethod
  def rechannel(aud, new_channel):
    sig, sr = aud

    if (sig.shape[0] == new_channel):
      # Nothing to do
      return aud

    if (new_channel == 1):
      # Convert from stereo to mono by selecting only the first channel
      resig = sig[:1, :]
    else:
      # Convert from mono to stereo by duplicating the first channel
      resig = torch.cat([sig, sig])

    return ((resig, sr))

  @staticmethod
  def quantize(aud, bits):
    sig, sr = aud
    bit_range = (0, (2**bits)-1)
    scaler = MinMaxScaler(feature_range=bit_range)
    
    quantized_signal = scaler.fit_transform(sig.reshape(-1, 1))
        
    quantized_signal = np.clip(quantized_signal, 0, (2**bits)-1)
    quantized_signal = quantized_signal[0:len(quantized_signal)//2].astype(np.int8)
    quantized_signal = torch.tensor(quantized_signal)

    return ((quantized_signal, sr))

In [118]:
rand_int = random.randint(0, len(speech_dataset['train']))
aud, truth = AudioUtil.open(rand_int)
sig, sr = aud
ipd.Audio(data=np.asarray(sig), autoplay=True, rate=16000)

In [119]:
wer_metric = load_metric("wer")

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


In [120]:
processor = Wav2Vec2Processor.from_pretrained("patrickvonplaten/wav2vec2-base-timit-demo")
model = Wav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-base-timit-demo")

Some weights of the model checkpoint at patrickvonplaten/wav2vec2-base-timit-demo were not used when initializing Wav2Vec2ForCTC: ['wav2vec2.encoder.pos_conv_embed.conv.weight_g', 'wav2vec2.encoder.pos_conv_embed.conv.weight_v']
- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at patrickvonplaten/wav2vec2-base-timit-demo and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1']
You shoul

In [165]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
bit = 3
model.to(device)

for i in range(0, 5):
    rand_int = random.randint(0, len(speech_dataset['train']))
    normal_aud, truth = AudioUtil.open(rand_int)
    normal_aud = AudioUtil.rechannel(normal_aud, 2)
    normal_values = processor(
        normal_aud[0], 
        sampling_rate=normal_aud[1], 
        return_tensors="pt"
    ).input_values.to(device)

    quant_aud = AudioUtil.quantize(normal_aud, bit)
    quant_aud = (torch.t(quant_aud[0]), quant_aud[1])
    quant_aud = AudioUtil.rechannel(quant_aud, 2)
    quant_values = processor(
        quant_aud[0],
        sampling_rate=quant_aud[1],
        return_tensors="pt"
    ).input_values.to(device)

    with torch.no_grad():
        normal_logits = model(normal_values.squeeze()).logits
        quant_logits = model(quant_values.squeeze()).logits
    normal_pred_ids = torch.argmax(normal_logits, dim=-1)
    quant_pred_ids = torch.argmax(quant_logits, dim=-1)

    normal_pred = processor.batch_decode(normal_pred_ids)[0]
    quant_pred = processor.batch_decode(quant_pred_ids)[0]

    print(f"Truth: {truth}")
    print(f"Prediction: {normal_pred}")
    print(f"Predication Quantized {bit} bits: {quant_pred}\n")

Truth: society moral system
Prediction: sussietyen moras his tomgue
Predication Quantized 3 bits: ie emaos

Truth: environment just like more
Prediction: e maman to seley mord
Predication Quantized 3 bits: enbirment make morr

Truth: yeah
Prediction: ya
Predication Quantized 3 bits: 

Truth: and when you do vedio chat you can see the person
Prediction: and whilloum do readerchads you conseete a present
Predication Quantized 3 bits: w it o

Truth: ok
Prediction: o' ke
Predication Quantized 3 bits: 

