In [None]:
!pip install datasets transformers
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

# required to sample audio
!pip install librosa soundfile pyaudio

!pip install ipywidgets

# required for Wav2Vec2ProcessorWithLM
# !pip install pyctcdecode
# !pip install https://github.com/kpu/kenlm/archive/master.zip

In [67]:
from pyaudio import PyAudio

p = PyAudio()

print("Available input devices:")

for idx in range(p.get_device_count()):
    device = p.get_device_info_by_index(idx)
    if device["maxInputChannels"] > 0 and device["maxOutputChannels"] == 0:
        print(device)
    
device = p.get_default_input_device_info()
device_idx = int(device["index"])
device_sample_rate = int(device["defaultSampleRate"])
print("\nDefault input device:", device_idx)
p.terminate()

Available input devices:
{'index': 0, 'structVersion': 2, 'name': 'Microsoft Sound Mapper - Input', 'hostApi': 0, 'maxInputChannels': 2, 'maxOutputChannels': 0, 'defaultLowInputLatency': 0.09, 'defaultLowOutputLatency': 0.09, 'defaultHighInputLatency': 0.18, 'defaultHighOutputLatency': 0.18, 'defaultSampleRate': 44100.0}
{'index': 1, 'structVersion': 2, 'name': 'Microphone (3- SteelSeries Arct', 'hostApi': 0, 'maxInputChannels': 1, 'maxOutputChannels': 0, 'defaultLowInputLatency': 0.09, 'defaultLowOutputLatency': 0.09, 'defaultHighInputLatency': 0.18, 'defaultHighOutputLatency': 0.18, 'defaultSampleRate': 44100.0}
{'index': 2, 'structVersion': 2, 'name': 'Microphone (High Definition Aud', 'hostApi': 0, 'maxInputChannels': 2, 'maxOutputChannels': 0, 'defaultLowInputLatency': 0.09, 'defaultLowOutputLatency': 0.09, 'defaultHighInputLatency': 0.18, 'defaultHighOutputLatency': 0.18, 'defaultSampleRate': 44100.0}
{'index': 3, 'structVersion': 2, 'name': 'Microphone (Steam Streaming Mic', 'ho

In [68]:
from ipywidgets import widgets
from IPython.display import display
from threading import Thread, Lock
import numpy as np
import torch
import torchaudio.functional as F

from transformers import AutoModelForCTC, Wav2Vec2Processor
from queue import Queue
from pyaudio import PyAudio, paInt16, Stream

CHANNELS = 1
SAMPLE_RATE = 16000
RECORD_SECONDS = 30
CHUNK = 1024
AUDIO_FORMAT = paInt16
SAMPLE_SIZE = 2
MODEL_ID = "Jzuluaga/wav2vec2-large-960h-lv60-self-en-atc-uwb-atcc-and-atcosim"
model = AutoModelForCTC.from_pretrained(MODEL_ID)
processor = Wav2Vec2Processor.from_pretrained(MODEL_ID)
recordings = Queue()
output = widgets.Output()

def is_buffer_full(buffer, chunk):
    return len(buffer) >= (SAMPLE_RATE * RECORD_SECONDS) / chunk

def process(sample: bytes):
    nparray = np.frombuffer(sample, dtype=np.int16).astype(np.float32) / 32767.0
    tensor = torch.from_numpy(nparray)
    resampled_audio = F.resample(tensor, device_sample_rate, SAMPLE_RATE).numpy()
    input_values = processor(resampled_audio, sampling_rate=SAMPLE_RATE, return_tensors="pt").input_values
    
    with torch.no_grad():
        logits = model(input_values).logits
        pred_ids = torch.argmax(logits, dim=-1)
        transcription = processor.batch_decode(pred_ids)
    
    return transcription[0]

class RecordState:
    _state = False
    _lock = Lock()
    
    def is_recording(self):
        return self._state
    
    def set_recording(self, mode: bool):
        self._lock.acquire()
        self._state = mode
        self._lock.release()

state = RecordState()

record_btn = widgets.Button(
    description="Record",
    disabled=False,
    button_style="success",
    icon="microphone",
)
stop_btn = widgets.Button(
    description="Stop",
    disabled=False,
    button_style="warning",
    icon="stop",
)

def start_recording(data):
    audio = PyAudio()
    stream = audio.open(
            format=AUDIO_FORMAT,
            channels=CHANNELS,
            rate=device_sample_rate,
            frames_per_buffer=CHUNK,
            input=True,
            input_device_index=device_idx,
        )
    record_thread = Thread(target=record, args=(output, audio, stream, ))
    transcribe_thread = Thread(target=transcribe_loop, args=(output,))
    output.append_stdout("Recording...\n")
    state.set_recording(True)
    record_thread.start()
    transcribe_thread.start()
        
def stop_recording(data):
    state.set_recording(False)
        
# chunk defines how often we read the microphone
def record(output: widgets.Output, audio: PyAudio, stream: Stream):
    buffer = []
    
    while state.is_recording():
        data = stream.read(CHUNK, True)
        buffer.append(data)
        
        if is_buffer_full(buffer, CHUNK):
            recordings.put(buffer.copy())
            buffer = []
            
    if buffer:
        recordings.put(buffer.copy())
        transcribe(output)
    
    stream.stop_stream()
    stream.close()
    audio.terminate()

def transcribe_loop(output: widgets.Output):
    while state.is_recording():
        transcribe(output)
        
def transcribe(output: widgets.Output):
    frames = recordings.get()
    binary = b''.join(frames)
    text = process(binary)
    output.append_stdout(text + '\n')
        
record_btn.on_click(start_recording)
stop_btn.on_click(stop_recording)

display(record_btn, stop_btn, output)

Some weights of the model checkpoint at Jzuluaga/wav2vec2-large-960h-lv60-self-en-atc-uwb-atcc-and-atcosim were not used when initializing Wav2Vec2ForCTC: ['wav2vec2.encoder.pos_conv_embed.conv.weight_g', 'wav2vec2.encoder.pos_conv_embed.conv.weight_v']
- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at Jzuluaga/wav2vec2-large-960h-lv60-self-en-atc-uwb-atcc-and-atcosim and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.

Button(button_style='success', description='Record', icon='microphone', style=ButtonStyle())



Output()