Please note: this notebook was designed to be used in a Sagemaker environment

In [16]:
import boto3
import botocore
import asyncio
import aiohttp
import io
import wave
import nest_asyncio
import os
import numpy as np
from tqdm import tqdm
import time
from os.path import isfile, join
from io import BytesIO
import json
from datetime import timedelta
from dateutil import parser
from urllib.parse import urlparse
from pathlib import Path

In [None]:
SAMPLE_RATE = 16000
CHANNEL_NUMS = 1

REGION = "your region" ###
BUCKET_NAME = "your/input/bucket" ###
LANGUAGE_CODE = "nl-NL"
OUTPUT_BUCKET = "your/output/bucket" ###

script_dir = os.getcwd()

In [None]:
def get_s3_client():
    return boto3.client('s3')


def list_s3_files_folder(folder_name):
    if not folder_name.endswith('/'):
        folder_name += '/'

    keys = []
    continuation_token = None

    while True:
        if continuation_token:
            response = s3.list_objects_v2(
                Bucket=BUCKET_NAME,
                Prefix=folder_name,
                ContinuationToken=continuation_token
            )
        else:
            response = s3.list_objects_v2(
                Bucket=BUCKET_NAME,
                Prefix=folder_name
            )

        contents = response.get('Contents', [])
        keys.extend(
            obj['Key'] for obj in contents if obj['Key'].endswith('.wav')
        )

        if response.get('IsTruncated'):
            continuation_token = response.get('NextContinuationToken')
        else:
            break

    return keys


def get_reference(folder_name):
    if not folder_name.endswith('/'):
        folder_name += '/'

    x = folder_name.rstrip('/').split('.')[0]
    expected_filename = f"{x}_reference.stm"
    expected_key = folder_name + expected_filename

    response = s3.list_objects_v2(Bucket=BUCKET_NAME, Prefix=folder_name)

    for obj in response.get('Contents', []):
        if obj['Key'] == expected_key:
            buffer = BytesIO()
            s3.download_fileobj(BUCKET_NAME, expected_key, buffer)
            stm_text = buffer.getvalue().decode("utf-8")
            return stm_text

    return None


def list_folders():
    response = s3.list_objects_v2(Bucket=BUCKET_NAME, Delimiter='/')
    return [p['Prefix'].rstrip('/') for p in response.get('CommonPrefixes', [])]


def get_duration(stm_text, s3_key):
    clip_base = os.path.splitext(os.path.basename(s3_key))[0]

    for line in stm_text.splitlines():
        parts = line.strip().split()
        if len(parts) < 5:
            continue

        if parts[0] == clip_base:
            try:
                start = float(parts[3])
                end = float(parts[4])
                return end - start
            except ValueError:
                continue

    return None


def to_job_name(file_name):
    return file_name.replace('/', '_').replace(' ', '_').replace('.wav', '') #.replace('+', '')


def remove_folder_name(file_name):
    parts = file_name.split('_')
    return '_'.join(parts[2:])

In [None]:
def extract_ctm(data):
    file_id = remove_folder_name(data["jobName"])
    channel = '1'
    ctm_line = []

    for item in data["results"]["items"]:
        if item["type"] != "pronunciation":
            continue

        word = item["alternatives"][0]["content"]
        start_time = float(item["start_time"])
        end_time = float(item["end_time"])
        duration = end_time - start_time

        confidence = item["alternatives"][0].get("confidence", "1.0")
        ctm_word = f"{file_id} {channel} {start_time:.2f} {duration:.2f} {word} {confidence}"

        ctm_line.append(ctm_word)

    return ctm_line

In [20]:
def extract_transcript(data):
    transcript = data["results"]["transcripts"][0]["transcript"]
    return transcript

In [21]:
def get_json(file):
    job_name = to_job_name(file) + '.json'
    try:
        response = s3.get_object(Bucket=OUTPUT_BUCKET, Key=job_name)
        content = response['Body'].read()
        data = json.loads(content)
        return data
    except botocore.exceptions.ClientError as e:
        print(f"Skipping {job_name} — could not retrieve from S3: {e.response['Error']['Message']}")
        return None


def transcribe_audio(file):
    job_name = to_job_name(file)
    s3_uri = f"s3://{BUCKET_NAME}/{file}"

    command = (
        "aws transcribe start-transcription-job "
        "--region {region} "
        "--transcription-job-name {job_name} "
        "--media MediaFileUri='{url}' "
        "--output-bucket-name {output_bucket} "
        "--language-code {language_code} "
    ).format(
        job_name=job_name,
        url=s3_uri,
        output_bucket=OUTPUT_BUCKET,
        region=REGION,
        language_code=LANGUAGE_CODE
    )
    os.system(command)


def process_results(file):
    data = get_json(file)
    if data is None:
        return None
    ctm = extract_ctm(data)
    transcript = extract_transcript(data)
    processing_time = get_processing_time(to_job_name(file))
    return ctm, transcript, processing_time

In [22]:
def get_processing_time(job_name):
    transcribe = boto3.client('transcribe')
    response = transcribe.get_transcription_job(TranscriptionJobName=job_name)

    job = response['TranscriptionJob']
    creation_time = job['CreationTime']
    completion_time = job.get('CompletionTime')
    execution_time = completion_time - creation_time
    execution_time = execution_time.total_seconds()

    return execution_time

In [None]:
def predict_clips_with_models_ctm():
    model_name = "aws-transcribe"

    # Set up output directories
    model_dir = Path("results") / model_name
    ctm_dir = model_dir / "ctm"
    tsv_dir = model_dir / "tsv"
    model_dir.mkdir(parents=True, exist_ok=True)
    ctm_dir.mkdir(exist_ok=True)
    tsv_dir.mkdir(exist_ok=True)

    # Progress tracking
    progress_file = model_dir / "progress.tsv"
    completed_folders = set()
    if progress_file.exists():
        with open(progress_file, "r") as pf:
            completed_folders = set(line.strip() for line in pf)

    folders = list_folders()
    folders = folders[0:9]  ## edit to process a subset

    for folder in tqdm(folders, desc=f"{model_name} - folders", unit="folder", dynamic_ncols=True, position=0):
        print(folder)
        if folder in completed_folders:
            continue

        x = folder.rstrip('/').split('.')[0]
        files = list_s3_files_folder(folder)

        ctm_lines = []
        tsv_lines = []
        skipped = 0

        ref_file = get_reference(folder)

        for file in tqdm(files, desc=f"  {folder}", unit="file", leave=False, dynamic_ncols=True, position=1):
            transcribe_audio(file)

        for file in tqdm(files, desc=f"  {folder}", unit="file", leave=False, dynamic_ncols=True, position=1):
            result = process_results(file)
            if result is None:
                print(f"Couldn't find job name for {file}")
                skipped += 1
                continue
            ctm, transcript, execution_time = result
            file_id = Path(file).stem

            if int(x) == 5 or int(x) == 6:
                duration = 10.0
            else:
                duration = get_duration(ref_file, file)

            rtf = f"{(execution_time / duration):.4f}"
            tsv_lines.append(f"{file_id}\t{rtf}\t{transcript}")

            ctm_lines.append(ctm)

        ctm_lines = np.concatenate(ctm_lines).tolist()
        safe_folder_name = folder.split(".")[0]
        ctm_filename = f"{model_name}_{safe_folder_name}.ctm"
        tsv_filename = f"{model_name}_{safe_folder_name}.tsv"

        with open(ctm_dir / ctm_filename, "w", encoding="utf-8") as f:
            f.write("\n".join(ctm_lines) + "\n")

        with open(tsv_dir / tsv_filename, "w", encoding="utf-8") as f:
            f.write("file\tRTF\tprediction\n")  # TSV header
            f.write("\n".join(tsv_lines) + "\n")

        # Mark this folder as completed
        with open(progress_file, "a") as pf:
            pf.write(folder + "\n")

        print(f"Skipped {skipped} lines in this folder")

In [27]:
transcribe_client = boto3.client('transcribe')
s3 = get_s3_client()

In [None]:
predict_clips_with_models_ctm()