In [None]:
import io
import numpy as np
import torch
import torchaudio
import torchaudio.transforms as T

import pandas as pd
from pyspark.sql.functions import pandas_udf, col, udf
from pyspark.sql.types import ArrayType, FloatType, StringType, IntegerType

In [None]:
%%configure -f
{
  "executorCores": 1,
  "conf": {
    "spark.sql.execution.arrow.maxRecordsPerBatch": "64",
    "spark.executorEnv.HF_HOME": "/tmp/huggingface"
  }
}

In [None]:
TRANSCRIPTION_MODEL = "openai/whisper-tiny"
NEW_SAMPLING_RATE = 16000

_processor_cache = {"processor": None}


def get_processor():
    if _processor_cache["processor"] is None:
        from transformers import AutoProcessor

        _processor_cache["processor"] = AutoProcessor.from_pretrained(
            TRANSCRIPTION_MODEL
        )
    return _processor_cache["processor"]


_model_cache = {"model": None, "device": None, "dtype": None}


def get_model():
    if _model_cache["model"] is None:
        from transformers import AutoModelForSpeechSeq2Seq

        device = "cuda" if torch.cuda.is_available() else "cpu"
        dtype = torch.float16 if torch.cuda.is_available() else torch.float32
        model = AutoModelForSpeechSeq2Seq.from_pretrained(
            TRANSCRIPTION_MODEL,
            torch_dtype=dtype,
            low_cpu_mem_usage=True,
            use_safetensors=True,
        ).to(device)
        _model_cache["model"] = model
        _model_cache["device"] = device
        _model_cache["dtype"] = dtype
    return _model_cache["model"], _model_cache["device"], _model_cache["dtype"]


@pandas_udf(ArrayType(FloatType()))
def resample_udf(audio_bytes: pd.Series) -> pd.Series:
    results = []
    for bytes_arr in audio_bytes:
        waveform, sampling_rate = torchaudio.load(io.BytesIO(bytes_arr))
        waveform = T.Resample(sampling_rate, NEW_SAMPLING_RATE)(waveform).squeeze()
        results.append(waveform.numpy().astype(np.float32).tolist())
    return pd.Series(results)


@pandas_udf(ArrayType(ArrayType(FloatType())))
def whisper_preprocess_udf(resampled: pd.Series) -> pd.Series:
    processor = get_processor()
    features = processor(
        resampled.tolist(), sampling_rate=NEW_SAMPLING_RATE, return_tensors="np"
    ).input_features
    return pd.Series([f.astype(np.float32).tolist() for f in features])


@pandas_udf(ArrayType(IntegerType()))
def transcriber_udf(extracted_features: pd.Series) -> pd.Series:
    model, device, dtype = get_model()
    batch = [np.array(feat, dtype=np.float32) for feat in extracted_features]
    spectrograms = torch.tensor(batch, dtype=dtype, device=device)
    with torch.no_grad():
        token_ids = model.generate(spectrograms)
    return pd.Series([toks.cpu().numpy().tolist() for toks in token_ids])


@pandas_udf(StringType())
def decode_udf(token_ids: pd.Series) -> pd.Series:
    processor = get_processor()
    return pd.Series(
        processor.batch_decode(token_ids.tolist(), skip_special_tokens=True)
    )

In [None]:
df = spark.read.parquet("s3://daft-public-datasets/common_voice_17")
df = df.withColumn("resampled", resample_udf(col("audio.bytes")))
df = df.withColumn("extracted_features", whisper_preprocess_udf(col("resampled")))
df = df.withColumn("token_ids", transcriber_udf(col("extracted_features")))
df = df.withColumn("transcription", decode_udf(col("token_ids")))
df = df.withColumn(
    "transcription_length", udf(lambda x: len(x), IntegerType())(col("transcription"))
)

final_df = df.drop("token_ids", "extracted_features", "resampled")
final_df.write.mode("append").parquet(
    "s3://eventual-dev-benchmarking-results/ai-benchmark-results/audio-transcription"
)