Please note: this notebook was designed to be used in a Sagemaker environment

In [None]:
!pip install -r requirements.txt

In [2]:
import warnings
import os
import torch
import time
import boto3
from os.path import join
from pathlib import Path
from tqdm import tqdm
from io import BytesIO
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import torchaudio
import librosa

In [None]:
BUCKET_NAME = 'your/input/bucket'
SAMPLING_RATE = 16000

script_dir = join(os.getcwd())
warnings.filterwarnings("ignore")
warnings.simplefilter(action='ignore', category=FutureWarning)

In [None]:
def get_s3_client():
    return boto3.client('s3')


def list_s3_files_folder(folder_name):
    if not folder_name.endswith('/'):
        folder_name += '/'

    keys = []
    continuation_token = None

    while True:
        if continuation_token:
            response = s3.list_objects_v2(
                Bucket=BUCKET_NAME,
                Prefix=folder_name,
                ContinuationToken=continuation_token
            )
        else:
            response = s3.list_objects_v2(
                Bucket=BUCKET_NAME,
                Prefix=folder_name
            )

        contents = response.get('Contents', [])
        keys.extend(
            obj['Key'] for obj in contents if obj['Key'].endswith('.wav')
        )

        if response.get('IsTruncated'):
            continuation_token = response.get('NextContinuationToken')
        else:
            break

    return keys


def get_file(file_name):
    buffer = BytesIO()
    s3.download_fileobj(BUCKET_NAME, file_name, buffer)
    buffer.seek(0)

    return buffer


def get_reference(folder_name):
    if not folder_name.endswith('/'):
        folder_name += '/'

    x = folder_name.rstrip('/').split('.')[0]
    expected_filename = f"{x}_reference.stm"
    expected_key = folder_name + expected_filename

    response = s3.list_objects_v2(Bucket=BUCKET_NAME, Prefix=folder_name)

    for obj in response.get('Contents', []):
        if obj['Key'] == expected_key:
            buffer = BytesIO()
            s3.download_fileobj(BUCKET_NAME, expected_key, buffer)
            stm_text = buffer.getvalue().decode("utf-8")
            return stm_text

    return None


def list_folders():
    response = s3.list_objects_v2(Bucket=BUCKET_NAME, Delimiter='/')
    return [p['Prefix'].rstrip('/') for p in response.get('CommonPrefixes', [])]


def get_duration(stm_text, s3_key):
    # Extract the filename without extension from the full S3 key
    clip_base = os.path.splitext(os.path.basename(s3_key))[0]

    for line in stm_text.splitlines():
        parts = line.strip().split()
        if len(parts) < 5:
            continue

        if parts[0] == clip_base:
            try:
                start = float(parts[3])
                end = float(parts[4])
                return end - start
            except ValueError:
                continue

    return None

In [5]:
s3 = get_s3_client()

In [None]:
def predict_clips_with_models(models):
    """Predict audio clips using wav2vec2 models and write .ctm + .tsv output per folder."""
    for model_name in models:
        processor = Wav2Vec2Processor.from_pretrained(model_name)
        model = Wav2Vec2ForCTC.from_pretrained(model_name)
        model.eval()

        device = torch.device("cuda")
        model = model.to(device)

        # Set up output directories
        model_name = model_name.split('/')[-1]
        model_dir = Path("results") / model_name
        ctm_dir = model_dir / "ctm"
        tsv_dir = model_dir / "tsv"
        model_dir.mkdir(parents=True, exist_ok=True)
        ctm_dir.mkdir(exist_ok=True)
        tsv_dir.mkdir(exist_ok=True)

        # Progress tracking
        progress_file = model_dir / "progress.tsv"
        completed_folders = set()
        if progress_file.exists():
            with open(progress_file, "r") as pf:
                completed_folders = set(line.strip() for line in pf)

        folders = list_folders()

        for folder in tqdm(folders, desc=f"{model_name} - folders", unit="folder", dynamic_ncols=True, position=0):
            if folder in completed_folders:
                continue

            x = folder.rstrip('/').split('.')[0]
            files = list_s3_files_folder(folder)

            ctm_lines = []
            tsv_lines = []
            file_sentences = {}

            ref_file = get_reference(folder)

            for file in tqdm(files, desc=f"Â  {folder}", unit="file", leave=False, dynamic_ncols=True, position=1):
                audio_buffer = get_file(file)
                waveform, sr = torchaudio.load(audio_buffer)
                waveform = waveform.squeeze().to(device)
                file_id = Path(file).stem

                assert sr == SAMPLING_RATE, f"Expected sampling rate {SAMPLING_RATE}, got {sr} in {file}"

                inputs = processor(waveform, sampling_rate=sr, return_tensors="pt", padding=True).to(device)

                start_time = time.perf_counter()
                with torch.no_grad():
                    logits = model(**inputs).logits

                pred_ids = torch.argmax(logits, dim=-1)
                output = processor.decode(pred_ids[0], output_word_offsets=True)

                time_offset = model.config.inputs_to_logits_ratio / processor.feature_extractor.sampling_rate
                end_time = time.perf_counter()

                words = []

                for d in output.word_offsets:
                    word = d["word"]
                    start = round(d["start_offset"] * time_offset, 2)
                    end = round(d["end_offset"] * time_offset, 2)
                    duration = round(end - start, 2)

                    confidence = d.get("score", 1.0)  # score sometimes not present
                    ctm_lines.append(f"{file_id} 1 {start:.2f} {duration:.2f} {word} {confidence:.4f}")
                    words.append(word)

                    duration = get_duration(ref_file, file)

                execution_time = end_time - start_time
                rtf = f"{(execution_time / duration):.4f}"

                prediction = " ".join(words)
                file_sentences[file_id] = prediction

                tsv_lines.append(f"{file_id}\t{rtf}\t{prediction}")

            safe_folder_name = folder.split(".")[0]
            ctm_filename = f"{model_name}_{safe_folder_name}.ctm"
            tsv_filename = f"{model_name}_{safe_folder_name}.tsv"

            with open(ctm_dir / ctm_filename, "w", encoding="utf-8") as f:
                f.write("\n".join(ctm_lines) + "\n")

            with open(tsv_dir / tsv_filename, "w", encoding="utf-8") as f:
                f.write("file\tRTF\tprediction\n")  # TSV header
                f.write("\n".join(tsv_lines) + "\n")

            # Mark this folder as completed
            with open(progress_file, "a") as pf:
                pf.write(folder + "\n")

In [None]:
models_to_test = ["jonatasgrosman/wav2vec2-large-xlsr-53-dutch", "GroNLP/wav2vec2-dutch-large-ft-cgn"]

predict_clips_with_models(models_to_test)