# Wav2Vec2 Speech-to-Text â€” Feasibility & Baseline (SageMaker)

Objective: establish a lean, repeatable path to an **ASR endpoint on SageMaker** using a Hugging Face Wav2Vec2 model; validate basic **latency**, **cost posture**, and **quality** (WER/CER) on a small test set.

**Scope**
- No training; inference-only via HF DLC or minimal handler.
- Tiny sample set to sanity-check performance before productizing.

**Contents**
1. Environment & config
2. Local smoke test
3. Endpoint invocation helper
4. Optional evaluation (WER/CER) if references available
5. Notes & follow-ups


In [ ]:
# %% Environment (run locally). Comment-out installs if already present.
# !pip install --quiet boto3 jiwer soundfile pydub
import os, io, json, time, pathlib
import boto3
from botocore.config import Config
from jiwer import wer, cer
from pydub import AudioSegment

# Config
AWS_REGION = os.environ.get("AWS_DEFAULT_REGION", "us-east-1")
ENDPOINT_NAME = os.environ.get("SM_ENDPOINT_NAME", "vrynt-stt-demo")
SAMPLE_DIR = pathlib.Path("samples")  # put a few short wav/mp3 files here


## 1) Local smoke test (no endpoint)
For a pure local check with a small model, use `transformers` (optional). This cell is intentionally skipped to keep the notebook lightweight.

> If you want to run locally, install `transformers` + `datasets` + `torch` and load `facebook/wav2vec2-base-960h`.

In [ ]:
SKIP_LOCAL = True  # set to False if you want to run a local small-model smoke test
if not SKIP_LOCAL:
    # from transformers import AutoProcessor, AutoModelForCTC
    # import torch, soundfile as sf
    # processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h")
    # model = AutoModelForCTC.from_pretrained("facebook/wav2vec2-base-960h")
    # waveform, sr = sf.read("samples/short.wav")
    # inputs = processor(waveform, sampling_rate=sr, return_tensors="pt")
    # with torch.inference_mode():
    #     ids = model(**inputs).logits.argmax(-1)
    # print(processor.batch_decode(ids)[0])
    pass

## 2) Invoke SageMaker endpoint
Helper to send audio bytes to the endpoint and get a transcript JSON back.

In [ ]:
def invoke_stt(endpoint_name: str, region: str, audio_bytes: bytes):
    smr = boto3.client("sagemaker-runtime", region_name=region, config=Config(retries={'max_attempts': 3}))
    t0 = time.time()
    resp = smr.invoke_endpoint(EndpointName=endpoint_name, ContentType="application/x-audio", Body=audio_bytes)
    latency = time.time() - t0
    body = resp["Body"].read()
    try:
        data = json.loads(body.decode("utf-8"))
    except Exception:
        data = {"raw": body.decode("utf-8", errors="ignore")}
    return data, latency

def load_audio_bytes(path: pathlib.Path) -> bytes:
    audio = AudioSegment.from_file(path, format=path.suffix.strip('.'))
    buf = io.BytesIO()
    audio.export(buf, format="wav")
    return buf.getvalue()

### Batch over a small sample folder
If you add files under `samples/` (e.g., `short1.wav`, `short2.mp3`), this loop will collect latency stats.


In [ ]:
results = []
if SAMPLE_DIR.exists():
    for p in sorted(SAMPLE_DIR.iterdir()):
        if p.suffix.lower() not in {'.wav', '.mp3'}:
            continue
        data, lat = invoke_stt(ENDPOINT_NAME, AWS_REGION, load_audio_bytes(p))
        results.append({"file": p.name, "latency_s": round(lat, 3), "text": data.get("text") or data})
results

## 3) Optional: quality evaluation (WER/CER)
If you have ground-truth references, place a `references.jsonl` in `samples/` with lines:

```json
{"file": "short1.wav", "reference": "hello world"}
```

This cell will join predictions with references and compute WER/CER.

In [ ]:
import json
pred_map = {r['file']: r for r in results}
refs_path = SAMPLE_DIR / 'references.jsonl'
if refs_path.exists():
    refs = [json.loads(l) for l in open(refs_path)]
    y_true, y_pred = [], []
    for r in refs:
        file = r['file']
        gt = r['reference']
        pred = pred_map.get(file, {}).get('text', '') if pred_map else ''
        if isinstance(pred, dict):
            pred = pred.get('text', '')
        y_true.append(gt)
        y_pred.append(pred)
    print({
        'samples': len(y_true),
        'wer': wer(y_true, y_pred),
        'cer': cer(y_true, y_pred)
    })
else:
    print('No references.jsonl found; skipping WER/CER.')

## 4) Notes
- For the public demo, start with a single `ml.m5.xlarge` and autoscale when needed.
- Keep transcripts out of logs; store only timing + request size if needed.
- If latency tails are high on MP3, prefer WAV uploads from the client.
