# Demo: Run inference and compute metrics

This notebook demonstrates how to: 
1. Download the required model checkpoints (encoder + decoder).
2. Provide a sample or your own dataset (Google Drive zip or local directory) in the format expected by `src.datasets.custom_dir`.
3. Run `inference.py` to produce predictions.
4. Run `evaluator.py` (metrics script) to compute WER/CER.

Notes: the repository already contains `inference.py` and `evaluator.py`. The evaluator script in this repo is `evaluator.py` (used instead of a `calc_metrics.py` name). Cells below include short explanations and runnable Python code.

## 1) Install dependencies
Run once to install packages from `requirements.txt`. In a Jupyter environment `!` works to call shell commands. If you already installed project dependencies (for example in a venv), you can skip this cell.

In [None]:
# Install requirements (uncomment if needed).
# Note: in some environments you may want to run this in your terminal instead of the notebook.
# !python -m pip install -r requirements.txt
print('If needed, run: python -m pip install -r requirements.txt')

## 2) Download model checkpoints
The repo's `Inferencer` will also download checkpoints automatically if needed, but this cell shows how to download them explicitly and save them under `checkpoints/`.
We use the same Google Drive links referenced in `src/trainer/inferencer.py` and `src/configs/baseline.yaml`.

In [None]:
from pathlib import Path
import gdown

CHECKPOINTS_DIR = Path('checkpoints')
CHECKPOINTS_DIR.mkdir(exist_ok=True)
links = {
    'encoder.pth': 'https://drive.google.com/uc?export=download&id=1UpX3_UgrbRTWYunAMHPsR09a1_zHzj7E',
    'decoder.pth': 'https://drive.google.com/uc?export=download&id=1A1Cb1TCn5LWYuIADkOsfvzlBBBi2L2bi',
}
for name, url in links.items():
    out_path = CHECKPOINTS_DIR / name
    if out_path.exists():
        print(f'{out_path} already exists, skipping download')
        continue
    print(f'Downloading {name} to {out_path} ...')
    try:
        gdown.download(url, str(out_path), quiet=False)
    except Exception as e:
        print('Download failed:', e)

print('Done. Checkpoints saved to', CHECKPOINTS_DIR.resolve())

## 3) Provide a dataset to run inference on
You have two options:
- Use the built-in HF Librispeech streaming dataset (configured in `src/configs/inference.yaml`).
- Provide your own dataset as a ZIP hosted on Google Drive (recommended format below).

If providing your own dataset, the ZIP should extract into a folder that contains `audio/` (wav/mp3/etc) and `transcriptions/` (matching .txt files for each audio file). The notebook will extract the archive into `data/datasets/custom_dir/` so `CustomDirDataset` can find it automatically.
Example structure after extraction:
data/datasets/custom_dir/audio/utt1.wav
data/datasets/custom_dir/transcriptions/utt1.txt

Below is a helper cell that will ask for a Google Drive file URL (or file id). It will download and extract the ZIP into `data/datasets/custom_dir/`.

In [None]:
import re
import shutil
from pathlib import Path
import zipfile
import gdown

def download_and_unzip_gdrive(link_or_id, target_dir=Path('data/datasets/custom_dir')):
    target_dir = Path(target_dir)
    target_dir.mkdir(parents=True, exist_ok=True)

    # Try to extract file id from several common Google Drive URL formats
    m = re.search(r'id=([A-Za-z0-9_-]+)', link_or_id) or re.search(r'/d/([A-Za-z0-9_-]+)/', link_or_id)
    if m:
        file_id = m.group(1)
    else:
        # maybe the user passed only an id
        file_id = link_or_id

    out_zip = Path('custom_dataset.zip')
    url = f'https://drive.google.com/uc?export=download&id={file_id}'
    print('Downloading from', url)
    gdown.download(url, str(out_zip), quiet=False)

    # Extract
    print('Extracting', out_zip)
    with zipfile.ZipFile(out_zip, 'r') as zf:
        # Extract into a temporary directory then move contents to target_dir
        tmpdir = Path('tmp_dataset_extract')
        if tmpdir.exists():
            shutil.rmtree(tmpdir)
        tmpdir.mkdir()
        zf.extractall(tmpdir)

        # Try to find audio/ and transcriptions/ inside extracted tree
        candidates = [p for p in tmpdir.iterdir() if p.is_dir()]
        if len(candidates) == 1:
            root = candidates[0]
        else:
            root = tmpdir

        # Move or copy audio/ and transcriptions/ into target_dir
        for sub in ['audio', 'transcriptions']:
            src = root / sub
            if src.exists():
                dest = target_dir / sub
                if dest.exists():
                    shutil.rmtree(dest)
                shutil.move(str(src), str(dest))

    print('Cleaning up temporary files')
    out_zip.unlink(missing_ok=True)
    shutil.rmtree('tmp_dataset_extract', ignore_errors=True)
    print('Dataset prepared at', target_dir.resolve())
    return target_dir.resolve()

# Ask the user for a Google Drive link or file id. If you want to skip and use HF librispeech, just leave blank and press Enter.
gdrive_link = input('Enter Google Drive file link or id to a ZIP dataset (press Enter to skip and use HF Librispeech): ').strip()
custom_data_path = None
if gdrive_link:
    custom_data_path = download_and_unzip_gdrive(gdrive_link)
    print('Custom dataset ready at:', custom_data_path)
else:
    print('Skipping custom dataset. Notebook will demonstrate HF Librispeech streaming if enabled.')

## 4) Run inference
This cell runs `inference.py` with Hydra overrides. If you downloaded a custom dataset in the previous cell, the notebook will run inference using `datasets.custom_dir`. Otherwise it will run the HF Librispeech streaming evaluation (if enabled in config).
The command below runs `inference.py` as a subprocess and prints stdout/stderr.

In [None]:
import subprocess
from pathlib import Path

def run_inference_on_custom(custom_path=None, output_dir='inference_predictions'):
    cmd = ['python', 'inference.py']
    if custom_path is not None:
        # pass hydra overrides to enable custom_dir and set path
        cmd += [f'datasets.custom_dir.enabled=True', f'datasets.custom_dir.path={custom_path}', f'inferencer.output_dir={output_dir}']
    else:
        print('No custom dataset provided. Running inference with default config (HF Librispeech if configured).')

    print('Running:', ' '.join(cmd))
    proc = subprocess.run(cmd, capture_output=True, text=True)
    print('=== STDOUT ===')
    print(proc.stdout)
    print('=== STDERR ===')
    print(proc.stderr)
    if proc.returncode != 0:
        raise RuntimeError(f'Inference script failed with code {proc.returncode}')
    return Path(output_dir).resolve()

# Run inference depending on whether the user provided a custom dataset
out_dir = None
try:
    if 'custom_data_path' in globals() and custom_data_path:
        out_dir = run_inference_on_custom(str(custom_data_path))
    else:
        out_dir = run_inference_on_custom(None)
    print('Inference finished. Predictions saved under:', out_dir)
except Exception as e:
    print('Inference failed:', e)

## 5) Compute metrics (WER/CER)
`evaluator.py` (the repo's metrics script) expects a folder with ground-truth `.txt` files and a folder with prediction `.txt` files. By default the config in `src/configs/metrics_eval.yaml` sets `paths.gt_dir` and `paths.pred_dir`.
If your ground truth files are in `data/ground_truth` and predictions were written to e.g. `inference_predictions/custom/`, run the evaluator with hydra overrides to match those paths.

In [None]:
import subprocess
from pathlib import Path

def run_evaluator(gt_dir, pred_dir, out_path='metrics/wer_cer_report.json'):
    cmd = ['python', 'evaluator.py', f'paths.gt_dir={gt_dir}', f'paths.pred_dir={pred_dir}', f'out_path={out_path}']
    print('Running:', ' '.join(cmd))
    proc = subprocess.run(cmd, capture_output=True, text=True)
    print('=== STDOUT ===')
    print(proc.stdout)
    print('=== STDERR ===')
    print(proc.stderr)
    if proc.returncode != 0:
        raise RuntimeError(f'Evaluator failed with code {proc.returncode}')
    return Path(out_path).resolve()

# Determine example gt_dir/pred_dir.
# If you used a custom dataset that had transcriptions/, use that folder as gt_dir.
if 'custom_data_path' in globals() and custom_data_path:
    gt_dir = Path(custom_data_path) / 'transcriptions'
    # Predictions structure depends on the inferencer output; often predictions will be in output_dir/<part>/pred_ID*.txt
    # We assume predictions were saved to inference_predictions/custom/ or inference_predictions/<part>/
    pred_dir = Path('inference_predictions')
else:
    # If you used HF Librispeech the predictions path must be set according to the inferencer output.
    gt_dir = Path('data/ground_truth')
    pred_dir = Path('data/predictions')

print('GT dir:', gt_dir)
print('Pred dir (top-level):', pred_dir)

# Run evaluator if directories look valid
if gt_dir.exists() and pred_dir.exists():
    report = run_evaluator(str(gt_dir), str(pred_dir))
    print('Saved metrics report to', report)
else:
    print('Ground-truth or prediction directory not found.
Please set the correct paths:')
    print('  - Ground truth (folder of .txt):', gt_dir)
    print('  - Predictions (folder of .txt):', pred_dir)

## Notes and next steps
- The notebook runs the existing `inference.py` and `evaluator.py` scripts via subprocess.
- If `inference.py` or `evaluator.py` require additional config flags in your environment (GPU selection, different batch sizes), pass them using additional Hydra overrides in the command lists above. Example: `inferencer.device=cuda` or `inferencer.batch_size=8`.
- I used `evaluator.py` present in the repository (named `evaluator.py`) rather than `calc_metrics.py` since there is no `calc_metrics.py` file. If you want a different evaluator script, adapt the evaluator-running cell accordingly.

Outputs:
- Checkpoints: `checkpoints/encoder.pth` and `checkpoints/decoder.pth`
- Predictions: folder specified by `inferencer.output_dir` (default `inference_predictions`)
- Metrics report: JSON at `metrics/wer_cer_report.json` (configurable via evaluator overrides)