diff --git a/scripts/tts_comparison_report/README.md b/scripts/tts_comparison_report/README.md new file mode 100644 index 000000000000..a234c721a49a --- /dev/null +++ b/scripts/tts_comparison_report/README.md @@ -0,0 +1,215 @@ +# TTS Comparison Report + +This tool generates HTML comparison reports for TTS evaluation buckets and uploads them to S3. + +The `generate_report` script compares two evaluation buckets produced by `magpietts_inference` and generates: +1. an HTML evaluation report with aggregated and per-benchmark metrics; +2. an optional HTML audio comparison report with side-by-side audio samples. + +Both reports are uploaded to S3-compatible object storage and returned as presigned URLs, +which can be opened directly in a browser. If the audio report is enabled, its link is +also embedded into the evaluation report. + +The generated reports are designed to make model comparison faster, easier to share, and easier to review. + +## Supported workflows + +The script supports: + +- local evaluation buckets; +- remote evaluation buckets accessed over SSH/SFTP; +- upload of generated reports and audio assets to S3-compatible object storage. + +## Terminology + +A **bucket** in this tool means the root directory of one evaluation run. + +Evaluation artifacts are expected to be located either: +- directly inside the experiment root; or +- inside the subdirectory given by `--results_subdir` (default: `results`). + +Typical layouts: + +**Local generation** +```text +experiment_root/ +├── benchmark_1 +├── benchmark_2 +└── benchmark_3 +``` + +**Cluster / Slurm generation** +```text +experiment_root/ +├── logs_dir +└── results_subdir + ├── benchmark_1 + ├── benchmark_2 + └── benchmark_3 +``` + +In both cases, `--baseline_path` and `--candidate_path` should point to the experiment root. +If evaluation artifacts are stored directly inside the experiment root, set `--results_subdir` +to an empty string. If evaluation artifacts are stored under a dedicated subdirectory, +set `--results_subdir` to that subdirectory name. + +## Installation + +It is recommended to run the script inside the NeMo Docker container, starting from version `25.11`, +since it already includes all required dependencies. + +From the repository root, update the local NeMo package with: +```bash +pip install -e ./ --no-deps +``` +The `--no-deps` flag is used because the required dependencies are already available +in the recommended NeMo Docker environment. + +If you use a different environment, install the required dependencies from `requirements.txt`: +```bash +pip install -r scripts/tts_comparison_report/requirements.txt +``` + +## Environment variables + +Before running the script, make sure that the required environment variables are set. + +The following variables are required for uploading reports and related assets +to S3-compatible object storage: +- `S3_ACCESS_KEY_ID` - S3 access key, +- `S3_SECRET_ACCESS_KEY` - S3 secret key. + +Example: +```bash +export S3_ACCESS_KEY_ID='your_s3_key_id' S3_SECRET_ACCESS_KEY='your_s3_secret' +``` + +If the evaluation buckets are stored on a remote machine, also set: +- `REMOTE_PASSWORD` - password used for SSH authentication. + +Example: +```bash +export REMOTE_PASSWORD='your_ssh_password' +``` + +## Usage examples + +To generate and upload only the evaluation report from local buckets, run: + +```bash +python scripts/tts_comparison_report/generate_report.py \ + --baseline_name "Model A" \ + --baseline_path /workspace/NeMo/exp/buckets/baseline \ + --candidate_name "Model B" \ + --candidate_path /workspace/NeMo/exp/buckets/candidate \ + --s3_endpoint https://your-s3-endpoint \ + --s3_bucket your_bucket_name \ + --s3_region us-west-2 \ + --task_id NEMOTTS-2007 +``` + +If the evaluation artifacts are stored in a non-default results subdirectory, +use `--results_subdir`. + +To generate both the evaluation report and the audio comparison report, use `--audio_report`. +You can also use `--audio_report_benchmarks` and `--samples_per_benchmark` +to control which benchmarks are included in the audio report and how many +samples are selected for each benchmark. + +```bash +python scripts/tts_comparison_report/generate_report.py \ + --baseline_name "Model A" \ + --baseline_path /workspace/NeMo/exp/buckets/baseline \ + --candidate_name "Model B" \ + --candidate_path /workspace/NeMo/exp/buckets/candidate \ + --s3_endpoint https://your-s3-endpoint \ + --s3_bucket your_bucket_name \ + --s3_region us-west-2 \ + --task_id NEMOTTS-2007 \ + --audio_report \ + --audio_report_benchmarks libritts_test_clean,riva_hard_digits \ + --samples_per_benchmark 20 +``` + +If the buckets are located on a remote machine, specify `--remote_hostname` +and `--remote_username`: + +```bash +python scripts/tts_comparison_report/generate_report.py \ + --baseline_name "Model A" \ + --baseline_path /mnt/exps/baseline \ + --candidate_name "Model B" \ + --candidate_path /mnt/exps/candidate \ + --s3_endpoint https://your-s3-endpoint \ + --s3_bucket your_bucket_name \ + --s3_region us-west-2 \ + --task_id NEMOTTS-2007 \ + --audio_report \ + --audio_report_benchmarks libritts_test_clean,riva_hard_digits \ + --samples_per_benchmark 20 \ + --remote_hostname your_remote_host \ + --remote_username your_user +``` + +You can also restrict the evaluation report to a selected set of benchmarks +by using `--benchmarks`: + +```bash +python scripts/tts_comparison_report/generate_report.py \ + --baseline_name "Model A" \ + --baseline_path /workspace/NeMo/exp/buckets/baseline \ + --candidate_name "Model B" \ + --candidate_path /workspace/NeMo/exp/buckets/candidate \ + --benchmarks libritts_test_clean,riva_hard_digits,riva_hard_letters \ + --s3_endpoint https://your-s3-endpoint \ + --s3_bucket your_bucket_name \ + --s3_region us-west-2 \ + --task_id NEMOTTS-2007 \ + --audio_report \ + --audio_report_benchmarks libritts_test_clean,riva_hard_digits \ + --samples_per_benchmark 20 +``` + +## Notes + +- `magpietts_inference` supports several repetitions, but this script compares +only artifacts from repetition `0`. +- `--results_subdir` is not the experiment root. It is the subdirectory inside +the experiment root that contains evaluation outputs such as metrics and generated audio. +If evaluation artifacts are stored directly inside the experiment root, `--results_subdir` +should be set to an empty string. +- Both generated reports are HTML reports uploaded to S3-compatible object storage. +- If the audio report is enabled, the evaluation report includes a link to the audio report. +- Audio files referenced by the audio report are uploaded separately and linked through presigned URLs. +- Box plot images are also uploaded to S3 and embedded into the evaluation report via presigned URLs. +- Presigned S3 links expire. Both generated reports include the expiration time directly in the HTML page. +The default expiration time is one year. +- The expiration time is also included as a suffix in the uploaded artifacts +directory name, using the format `%Y-%m-%dT%H-%M-%SZ`, so uploaded reports +can be filtered and deleted later if needed. +- Both generated reports include a clickable Jira link derived from `--task_id`. +If no task ID is specified, the link points to the Jira project page. + +## Maintenance + +### Updating benchmarks + +To add or remove a benchmark, update `SUPPORTED_BENCHMARK_NAMES` in `reporting/constants.py`. + +### Updating metrics + +By design, metrics are divided into two groups: +- **standard metrics**, used for reporting aggregated values in the evaluation report; +- **distribution metrics**, used for statistical tests and box plot visualization. + +The metric specifications are defined in `reporting/metrics/specs.py`. + +To add or remove a metric, update the metric registries in `reporting/metrics/registry.py`: +- `MetricsRegistry` - for standard aggregated metrics; +- `DistributionMetricsRegistry` - for metrics used in statistical tests and visualizations. + +### Modifying bucket structure + +If the bucket structure changes, update the `BucketStructure` class, +which defines how report artifacts are located inside an evaluation bucket. +See `reporting/models.py`. diff --git a/scripts/tts_comparison_report/__init__.py b/scripts/tts_comparison_report/__init__.py new file mode 100644 index 000000000000..9e3fb699d9f6 --- /dev/null +++ b/scripts/tts_comparison_report/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/scripts/tts_comparison_report/generate_report.py b/scripts/tts_comparison_report/generate_report.py new file mode 100644 index 000000000000..287731e29919 --- /dev/null +++ b/scripts/tts_comparison_report/generate_report.py @@ -0,0 +1,320 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os +from argparse import ArgumentParser, RawDescriptionHelpFormatter +from pathlib import Path +from typing import Optional + +from paramiko import AutoAddPolicy, SSHClient +from paramiko.sftp_client import SFTPClient +from scripts.tts_comparison_report.reporting import ( + DUMMY_TASK_ID, + SUPPORTED_BENCHMARK_NAMES, + TEMPLATES_DIR, + BaseStorage, + BucketStructure, + LocalStorage, + Orchestrator, + Renderer, + S3Client, + S3Config, + SFTPStorage, +) + +logging.basicConfig(level=logging.INFO, format="%(message)s") +logging.captureWarnings(True) + + +_REMOTE_PASSWORD: str = "REMOTE_PASSWORD" +_S3_ACCESS_KEY_ID: str = "S3_ACCESS_KEY_ID" +_S3_SECRET_ACCESS_KEY: str = "S3_SECRET_ACCESS_KEY" + +_DEFAULT_BENCHMARK_NAMES: str = ",".join(SUPPORTED_BENCHMARK_NAMES) + + +def _create_argparser() -> ArgumentParser: + parser = ArgumentParser( + description="Script for generating MagpieTTS evaluation comparison reports", + formatter_class=RawDescriptionHelpFormatter, + ) + parser.add_argument( + "--baseline_name", + type=str, + required=True, + help="Name of the baseline model that will be used in report.", + ) + parser.add_argument( + "--baseline_path", + type=str, + required=True, + help="Path to the generated evaluation bucket for the baseline model.", + ) + parser.add_argument( + "--candidate_name", + type=str, + required=True, + help="Name of the candidate model that will be used in report.", + ) + parser.add_argument( + "--candidate_path", + type=str, + required=True, + help="Path to the generated evaluation bucket for the candidate model.", + ) + parser.add_argument( + "--benchmarks", + type=str, + default=_DEFAULT_BENCHMARK_NAMES, + help="Comma-separated list of benchmarks used in the evaluation report.", + ) + parser.add_argument( + "--s3_endpoint", + type=str, + required=True, + help="S3 endpoint URL used for uploading the audio report.", + ) + parser.add_argument( + "--s3_bucket", + type=str, + required=True, + help="Name of the S3 bucket where the audio report HTML and audio files will be uploaded.", + ) + parser.add_argument( + "--s3_region", + type=str, + required=True, + help="AWS region name for the S3 client.", + ) + parser.add_argument( + "--remote_hostname", + type=str, + default=None, + help="Name of the remote host, if the generated buckets are located there.", + ) + parser.add_argument( + "--remote_username", + type=str, + default=None, + help="Name of the user on the remote host.", + ) + parser.add_argument( + "--task_id", + type=str, + default=DUMMY_TASK_ID, + help="Jira task number associated with this report.", + ) + parser.add_argument( + "--results_subdir", + type=str, + default="results", + help="Subdirectory inside the bucket root that contains evaluation outputs produced by `magpietts_inference`.", + ) + parser.add_argument( + "--audio_report", + action='store_true', + help="Generate additional report with side-by-side audio comparison.", + ) + parser.add_argument( + "--audio_report_benchmarks", + type=str, + default="libritts_test_clean,riva_hard_digits,riva_hard_letters", + help="Comma-separated list of benchmarks to include in the audio report.", + ) + parser.add_argument( + "--samples_per_benchmark", + type=int, + default=30, + help="Number of samples per benchmark in the audio report.", + ) + return parser + + +def _get_benchmarks_list(benchmarks: str) -> list[str]: + return [x.strip() for x in benchmarks.split(",") if x.strip()] + + +def _validate_benchmarks(benchmarks: list[str]) -> None: + if not benchmarks: + raise ValueError("Empty list of benchmark names was provided.") + + supported_set = set(SUPPORTED_BENCHMARK_NAMES) + + for name in benchmarks: + if name not in supported_set: + raise ValueError(f"Unknown benchmark name: '{name}'.") + + +def _validate_audio_report_benchmarks( + benchmarks: list[str], + audio_report_benchmarks: list[str], +) -> None: + if not audio_report_benchmarks: + raise ValueError("Empty list of benchmark names was provided for the audio report.") + + supported_set = set(benchmarks) + + for name in audio_report_benchmarks: + if name not in supported_set: + raise ValueError(f"Benchmark name for audio report '{name}' is not included in evaluation benchmarks.") + + +def main() -> None: + """Parse CLI arguments, generate comparison reports, and upload them to S3. + + This function serves as the command-line entry point for the report + generation workflow. It validates user input, initializes storage and S3 + clients, runs the report orchestrator, and logs the resulting report URLs. + + Raises: + ValueError: If required environment variables are missing or CLI + arguments are invalid. + RuntimeError: If report generation or upload does not complete + successfully. + """ + logger = logging.getLogger(__name__) + + parser = _create_argparser() + args = parser.parse_args() + + bucket_structure = BucketStructure() + bucket_structure.eval_output_subdir = args.results_subdir + baseline_path = Path(args.baseline_path).resolve() + candidate_path = Path(args.candidate_path).resolve() + task_id = args.task_id + + storage: BaseStorage + s3_client: Optional[S3Client] = None + ssh_client: Optional[SSHClient] = None + sftp: Optional[SFTPClient] = None + eval_report_url: Optional[str] = None + audio_report_url: Optional[str] = None + audio_report_benchmarks: Optional[list[str]] = None + + s3_key_id = os.getenv(_S3_ACCESS_KEY_ID) + s3_secret_key = os.getenv(_S3_SECRET_ACCESS_KEY) + + if s3_key_id is None or s3_secret_key is None: + raise ValueError( + f"Environment variables '{_S3_ACCESS_KEY_ID}' and '{_S3_SECRET_ACCESS_KEY}' " + "must be set for uploading reports to S3." + ) + + s3_cfg = S3Config( + bucket=args.s3_bucket, + endpoint_url=args.s3_endpoint, + region_name=args.s3_region, + ) + s3_client = S3Client( + cfg=s3_cfg, + aws_access_key_id=s3_key_id, + aws_secret_access_key=s3_secret_key, + ) + + benchmarks = _get_benchmarks_list(args.benchmarks) + _validate_benchmarks(benchmarks) + + if args.audio_report: + audio_report_benchmarks = _get_benchmarks_list(args.audio_report_benchmarks) + _validate_audio_report_benchmarks(benchmarks, audio_report_benchmarks) + + if args.samples_per_benchmark <= 0: + raise ValueError("Number of samples per benchmark for the audio report must be greater than 0.") + + if task_id == DUMMY_TASK_ID: + logger.warning("\nWARNING: It is recommended to assign the evaluation report to a specific ticket!") + + if baseline_path == candidate_path: + logger.warning( + "\nWARNING: Baseline and candidate paths are identical. " + "Comparison report is not meaningful in this case!" + ) + + logger.info(f"\nComparing baseline '{args.baseline_name}' against candidate '{args.candidate_name}'") + + try: + if args.remote_hostname is not None or args.remote_username is not None: + if args.remote_username is None: + raise ValueError("'remote_username' must be provided when using remote access.") + + if args.remote_hostname is None: + raise ValueError("'remote_hostname' must be provided when using remote access.") + + remote_password = os.getenv(_REMOTE_PASSWORD) + + if remote_password is None: + raise ValueError(f"Environment variable '{_REMOTE_PASSWORD}' is not set.") + + logger.info(f"\nSetting remote connection with host: {args.remote_hostname}") + + ssh_client = SSHClient() + ssh_client.set_missing_host_key_policy(policy=AutoAddPolicy()) + ssh_client.connect( + hostname=args.remote_hostname, + username=args.remote_username, + password=remote_password, + ) + sftp = ssh_client.open_sftp() + storage = SFTPStorage(sftp) + + else: + storage = LocalStorage() + + renderer = Renderer(templates_dir=TEMPLATES_DIR) + + orchestrator = Orchestrator( + bucket_structure=bucket_structure, + storage=storage, + s3_client=s3_client, + renderer=renderer, + logger=logger, + ) + eval_report_url, audio_report_url = orchestrator.run( + baseline_name=args.baseline_name, + candidate_name=args.candidate_name, + baseline_path=baseline_path, + candidate_path=candidate_path, + benchmarks=benchmarks, + generate_audio_report=args.audio_report, + audio_report_benchmarks=audio_report_benchmarks, + samples_per_benchmark=args.samples_per_benchmark, + task_id=task_id, + ) + + finally: + if sftp is not None: + sftp.close() + + if ssh_client is not None: + ssh_client.close() + + if s3_client is not None: + s3_client.close() + + if eval_report_url is None: + raise RuntimeError("Failed to generate evaluation report and upload it to S3.") + + if args.audio_report and audio_report_url is None: + raise RuntimeError("Failed to upload audio report to S3 and create URL.") + + if audio_report_url is not None: + logger.info(f"\nAudio report is available at:\n{audio_report_url}") + + logger.info(f"\nEvaluation report is available at:\n{eval_report_url}") + + logger.info("\nSave the links and open in your browser!\n") + + +if __name__ == "__main__": + main() diff --git a/scripts/tts_comparison_report/reporting/__init__.py b/scripts/tts_comparison_report/reporting/__init__.py new file mode 100644 index 000000000000..660f6e7aff73 --- /dev/null +++ b/scripts/tts_comparison_report/reporting/__init__.py @@ -0,0 +1,33 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from scripts.tts_comparison_report.reporting.constants import DUMMY_TASK_ID, SUPPORTED_BENCHMARK_NAMES, TEMPLATES_DIR +from scripts.tts_comparison_report.reporting.models import BucketStructure +from scripts.tts_comparison_report.reporting.orchestrator import Orchestrator +from scripts.tts_comparison_report.reporting.renderer import Renderer +from scripts.tts_comparison_report.reporting.s3_client import S3Client, S3Config +from scripts.tts_comparison_report.reporting.storage import BaseStorage, LocalStorage, SFTPStorage + +__all__ = [ + "BaseStorage", + "BucketStructure", + "DUMMY_TASK_ID", + "LocalStorage", + "Orchestrator", + "Renderer", + "S3Client", + "S3Config", + "SUPPORTED_BENCHMARK_NAMES", + "SFTPStorage", + "TEMPLATES_DIR", +] diff --git a/scripts/tts_comparison_report/reporting/components/__init__.py b/scripts/tts_comparison_report/reporting/components/__init__.py new file mode 100644 index 000000000000..e90ba7d36967 --- /dev/null +++ b/scripts/tts_comparison_report/reporting/components/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from scripts.tts_comparison_report.reporting.components.audio_report import prepare_audio_pairs +from scripts.tts_comparison_report.reporting.components.boxplots import BoxPlotsConfig +from scripts.tts_comparison_report.reporting.components.eval_report import prepare_eval_artifacts + +__all__ = [ + "BoxPlotsConfig", + "prepare_audio_pairs", + "prepare_eval_artifacts", +] diff --git a/scripts/tts_comparison_report/reporting/components/audio_report.py b/scripts/tts_comparison_report/reporting/components/audio_report.py new file mode 100644 index 000000000000..4d0ac485705d --- /dev/null +++ b/scripts/tts_comparison_report/reporting/components/audio_report.py @@ -0,0 +1,98 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import random +import warnings + +from scripts.tts_comparison_report.reporting.constants import SEED +from scripts.tts_comparison_report.reporting.models import AudioPair, BucketData, BucketStructure + +_RNG = random.Random(SEED) + + +def _collect_audio_pairs( + benchmark_name: str, + bucket_baseline: BucketData, + bucket_candidate: BucketData, + bucket_structure: BucketStructure, +) -> list[AudioPair]: + baseline_paths = bucket_baseline.get_benchmark_audio_paths(benchmark_name) + candidate_paths = bucket_candidate.get_benchmark_audio_paths(benchmark_name) + baseline_meta = bucket_baseline.get_benchmark_sample_meta(benchmark_name, bucket_structure) + candidate_meta = bucket_candidate.get_benchmark_sample_meta(benchmark_name, bucket_structure) + pairs = [] + + if set(baseline_paths) != set(candidate_paths): + raise ValueError(f"Audio sample sets differ for benchmark '{benchmark_name}'.") + + for name in baseline_paths: + if name not in candidate_paths or name not in baseline_meta or name not in candidate_meta: + raise ValueError( + f"Missing matched sample '{name}' in audio paths or metadata for benchmark '{benchmark_name}'." + ) + + if baseline_meta[name].sample_id != candidate_meta[name].sample_id: + raise ValueError( + f"Sample id mismatch for '{name}' in benchmark '{benchmark_name}'. " + "Probably you use different versions of buckets." + ) + + pair = AudioPair( + context_path=baseline_meta[name].context_path, + baseline_path=baseline_paths[name], + candidate_path=candidate_paths[name], + text=baseline_meta[name].gt_text, + ) + pairs.append(pair) + + pairs.sort(key=lambda p: p.baseline_path.stem) + + return pairs + + +def prepare_audio_pairs( + bucket_baseline: BucketData, + bucket_candidate: BucketData, + bucket_structure: BucketStructure, + used_benchmarks: list[str], + samples_per_benchmark: int, +) -> dict[str, list[AudioPair]]: + """Prepare audio pairs for the selected benchmarks. + + Args: + bucket_baseline: Baseline bucket data. + bucket_candidate: Candidate bucket data. + used_benchmarks: Benchmark names to include in the audio report. + samples_per_benchmark: Maximum number of audio pairs to sample per benchmark. + + Returns: + Mapping from benchmark name to sampled baseline/candidate audio pairs. + + Raises: + ValueError: If benchmark audio sets or sample metadata are inconsistent. + """ + pairs = {} + + for benchmark_name in used_benchmarks: + benchmark_pairs = _collect_audio_pairs(benchmark_name, bucket_baseline, bucket_candidate, bucket_structure) + sampled_pairs = _RNG.sample(benchmark_pairs, k=min(samples_per_benchmark, len(benchmark_pairs))) + + if len(sampled_pairs) < samples_per_benchmark: + warnings.warn( + f"\nBenchmark '{benchmark_name}' contains only {len(sampled_pairs)} available paired samples, " + f"but {samples_per_benchmark} were requested.", + stacklevel=2, + ) + pairs[benchmark_name] = sampled_pairs + + return pairs diff --git a/scripts/tts_comparison_report/reporting/components/boxplots.py b/scripts/tts_comparison_report/reporting/components/boxplots.py new file mode 100644 index 000000000000..302a0be5bb08 --- /dev/null +++ b/scripts/tts_comparison_report/reporting/components/boxplots.py @@ -0,0 +1,221 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass, field +from io import BytesIO +from typing import Optional + +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.axes import Axes +from matplotlib.patches import PathPatch +from scripts.tts_comparison_report.reporting.metrics import DistributionMetricSpec, DistributionMetricsRegistry +from scripts.tts_comparison_report.reporting.models import BucketData, StatTestResult, Winner + + +@dataclass +class BoxPlotsConfig: + """Styling and layout configuration for generated benchmark box plots.""" + + font_family: str = "sans-serif" + font_list: list[str] = field(default_factory=lambda: ["Arial", "Helvetica", "DejaVu Sans"]) + + linewidth: float = 0.4 + default_model_color: str = "#36454F" + winner_model_color: str = "#7393B3" + box_alpha: float = 0.35 + grid_alpha: float = 0.4 + fontsize: int = 6 + fontsize_title: int = 8 + + widths: float = 0.6 + mean_marker: str = "o" + mean_marker_color: str = "#CD5C5C" + mean_marker_size: float = 4.0 + median_color: str = "black" + whisker_color: str = "#666666" + cap_color: str = "#666666" + outlier_color: str = "#708090" + outlier_marker: str = "o" + outlier_markersize: float = 3.0 + outlier_alpha: float = 0.5 + + +def _style_boxplot( + bp: dict[str, PathPatch], + metric: DistributionMetricSpec, + winner_lookup: dict[str, Winner], + cfg: BoxPlotsConfig, +) -> None: + for i, patch in enumerate(bp["boxes"]): + winner = winner_lookup[metric.report_name] + + if (i == 0 and winner == Winner.baseline) or (i == 1 and winner == Winner.candidate): + color = cfg.winner_model_color + else: + color = cfg.default_model_color + + patch.set_facecolor(color) + patch.set_alpha(cfg.box_alpha) + patch.set_edgecolor(color) + patch.set_linewidth(cfg.linewidth) + + +def _add_mean_ci_labels( + ax: Axes, + baseline: np.ndarray, + candidate: np.ndarray, + metric: DistributionMetricSpec, + cfg: BoxPlotsConfig, +) -> None: + for x, values in [(1, baseline), (2, candidate)]: + mean, median = values.mean(), np.median(values) + sem = values.std(ddof=1) / np.sqrt(len(values)) if len(values) > 1 else 0.0 + ci95 = 1.96 * sem + label = f"{mean:.3f} ± {ci95:.3f}" + + if metric.plot_range is not None: + range_ = metric.plot_range[1] - metric.plot_range[0] + else: + range_ = values.max() - values.min() + + x_offset = 0.02 + y_offset = 0.03 * range_ + + if median > mean and mean - y_offset > 0: + y_offset = -y_offset + + ax.text(x + x_offset, mean + y_offset, label, ha="left", va="center", fontsize=cfg.fontsize) + + +def _configure_boxplot_axis( + ax: Axes, + metric: DistributionMetricSpec, + baseline_name: str, + candidate_name: str, + cfg: BoxPlotsConfig, +) -> None: + ax.set_title(metric.report_name, fontsize=cfg.fontsize_title) + ax.set_xticks([1, 2]) + ax.set_xticklabels([baseline_name, candidate_name]) + ax.tick_params(axis="x", labelsize=cfg.fontsize) + ax.tick_params(axis="y", labelsize=cfg.fontsize) + ax.grid(True, axis="y", linestyle="dotted", alpha=cfg.grid_alpha) + + for spine in ax.spines.values(): + spine.set_linewidth(cfg.linewidth) + + ax.tick_params(axis="both", width=cfg.linewidth) + + if metric.plot_range is not None: + ax.set_ylim(metric.plot_range[0], metric.plot_range[1]) + + +def prepare_boxplots( + bucket_baseline: BucketData, + bucket_candidate: BucketData, + stat_test_results: list[StatTestResult], + cfg: BoxPlotsConfig, + benchmark_name: Optional[str] = None, +) -> BytesIO: + """Create an in-memory box plot figure for summary or benchmark-level metrics. + + Args: + bucket_baseline: Baseline bucket data. + bucket_candidate: Candidate bucket data. + stat_test_results: Statistical test results used to highlight the winning model. + cfg: Plot styling and layout configuration. + benchmark_name: Benchmark name. If omitted, metric samples are aggregated + across all benchmarks. + + Returns: + PNG image stored in an in-memory bytes buffer. + """ + baseline_name = bucket_baseline.name + candidate_name = bucket_candidate.name + winner_lookup = {res.metric_name: res.winner for res in stat_test_results} + num_rows = sum(m.add_to_box_plot for m in DistributionMetricsRegistry) + fig_height = max(2.0 * num_rows, 4.5) + + with plt.rc_context({"font.family": cfg.font_family, "font.sans-serif": cfg.font_list}): + fig, axs = plt.subplots(num_rows, 1, figsize=(6, fig_height), squeeze=False) + axs = axs.flatten() + plot_idx = 0 + + for metric in DistributionMetricsRegistry: + if not metric.add_to_box_plot: + continue + + baseline = bucket_baseline.get_metric_samples( + metric_name=metric.key, + benchmark_name=benchmark_name, + ) + candidate = bucket_candidate.get_metric_samples( + metric_name=metric.key, + benchmark_name=benchmark_name, + ) + baseline = np.asarray(baseline, dtype=float) + candidate = np.asarray(candidate, dtype=float) + + ax = axs[plot_idx] + plot_idx += 1 + + bp = ax.boxplot( + [baseline, candidate], + positions=[1, 2], + widths=cfg.widths, + patch_artist=True, + showmeans=True, + meanline=False, + meanprops={ + "marker": cfg.mean_marker, + "markerfacecolor": cfg.mean_marker_color, + "markeredgecolor": cfg.mean_marker_color, + "markersize": cfg.mean_marker_size, + }, + medianprops={ + "color": cfg.median_color, + "linewidth": cfg.linewidth, + }, + whiskerprops={ + "color": cfg.whisker_color, + "linewidth": cfg.linewidth, + }, + capprops={ + "color": cfg.cap_color, + "linewidth": cfg.linewidth, + }, + boxprops={ + "linewidth": cfg.linewidth, + }, + flierprops={ + "marker": cfg.outlier_marker, + "markerfacecolor": cfg.outlier_color, + "markeredgecolor": cfg.outlier_color, + "markersize": cfg.outlier_markersize, + "alpha": cfg.outlier_alpha, + }, + ) + + _style_boxplot(bp, metric, winner_lookup, cfg) + _add_mean_ci_labels(ax, baseline, candidate, metric, cfg) + _configure_boxplot_axis(ax, metric, baseline_name, candidate_name, cfg) + + fig.tight_layout(rect=[0, 0, 1, 0.985]) + + buffer = BytesIO() + fig.savefig(buffer, format="png", dpi=300, bbox_inches="tight") + plt.close(fig) + buffer.seek(0) + + return buffer diff --git a/scripts/tts_comparison_report/reporting/components/eval_report.py b/scripts/tts_comparison_report/reporting/components/eval_report.py new file mode 100644 index 000000000000..1d6d176a5686 --- /dev/null +++ b/scripts/tts_comparison_report/reporting/components/eval_report.py @@ -0,0 +1,97 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from scripts.tts_comparison_report.reporting.components.boxplots import BoxPlotsConfig, prepare_boxplots +from scripts.tts_comparison_report.reporting.components.metrics_table import ( + prepare_benchmark_metrics_table_rows, + prepare_summary_metrics_table_rows, +) +from scripts.tts_comparison_report.reporting.components.stat_tests import ( + prepare_stat_tests_analysis_info, + prepare_stat_tests_table_rows, + run_stat_tests, +) +from scripts.tts_comparison_report.reporting.models import BucketData, EvalArtifacts, EvalResult, ModelConfiguration + + +def prepare_eval_artifacts( + bucket_baseline: BucketData, + bucket_candidate: BucketData, + box_plots_cfg: BoxPlotsConfig, +) -> EvalArtifacts: + """Prepare summary and benchmark-level evaluation artifacts for report rendering. + + Args: + bucket_baseline: Baseline bucket data. + bucket_candidate: Candidate bucket data. + box_plots_cfg: Configuration used to generate benchmark and summary box plots. + + Returns: + Evaluation artifacts containing configuration metadata, summary results, + and per-benchmark results. + """ + baseline_name = bucket_baseline.name + candidate_name = bucket_candidate.name + is_self_comparison = bucket_baseline.path == bucket_candidate.path + + metrics_table_row = prepare_summary_metrics_table_rows(bucket_baseline, bucket_candidate) + stat_test_results = run_stat_tests(bucket_baseline, bucket_candidate) + stat_test_table_row = prepare_stat_tests_table_rows(baseline_name, candidate_name, stat_test_results) + stat_tests_analysis_info = prepare_stat_tests_analysis_info(baseline_name, candidate_name, stat_test_results) + + box_plots = prepare_boxplots( + bucket_baseline=bucket_baseline, + bucket_candidate=bucket_candidate, + stat_test_results=stat_test_results, + cfg=box_plots_cfg, + ) + + configuration = ModelConfiguration( + baseline=bucket_baseline.configuration_str, + candidate=bucket_candidate.configuration_str, + ) + summary = EvalResult( + metrics_table_row=metrics_table_row, + stat_test_table_row=stat_test_table_row, + stat_tests_analysis_info=stat_tests_analysis_info, + box_plots=box_plots, + ) + benchmarks = {} + + for benchmark_name in bucket_baseline.benchmarks: + metrics_table_row = prepare_benchmark_metrics_table_rows(benchmark_name, bucket_baseline, bucket_candidate) + stat_test_results = run_stat_tests(bucket_baseline, bucket_candidate, benchmark_name) + stat_test_table_row = prepare_stat_tests_table_rows(baseline_name, candidate_name, stat_test_results) + stat_tests_analysis_info = prepare_stat_tests_analysis_info(baseline_name, candidate_name, stat_test_results) + + box_plots = prepare_boxplots( + bucket_baseline=bucket_baseline, + bucket_candidate=bucket_candidate, + stat_test_results=stat_test_results, + cfg=box_plots_cfg, + benchmark_name=benchmark_name, + ) + + benchmarks[benchmark_name] = EvalResult( + metrics_table_row=metrics_table_row, + stat_test_table_row=stat_test_table_row, + stat_tests_analysis_info=stat_tests_analysis_info, + box_plots=box_plots, + ) + + return EvalArtifacts( + configuration=configuration, + summary=summary, + benchmarks=benchmarks, + is_self_comparison=is_self_comparison, + ) diff --git a/scripts/tts_comparison_report/reporting/components/metrics_table.py b/scripts/tts_comparison_report/reporting/components/metrics_table.py new file mode 100644 index 000000000000..4878945501da --- /dev/null +++ b/scripts/tts_comparison_report/reporting/components/metrics_table.py @@ -0,0 +1,154 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import html +from typing import Optional + +import numpy as np +from scripts.tts_comparison_report.reporting.metrics import MetricSpec, MetricsRegistry +from scripts.tts_comparison_report.reporting.models import BucketData + + +def _metric_comparator( + a: float, + b: float, + lower_is_better: Optional[bool], +) -> Optional[bool]: + if lower_is_better is None: + return None + + if lower_is_better: + # If the values are equal, the baseline wins. + return a <= b + + return a >= b + + +def _format_metric_values( + a: float, + b: float, + metric: MetricSpec, +) -> tuple[str, str]: + a, b = metric.multiplier * a, metric.multiplier * b + a_is_better = _metric_comparator(a, b, metric.lower_is_better) + a, b = round(a, metric.round_digits), round(b, metric.round_digits) + a_str, b_str = f"{a}{metric.units}", f"{b}{metric.units}" + + a_str = html.escape(a_str) + b_str = html.escape(b_str) + + if metric.lower_is_better is not None: + if a_is_better: + a_str = f"{a_str}" + else: + b_str = f"{b_str}" + + return a_str, b_str + + +def prepare_benchmark_metrics_table_rows( + benchmark_name: str, + bucket_baseline: BucketData, + bucket_candidate: BucketData, +) -> list[list[str]]: + """Prepare formatted metric rows for one benchmark comparison table. + + Args: + benchmark_name: Name of the benchmark to render. + bucket_baseline: Baseline bucket data. + bucket_candidate: Candidate bucket data. + + Returns: + Table rows containing metric names and formatted baseline/candidate values. + + Raises: + ValueError: If a required metric is missing for the benchmark. + """ + rows = [] + + for metric in MetricsRegistry: + a = bucket_baseline.get_metric_avg_value( + metric_name=metric.key, + benchmark_name=benchmark_name, + ) + b = bucket_candidate.get_metric_avg_value( + metric_name=metric.key, + benchmark_name=benchmark_name, + ) + + if a is None or b is None: + if metric.optional: + continue + raise ValueError(f"Unknown metric '{metric.key}' for benchmark '{benchmark_name}'.") + + a_str, b_str = _format_metric_values(a, b, metric) + + rows.append([html.escape(metric.report_name), a_str, b_str]) + + return rows + + +def prepare_summary_metrics_table_rows( + bucket_baseline: BucketData, + bucket_candidate: BucketData, +) -> list[list[str]]: + """Prepare formatted metric rows for the summary comparison table. + + Args: + bucket_baseline: Baseline bucket data. + bucket_candidate: Candidate bucket data. + + Returns: + Table rows containing metric names and formatted macro-averaged + baseline/candidate values. + + Raises: + ValueError: If a required metric is missing for any benchmark included + in the summary. + """ + rows = [] + + for metric in MetricsRegistry: + if not metric.include_in_summary: + continue + + a_vals, b_vals = [], [] + skip = False + + for benchmark_name in bucket_baseline.benchmarks: + a = bucket_baseline.get_metric_avg_value( + metric_name=metric.key, + benchmark_name=benchmark_name, + ) + b = bucket_candidate.get_metric_avg_value( + metric_name=metric.key, + benchmark_name=benchmark_name, + ) + + if a is None or b is None: + if metric.optional: + skip = True + break + raise ValueError(f"Unknown metric '{metric.key}' for benchmark '{benchmark_name}'.") + + a_vals.append(a) + b_vals.append(b) + + if skip: + continue + + avg_a, avg_b = np.mean(a_vals), np.mean(b_vals) + a_str, b_str = _format_metric_values(avg_a, avg_b, metric) + rows.append([html.escape(metric.report_name), a_str, b_str]) + + return rows diff --git a/scripts/tts_comparison_report/reporting/components/stat_tests.py b/scripts/tts_comparison_report/reporting/components/stat_tests.py new file mode 100644 index 000000000000..3dfcd8fa8aa1 --- /dev/null +++ b/scripts/tts_comparison_report/reporting/components/stat_tests.py @@ -0,0 +1,191 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import html +import warnings +from enum import Enum +from typing import Optional + +from scipy.stats import mannwhitneyu +from scripts.tts_comparison_report.reporting.constants import P_VAL_ROUND_DIGITS +from scripts.tts_comparison_report.reporting.metrics import DistributionMetricsRegistry +from scripts.tts_comparison_report.reporting.models import BucketData, StatTestAnalysisInfo, StatTestResult, Winner + + +_SIGNIFICANCE_LEVEL: float = 0.05 + + +class _Alternative(str, Enum): + two_sided = "two-sided" + greater = "greater" + less = "less" + + +def _run_single_stat_test( + baseline: list[float], + candidate: list[float], + lower_is_better: bool, +) -> tuple[Winner, _Alternative, float]: + if not baseline: + raise ValueError("Baseline sample is empty.") + + if not candidate: + raise ValueError("Candidate sample is empty.") + + if len(baseline) != len(candidate): + warnings.warn( + "\nBaseline and candidate contain different numbers of samples. " + "This may indicate missing filewise metrics or dataset mismatch.", + stacklevel=2, + ) + + # First test whether distributions differ at all, then determine direction. + p_val_two_sided = mannwhitneyu(baseline, candidate, alternative="two-sided", method="auto").pvalue + + if p_val_two_sided >= _SIGNIFICANCE_LEVEL: + return Winner.tie, _Alternative.two_sided, round(p_val_two_sided, P_VAL_ROUND_DIGITS) + + p_val = mannwhitneyu(baseline, candidate, alternative="less", method="auto").pvalue + + if p_val < _SIGNIFICANCE_LEVEL: + winner = Winner.baseline if lower_is_better else Winner.candidate + p_val = round(p_val, P_VAL_ROUND_DIGITS) + return winner, _Alternative.less, p_val + + p_val = mannwhitneyu(baseline, candidate, alternative="greater", method="auto").pvalue + + if p_val < _SIGNIFICANCE_LEVEL: + winner = Winner.candidate if lower_is_better else Winner.baseline + p_val = round(p_val, P_VAL_ROUND_DIGITS) + return winner, _Alternative.greater, p_val + + return Winner.tie, _Alternative.two_sided, round(p_val_two_sided, P_VAL_ROUND_DIGITS) + + +def _map_winner_to_name( + winner: Winner, + baseline_name: str, + candidate_name: str, +) -> str: + if winner == Winner.baseline: + return baseline_name + if winner == Winner.candidate: + return candidate_name + return winner.value + + +def run_stat_tests( + bucket_baseline: BucketData, + bucket_candidate: BucketData, + benchmark_name: Optional[str] = None, +) -> list[StatTestResult]: + """Run statistical tests for all distribution metrics. + + Args: + bucket_baseline: Baseline bucket data. + bucket_candidate: Candidate bucket data. + benchmark_name: Benchmark name. If omitted, metric samples are aggregated + across all benchmarks. + + Returns: + List of StatTestResult instances for configured distribution metrics. + + Raises: + ValueError: If metric samples are missing or benchmark data is invalid. + """ + results = [] + + for metric in DistributionMetricsRegistry: + winner, alternative, p_value = _run_single_stat_test( + baseline=bucket_baseline.get_metric_samples(metric.key, benchmark_name), + candidate=bucket_candidate.get_metric_samples(metric.key, benchmark_name), + lower_is_better=metric.lower_is_better, + ) + result = StatTestResult( + metric_name=metric.report_name, + winner=winner, + alternative=alternative.value, + p_value=p_value, + ) + results.append(result) + + return results + + +def prepare_stat_tests_table_rows( + baseline_name: str, + candidate_name: str, + stat_test_results: list[StatTestResult], +) -> list[list[str]]: + """Prepare formatted rows for a statistical test results table. + + Args: + baseline_name: Name of the baseline model used in the reports. + candidate_name: Name of the candidate model used in the reports. + stat_test_results: Statistical test results to format. + + Returns: + Table rows containing metric name, winner, alternative hypothesis, + and p-value. + """ + rows = [] + + for res in stat_test_results: + winner = _map_winner_to_name(res.winner, baseline_name, candidate_name) + rows.append( + [ + html.escape(res.metric_name), + html.escape(winner), + html.escape(res.alternative), + html.escape(str(res.p_value)), + ] + ) + + return rows + + +def prepare_stat_tests_analysis_info( + baseline_name: str, + candidate_name: str, + stat_test_results: list[StatTestResult], +) -> StatTestAnalysisInfo: + """Prepare summary information for the statistical test analysis section. + + Args: + baseline_name: Name of the baseline model used in the reports. + candidate_name: Name of the candidate model used in the reports. + stat_test_results: Statistical test results to summarize. + + Returns: + Instance of StatTestAnalysisInfo containing the overall winner and its + key advantages. If no statistically significant wins are present, both + fields are set to `None`. + """ + a_wins, b_wins = [], [] + + for res in stat_test_results: + if res.winner == Winner.baseline: + a_wins.append(res.metric_name) + elif res.winner == Winner.candidate: + b_wins.append(res.metric_name) + + if not a_wins and not b_wins: + winner, advantages = None, None + else: + winner, wins = (baseline_name, a_wins) if len(a_wins) >= len(b_wins) else (candidate_name, b_wins) + advantages = ", ".join(wins) + + return StatTestAnalysisInfo( + winner=winner, + advantages=advantages, + ) diff --git a/scripts/tts_comparison_report/reporting/constants.py b/scripts/tts_comparison_report/reporting/constants.py new file mode 100644 index 000000000000..d69c96855d5a --- /dev/null +++ b/scripts/tts_comparison_report/reporting/constants.py @@ -0,0 +1,55 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pathlib import Path + + +_ROOT: Path = Path(__file__).parent.parent + +# Benchmark names supported by the comparison report pipeline. +SUPPORTED_BENCHMARK_NAMES: list[str] = [ + "libritts_seen", + "libritts_test_clean", + "riva_hard_digits", + "riva_hard_letters", + "riva_hard_money", + "riva_hard_short", + "vctk", +] + +# Default width of tqdm progress bars in terminal columns. +TQDM_NCOLS: int = 80 + +# Random seed used for reproducible sampling of audio examples. +SEED: int = 42 + +# Number of decimal digits used when formatting p-values in statistical tests. +P_VAL_ROUND_DIGITS: int = 4 + +# Default lifetime of generated S3 presigned links in seconds (one year). +S3_LINK_EXPIRES_IN: int = 31536000 + +# Subdirectory inside the S3 report prefix used for uploaded audio files. +S3_AUDIO_DIR: str = "audio" + +# Subdirectory inside the S3 report prefix used for uploaded plot images. +S3_IMAGES_DIR: str = "images" + +# Directory containing Jinja templates used for report rendering. +TEMPLATES_DIR: Path = _ROOT / "templates" + +# Fallback task id used when no real Jira ticket is provided. +DUMMY_TASK_ID: str = "NEMOTTS-0000" + +# URL prefix used to construct clickable Jira ticket links in reports. +JIRA_TICKET_URL_PREFIX: str = "https://jirasw.nvidia.com/browse" diff --git a/scripts/tts_comparison_report/reporting/helpers.py b/scripts/tts_comparison_report/reporting/helpers.py new file mode 100644 index 000000000000..671c2a8444ff --- /dev/null +++ b/scripts/tts_comparison_report/reporting/helpers.py @@ -0,0 +1,80 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from datetime import UTC, datetime, timedelta +from pathlib import Path + +from scripts.tts_comparison_report.reporting.constants import DUMMY_TASK_ID, JIRA_TICKET_URL_PREFIX +from scripts.tts_comparison_report.reporting.models import ExpirationInfo, TaskInfo + + +def make_expiration_info(expires_in: int) -> ExpirationInfo: + """Create formatted expiration metadata for reports and S3 artifact paths. + + Args: + expires_in: Link lifetime in seconds. + + Returns: + Expiration information with Unix timestamp and formatted string values. + """ + expires_at = datetime.now(UTC) + timedelta(seconds=expires_in) + + return ExpirationInfo( + timestamp=int(expires_at.timestamp()), + path_str=expires_at.strftime("%Y-%m-%dT%H-%M-%SZ"), + user_str=expires_at.strftime("%Y-%m-%d %H:%M UTC"), + ) + + +def make_task_info(task_id: str) -> TaskInfo: + """Create task metadata and the corresponding Jira link information. + + Args: + task_id: Jira task identifier used for the report. + + Returns: + Task information with the original task ID, derived Jira ID, and Jira URL. + """ + jira_id = task_id if task_id != DUMMY_TASK_ID else task_id.split("-")[0] + jira_url = f"{JIRA_TICKET_URL_PREFIX}/{jira_id}" + + return TaskInfo( + task_id=task_id, + jira_id=jira_id, + jira_url=jira_url, + ) + + +def generate_s3_prefix( + baseline_path: Path, + candidate_path: Path, + task_info: TaskInfo, + expiration_info: ExpirationInfo, +) -> str: + """Generate the S3 prefix used to store report artifacts. + + Args: + baseline_path: Path to the baseline bucket root. + candidate_path: Path to the candidate bucket root. + task_info: Task metadata. + expiration_info: Expiration metadata. + + Returns: + S3 key prefix for uploaded report artifacts. + """ + parts = [ + task_info.task_id, + f"{baseline_path.stem}_vs_{candidate_path.stem}", + expiration_info.path_str, + ] + return "-".join(parts) diff --git a/scripts/tts_comparison_report/reporting/metrics/__init__.py b/scripts/tts_comparison_report/reporting/metrics/__init__.py new file mode 100644 index 000000000000..8a95550cfe09 --- /dev/null +++ b/scripts/tts_comparison_report/reporting/metrics/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from scripts.tts_comparison_report.reporting.metrics.registry import DistributionMetricsRegistry, MetricsRegistry +from scripts.tts_comparison_report.reporting.metrics.specs import DistributionMetricSpec, MetricSpec + +__all__ = [ + "DistributionMetricSpec", + "DistributionMetricsRegistry", + "MetricSpec", + "MetricsRegistry", +] diff --git a/scripts/tts_comparison_report/reporting/metrics/registry.py b/scripts/tts_comparison_report/reporting/metrics/registry.py new file mode 100644 index 000000000000..9f2201918e52 --- /dev/null +++ b/scripts/tts_comparison_report/reporting/metrics/registry.py @@ -0,0 +1,37 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from scripts.tts_comparison_report.reporting.metrics.specs import DistributionMetricSpec, MetricSpec + + +MetricsRegistry: list[MetricSpec] = [ + MetricSpec("wer_cumulative", "WER (cumulative)", True, 2, "%", 100), + MetricSpec("cer_cumulative", "CER (cumulative)", True, 2, "%", 100), + MetricSpec("wer_filewise_avg", "WER (filewise avg)", True, 2, "%", 100), + MetricSpec("cer_filewise_avg", "CER (filewise avg)", True, 2, "%", 100), + MetricSpec("utmosv2_avg", "UTMOS v2", False, 3), + MetricSpec("ssim_pred_gt_avg", "SSIM (pred vs GT)", False, 4), + MetricSpec("ssim_pred_context_avg", "SSIM (pred vs context)", False, 4), + MetricSpec("eou_cutoff_rate", "EoU cut-off rate", True, 3, "", 1, False, True), + MetricSpec("eou_silence_rate", "EoU silence rate", True, 3, "", 1, False, True), + MetricSpec("eou_noise_rate", "EoU noise rate", True, 3, "", 1, False, True), + MetricSpec("eou_error_rate", "EoU error rate", True, 3, "", 1, False, True), + MetricSpec("total_gen_audio_seconds", "Total audio (sec)", None, 1, "", 1, False), +] + + +DistributionMetricsRegistry: list[DistributionMetricSpec] = [ + DistributionMetricSpec("cer", "CER", True, True, (0.0, 0.3)), + DistributionMetricSpec("utmosv2", "UTMOS v2", False), + DistributionMetricSpec("pred_context_ssim", "SSIM (pred vs context)", False), +] diff --git a/scripts/tts_comparison_report/reporting/metrics/specs.py b/scripts/tts_comparison_report/reporting/metrics/specs.py new file mode 100644 index 000000000000..2f6f42be7609 --- /dev/null +++ b/scripts/tts_comparison_report/reporting/metrics/specs.py @@ -0,0 +1,53 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Optional + + +@dataclass(frozen=True) +class MetricSpec: + """Specification of a metric shown in evaluation report tables.""" + + # Metric key expected in the aggregated metrics JSON. + key: str + # Metric name shown in the report tables. + report_name: str + # Whether smaller values are better; None means no winner highlighting. + lower_is_better: Optional[bool] + # Number of decimal digits used when formatting the metric value. + round_digits: int + # Optional unit suffix appended to the formatted metric value. + units: str = "" + # Scale factor applied before formatting, e.g. 100 for percentages. + multiplier: float | int = 1 + # Whether this metric should appear in the cross-benchmark summary table. + include_in_summary: bool = True + # Whether this metric may be absent from bucket metrics without causing an error. + optional: bool = False + + +@dataclass(frozen=True) +class DistributionMetricSpec: + """Specification of a metric used in statistical tests and distribution plots.""" + + # Metric key expected in the filewise metrics JSON used for statistical testing. + key: str + # Metric name shown in the statistical test tables. + report_name: str + # Whether smaller values indicate better quality for winner selection. + lower_is_better: bool + # Whether this metric should be included in the generated box plot figure. + add_to_box_plot: bool = True + # Optional y-axis range applied to the metric plot as (min, max). + plot_range: Optional[tuple[float, float]] = None diff --git a/scripts/tts_comparison_report/reporting/models.py b/scripts/tts_comparison_report/reporting/models.py new file mode 100644 index 000000000000..51ea36190d45 --- /dev/null +++ b/scripts/tts_comparison_report/reporting/models.py @@ -0,0 +1,633 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import hashlib +from dataclasses import dataclass, field +from enum import Enum +from io import BytesIO +from pathlib import Path +from typing import Any, Optional, Self + +from scripts.tts_comparison_report.reporting.constants import TQDM_NCOLS +from scripts.tts_comparison_report.reporting.storage import BaseStorage +from tqdm import tqdm + + +_REQUIRED_SAMPLE_ID_KEYS: list[str] = [ + "pred_audio_filepath", + "gt_text", + "gt_audio_filepath", + "context_audio_filepath", +] + + +@dataclass +class BucketStructure: + """Paths and naming conventions used to locate artifacts inside an evaluation bucket.""" + + eval_output_subdir: str = "results" + metrics_suffix: str = "_metrics_0.json" + metrics_filewise_suffix: str = "_filewise_metrics_0.json" + context_audio_dir: str = "audio/repeat_0" + context_audio_prefix: str = "context_audio_" + generated_audio_dir: str = "audio/repeat_0" + generated_audio_prefix: str = "predicted_audio_" + + +def _map_generated_to_context_name( + generated_name: str, + generated_prefix: str, + context_prefix: str, +) -> str: + suffix = generated_name.split(generated_prefix)[-1] + return f"{context_prefix}{suffix}" + + +@dataclass(frozen=True) +class BenchmarkSampleMeta: + """Metadata describing one generated sample within a benchmark.""" + + name: str + gt_text: str + context_path: Path + sample_id: str + + @staticmethod + def _validate(item: dict[str, Any]) -> None: + for key in _REQUIRED_SAMPLE_ID_KEYS: + if key not in item: + raise ValueError(f"Missing required key '{key}' in filewise metrics item.") + + @staticmethod + def _get_sample_id(item: dict[str, Any]) -> str: + parts = [item["gt_audio_filepath"], item["context_audio_filepath"]] + return hashlib.sha256("|".join(parts).encode("utf-8")).hexdigest() + + @classmethod + def create( + cls, + item: dict[str, Any], + context_audio_paths: dict[str, Path], + bucket_structure: BucketStructure, + ) -> Self: + """Create sample metadata from one filewise metrics item. + + Args: + item: One entry from the filewise metrics JSON. + context_audio_paths: Mapping from context audio file name to its path. + bucket_structure: Bucket naming and path conventions. + + Returns: + Sample metadata extracted from the given filewise metrics item. + + Raises: + ValueError: If required keys are missing from the item. + KeyError: If the corresponding context audio file is not found. + """ + cls._validate(item) + + name = Path(item["pred_audio_filepath"]).stem + + key = _map_generated_to_context_name( + generated_name=name, + generated_prefix=bucket_structure.generated_audio_prefix, + context_prefix=bucket_structure.context_audio_prefix, + ) + obj = cls( + name=name, + gt_text=item["gt_text"], + context_path=context_audio_paths[key], + sample_id=cls._get_sample_id(item), + ) + return obj + + +def _collect_audio_paths( + root: Path, + prefix: str, + audio_paths: dict[str, Path], + storage: BaseStorage, +) -> None: + if not storage.exists(root): + raise FileNotFoundError(f"Missing audio directory: '{root}'.") + + for p in storage.iter_dir(root): + if not p.stem.startswith(prefix) or p.suffix != ".wav": + continue + audio_paths[p.stem] = p + + +def _validate_audio_pairs( + context_audio_paths: dict[str, Path], + generated_audio_paths: dict[str, Path], + bucket_structure: BucketStructure, +) -> None: + for name in generated_audio_paths: + key = _map_generated_to_context_name( + generated_name=name, + generated_prefix=bucket_structure.generated_audio_prefix, + context_prefix=bucket_structure.context_audio_prefix, + ) + if key not in context_audio_paths: + raise ValueError(f"Missing context audio: '{key}'.") + + +@dataclass +class BenchmarkData: + """Artifacts and loaded data associated with one evaluation benchmark.""" + + name: str + metrics_path: Optional[Path] = None + filewise_metrics_path: Optional[Path] = None + generated_audio_paths: dict[str, Path] = field(default_factory=dict) + context_audio_paths: dict[str, Path] = field(default_factory=dict) + + metrics: Optional[dict[str, float]] = None + filewise_metrics: Optional[list[dict[str, Any]]] = None + + @classmethod + def from_storage( + cls, + benchmark_name: str, + benchmark_path: Path, + bucket_structure: BucketStructure, + check_audio: bool, + storage: BaseStorage, + ) -> Self: + """Create benchmark data by discovering benchmark artifacts in storage. + + Args: + benchmark_name: Name of the benchmark. + benchmark_path: Path to the benchmark directory inside the evaluation bucket. + bucket_structure: Bucket naming and path conventions. + check_audio: Whether generated audio files should also be discovered. + storage: Storage backend used to access local or remote files. + + Returns: + Benchmark data initialized with discovered artifact paths. + + Raises: + FileNotFoundError: If required metrics files are missing, audio directories + are missing, or expected audio files cannot be found. + ValueError: If generated audio files do not have matching context audio files. + """ + obj = cls(name=benchmark_name) + + path = benchmark_path / f"{benchmark_name}{bucket_structure.metrics_suffix}" + if not storage.exists(path): + raise FileNotFoundError(f"Missing metrics file: '{path}'.") + obj.metrics_path = path + + path = benchmark_path / f"{benchmark_name}{bucket_structure.metrics_filewise_suffix}" + if not storage.exists(path): + raise FileNotFoundError(f"Missing filewise metrics file: '{path}'.") + obj.filewise_metrics_path = path + + if check_audio: + _collect_audio_paths( + root=benchmark_path / bucket_structure.context_audio_dir, + prefix=bucket_structure.context_audio_prefix, + audio_paths=obj.context_audio_paths, + storage=storage, + ) + if not obj.context_audio_paths: + raise FileNotFoundError( + f"No context audio files were found in '{benchmark_path / bucket_structure.context_audio_dir}'. " + "The bucket structure likely differs from the one specified in 'BucketStructure'." + ) + _collect_audio_paths( + root=benchmark_path / bucket_structure.generated_audio_dir, + prefix=bucket_structure.generated_audio_prefix, + audio_paths=obj.generated_audio_paths, + storage=storage, + ) + if not obj.generated_audio_paths: + raise FileNotFoundError( + f"No generated audio files were found in '{benchmark_path / bucket_structure.generated_audio_dir}'. " + "The bucket structure likely differs from the one specified in 'BucketStructure'." + ) + _validate_audio_pairs( + context_audio_paths=obj.context_audio_paths, + generated_audio_paths=obj.generated_audio_paths, + bucket_structure=bucket_structure, + ) + return obj + + def load_metrics(self, storage: BaseStorage) -> None: + """Load aggregated benchmark metrics from storage. + + Args: + storage: Storage instance used to read the metrics file. + + Raises: + TypeError: If the metrics file does not contain a JSON object. + """ + if self.metrics_path is None: + return + + data = storage.read_json(self.metrics_path) + + if not isinstance(data, dict): + raise TypeError(f"Metrics file must contain a JSON object: '{self.metrics_path}'.") + + self.metrics = data + + def load_filewise_metrics(self, storage: BaseStorage) -> None: + """Load filewise benchmark metrics from storage. + + Args: + storage: Storage instance used to read the filewise metrics file. + + Raises: + TypeError: If the filewise metrics file does not contain a JSON array. + """ + if self.filewise_metrics_path is None: + return + + data = storage.read_json(self.filewise_metrics_path) + + if not isinstance(data, list): + raise TypeError(f"Filewise metrics file must contain a JSON array: '{self.filewise_metrics_path}'.") + + self.filewise_metrics = data + + +def _validate_numeric_metric_value( + value: Any, + metric_name: str, + context: str, +) -> float: + if not isinstance(value, (int, float)): + raise TypeError( + f"Metric '{metric_name}' in {context} must be numeric, " + f"but got value {value!r} of type {type(value).__name__}." + ) + return float(value) + + +@dataclass +class BucketData: + """Evaluation bucket metadata and loaded metric data.""" + + name: str + path: Path + configuration_str: Optional[str] = None + benchmarks: dict[str, BenchmarkData] = field(default_factory=dict) + + @classmethod + def from_storage( + cls, + bucket_name: str, + bucket_path: Path, + bucket_structure: BucketStructure, + benchmark_names: tuple[str, ...], + check_audio: bool, + storage: BaseStorage, + ) -> Self: + """Create bucket data by discovering benchmark artifacts in storage. + + Args: + bucket_name: Display name of the bucket, typically the name of the model it belongs to. + bucket_path: Path to the bucket root directory. + bucket_structure: Bucket naming and path conventions. + benchmark_names: Benchmark names expected in the bucket. + check_audio: Whether generated audio files should also be discovered. + storage: Storage instance used to access local or remote files. + + Returns: + Bucket data initialized with discovered benchmark artifacts. + + Raises: + FileNotFoundError: If the expected results directory is missing. + """ + obj = cls(name=bucket_name, path=bucket_path) + results_path = bucket_path / bucket_structure.eval_output_subdir + + if not storage.exists(results_path): + raise FileNotFoundError(f"Missing results directory: '{results_path}'.") + + for benchmark_path in storage.iter_dir(results_path, only_dirs=True): + if len(obj.benchmarks) == len(benchmark_names): + break + + dir_name = benchmark_path.name + name = next((n for n in benchmark_names if dir_name == n or dir_name.endswith(f"_{n}")), None) + + if name is None: + continue + + obj.benchmarks[name] = BenchmarkData.from_storage( + benchmark_name=name, + benchmark_path=benchmark_path, + bucket_structure=bucket_structure, + check_audio=check_audio, + storage=storage, + ) + if obj.configuration_str is None: + suffix = f"_{name}" + obj.configuration_str = dir_name[: -len(suffix)] + + return obj + + def load_metrics( + self, + storage: BaseStorage, + show_pbar: bool = False, + ) -> None: + """Load aggregated and filewise metrics for all discovered benchmarks. + + Args: + storage: Storage instance used to read metrics files. + show_pbar: Whether to display a progress bar while loading metrics. + """ + pbar = tqdm(total=len(self.benchmarks), ncols=TQDM_NCOLS) if show_pbar else None + + for benchmark_data in self.benchmarks.values(): + benchmark_data.load_metrics(storage) + benchmark_data.load_filewise_metrics(storage) + + if pbar: + pbar.update(1) + + if pbar: + pbar.close() + + def get_metric_avg_value( + self, + metric_name: str, + benchmark_name: str, + ) -> Optional[float]: + """Return the aggregated value of a metric for one benchmark. + + Args: + metric_name: Name of the metric to retrieve. + benchmark_name: Name of the benchmark. + + Returns: + Aggregated metric value, or `None` if the metric is not present. + + Raises: + ValueError: If the benchmark is unknown or metrics are not loaded. + TypeError: If the metric value is not numeric. + """ + if benchmark_name not in self.benchmarks: + raise ValueError(f"Unknown benchmark: '{benchmark_name}'.") + + metrics = self.benchmarks[benchmark_name].metrics + + if metrics is None: + raise ValueError(f"Metrics not loaded for benchmark: '{benchmark_name}'.") + + if metric_name not in metrics: + return None + + value = _validate_numeric_metric_value( + value=metrics[metric_name], + metric_name=metric_name, + context=f"averaged metrics for benchmark '{benchmark_name}'", + ) + return value + + def _get_metric_stats( + self, + metric_name: str, + benchmark_name: str, + ) -> list[float]: + if benchmark_name not in self.benchmarks: + raise ValueError(f"Unknown benchmark: '{benchmark_name}'.") + + items = self.benchmarks[benchmark_name].filewise_metrics + + if items is None or not items: + raise ValueError(f"Filewise metrics not loaded for benchmark: '{benchmark_name}'.") + + output = [] + validation_context = f"filewise metrics for benchmark '{benchmark_name}'" + + for item in items: + if metric_name not in item: + continue + + value = _validate_numeric_metric_value( + value=item[metric_name], + metric_name=metric_name, + context=validation_context, + ) + output.append(value) + + if not output: + raise ValueError(f"Unknown or empty metric '{metric_name}' for benchmark '{benchmark_name}'.") + + return output + + def _aggregate_metric_stats(self, metric_name: str) -> list[float]: + output = [] + + for benchmark_name in self.benchmarks: + output.extend(self._get_metric_stats(metric_name, benchmark_name)) + + if not output: + raise ValueError(f"Unknown or empty aggregated metric '{metric_name}'.") + + return output + + def get_metric_samples( + self, + metric_name: str, + benchmark_name: Optional[str] = None, + ) -> list[float]: + """Return filewise samples for a metric from one or all benchmarks. + + Args: + metric_name: Name of the metric to retrieve. + benchmark_name: Benchmark name. If omitted, samples are aggregated + across all benchmarks. + + Returns: + List of numeric metric samples. + + Raises: + ValueError: If the benchmark is unknown, filewise metrics are not loaded, + or the metric is missing. + TypeError: If any metric value is not numeric. + """ + if benchmark_name is None: + return self._aggregate_metric_stats(metric_name) + return self._get_metric_stats(metric_name, benchmark_name) + + def get_benchmark_audio_paths(self, benchmark_name: str) -> dict[str, Path]: + """Return generated audio file paths for a benchmark. + + Args: + benchmark_name: Name of the benchmark. + + Returns: + Mapping from sample name to generated audio path. + + Raises: + ValueError: If the benchmark is unknown or audio paths are not loaded. + """ + if benchmark_name not in self.benchmarks: + raise ValueError(f"Unknown benchmark: '{benchmark_name}'.") + + paths = self.benchmarks[benchmark_name].generated_audio_paths + + if not paths: + raise ValueError(f"Generated audio paths not loaded for benchmark: '{benchmark_name}'.") + + return paths + + def get_benchmark_sample_meta( + self, + benchmark_name: str, + bucket_structure: BucketStructure, + ) -> dict[str, BenchmarkSampleMeta]: + """Return sample metadata for a benchmark derived from filewise metrics. + + Args: + benchmark_name: Name of the benchmark. + bucket_structure: Bucket naming and path conventions used to resolve + matching context audio files. + + Returns: + Mapping from sample name to benchmark sample metadata. + + Raises: + ValueError: If the benchmark is unknown, filewise metrics are not loaded, + or context audio paths are not loaded. + KeyError: If a matching context audio file cannot be found for a sample. + """ + if benchmark_name not in self.benchmarks: + raise ValueError(f"Unknown benchmark: '{benchmark_name}'.") + + items = self.benchmarks[benchmark_name].filewise_metrics + + if not items: + raise ValueError(f"Filewise metrics not loaded for benchmark: '{benchmark_name}'.") + + paths = self.benchmarks[benchmark_name].context_audio_paths + + if not paths: + raise ValueError(f"Context audio paths not loaded for benchmark: '{benchmark_name}'.") + + output = {} + + for item in items: + meta = BenchmarkSampleMeta.create( + item=item, + context_audio_paths=paths, + bucket_structure=bucket_structure, + ) + output[meta.name] = meta + + return output + + +@dataclass(frozen=True) +class TaskInfo: + """Task identifiers and derived Jira link information used in reports.""" + + task_id: str + jira_id: str + jira_url: str + + +@dataclass(frozen=True) +class ExpirationInfo: + """Formatted expiration metadata used in reports and S3 artifact paths.""" + + timestamp: int + path_str: str + user_str: str + + +class Winner(str, Enum): + """Possible outcomes of a statistical comparison between baseline and candidate.""" + + baseline = "baseline" + candidate = "candidate" + tie = "tie" + + +@dataclass(frozen=True) +class StatTestResult: + """Result of a statistical comparison for a single metric.""" + + metric_name: str + winner: Winner + alternative: str + p_value: float + + +@dataclass(frozen=True) +class StatTestAnalysisInfo: + """Summary information used to describe statistical test outcomes in reports.""" + + winner: Optional[str] + advantages: Optional[str] + + +@dataclass(frozen=True) +class EvalResult: + """Evaluation results for one report section, including tables, analysis, and plot.""" + + metrics_table_row: list[str | float] + stat_test_table_row: list[str | float] + stat_tests_analysis_info: StatTestAnalysisInfo + box_plots: BytesIO + + +@dataclass(frozen=True) +class ModelConfiguration: + """Configuration strings associated with the baseline and candidate models.""" + + baseline: str + candidate: str + + +@dataclass(frozen=True) +class EvalArtifacts: + """Prepared evaluation results, configuration, and comparison metadata used to render reports.""" + + configuration: ModelConfiguration + summary: EvalResult + benchmarks: dict[str, EvalResult] + is_self_comparison: bool + + +@dataclass(frozen=True) +class UploadedBoxPlotsInfo: + """S3 URLs of uploaded summary and benchmark-level box plot images.""" + + summary_url: str + benchmark_urls: dict[str, str] + + +@dataclass(frozen=True) +class AudioPair: + """Matched context, baseline, and candidate audio files for one sample.""" + + context_path: Path + baseline_path: Path + candidate_path: Path + text: str + + +@dataclass(frozen=True) +class UploadedAudioPairInfo: + """Uploaded context, baseline, and candidate audio URLs for one sample.""" + + context_url: str + baseline_url: str + candidate_url: str + text: str diff --git a/scripts/tts_comparison_report/reporting/orchestrator.py b/scripts/tts_comparison_report/reporting/orchestrator.py new file mode 100644 index 000000000000..341c4c5ff4e7 --- /dev/null +++ b/scripts/tts_comparison_report/reporting/orchestrator.py @@ -0,0 +1,508 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from io import BytesIO +from logging import Logger +from pathlib import Path +from typing import Optional + +from scripts.tts_comparison_report.reporting.components import ( + BoxPlotsConfig, + prepare_audio_pairs, + prepare_eval_artifacts, +) +from scripts.tts_comparison_report.reporting.constants import ( + S3_AUDIO_DIR, + S3_IMAGES_DIR, + S3_LINK_EXPIRES_IN, + TQDM_NCOLS, +) +from scripts.tts_comparison_report.reporting.helpers import generate_s3_prefix, make_expiration_info, make_task_info +from scripts.tts_comparison_report.reporting.models import ( + AudioPair, + BucketData, + BucketStructure, + EvalArtifacts, + ExpirationInfo, + TaskInfo, + UploadedAudioPairInfo, + UploadedBoxPlotsInfo, +) +from scripts.tts_comparison_report.reporting.renderer import Renderer, TemplateName +from scripts.tts_comparison_report.reporting.s3_client import S3Client +from scripts.tts_comparison_report.reporting.storage import BaseStorage +from tqdm import tqdm + + +class Orchestrator: + """Coordinate loading, processing, rendering, and uploading of comparison reports.""" + + def __init__( + self, + bucket_structure: BucketStructure, + storage: BaseStorage, + s3_client: S3Client, + renderer: Renderer, + logger: Optional[Logger] = None, + ) -> None: + self.bucket_structure = bucket_structure + self.storage = storage + self.s3_client = s3_client + self.renderer = renderer + self.logger = logger + + self.show_pbar = True if logger is not None else False + + def _log_info(self, msg: str) -> None: + if self.logger is not None: + self.logger.info(msg) + + def _load_buckets( + self, + baseline_name: str, + candidate_name: str, + baseline_path: Path, + candidate_path: Path, + benchmark_names: tuple[str, ...], + check_audio: bool, + ) -> tuple[BucketData, BucketData]: + self._log_info(f"\nLoading metadata for {baseline_name}...") + bucket_baseline = BucketData.from_storage( + bucket_name=baseline_name, + bucket_path=baseline_path, + bucket_structure=self.bucket_structure, + benchmark_names=benchmark_names, + check_audio=check_audio, + storage=self.storage, + ) + self._log_info(f"Loading metadata for {candidate_name}...") + bucket_candidate = BucketData.from_storage( + bucket_name=candidate_name, + bucket_path=candidate_path, + bucket_structure=self.bucket_structure, + benchmark_names=benchmark_names, + check_audio=check_audio, + storage=self.storage, + ) + + baseline_set = set(bucket_baseline.benchmarks.keys()) + candidate_set = set(bucket_candidate.benchmarks.keys()) + + if baseline_set != candidate_set: + raise ValueError(f"Benchmark sets differ: '{baseline_set}' vs '{candidate_set}'.") + + self._log_info(f"\nLoading metric data for {baseline_name}:") + bucket_baseline.load_metrics(storage=self.storage, show_pbar=self.show_pbar) + + self._log_info(f"\nLoading metric data for {candidate_name}:") + bucket_candidate.load_metrics(storage=self.storage, show_pbar=self.show_pbar) + + return bucket_baseline, bucket_candidate + + def _upload_audio_file( + self, + path: Path, + key: str, + ) -> str: + with self.storage.open_file(path) as f: + url = self.s3_client.upload_fileobj( + fileobj=f, + key=key, + expires_in=S3_LINK_EXPIRES_IN, + content_type="audio/wav", + ) + return url + + def _upload_png_image( + self, + image: BytesIO, + key: str, + ) -> str: + url = self.s3_client.upload_bytes( + data=image.getvalue(), + key=key, + expires_in=S3_LINK_EXPIRES_IN, + content_type="image/png", + ) + return url + + def _upload_audio( + self, + used_benchmarks: list[str], + audio_pairs: dict[str, list[AudioPair]], + s3_prefix: str, + ) -> dict[str, list[UploadedAudioPairInfo]]: + total = sum(len(v) for v in audio_pairs.values()) + pbar = tqdm(total=total, ncols=TQDM_NCOLS) if self.show_pbar else None + uploaded_info = {} + + for benchmark_name in used_benchmarks: + benchmark_info = [] + + for i, pair in enumerate(audio_pairs[benchmark_name]): + context_url = self._upload_audio_file( + path=pair.context_path, + key=f"{s3_prefix}/{S3_AUDIO_DIR}/context_{benchmark_name}_{i}.wav", + ) + baseline_url = self._upload_audio_file( + path=pair.baseline_path, + key=f"{s3_prefix}/{S3_AUDIO_DIR}/baseline_{benchmark_name}_{i}.wav", + ) + candidate_url = self._upload_audio_file( + path=pair.candidate_path, + key=f"{s3_prefix}/{S3_AUDIO_DIR}/candidate_{benchmark_name}_{i}.wav", + ) + pair_info = UploadedAudioPairInfo( + context_url=context_url, + baseline_url=baseline_url, + candidate_url=candidate_url, + text=pair.text, + ) + benchmark_info.append(pair_info) + + if pbar: + pbar.update(1) + + uploaded_info[benchmark_name] = benchmark_info + + if pbar: + pbar.close() + + return uploaded_info + + def _upload_boxplots( + self, + eval_artifacts: EvalArtifacts, + s3_prefix: str, + ) -> UploadedBoxPlotsInfo: + name_prefix = "box_plot" + pbar = tqdm(total=len(eval_artifacts.benchmarks) + 1, ncols=TQDM_NCOLS) if self.show_pbar else None + + summary_url = self._upload_png_image( + image=eval_artifacts.summary.box_plots, + key=f"{s3_prefix}/{S3_IMAGES_DIR}/{name_prefix}_summary.png", + ) + if pbar: + pbar.update(1) + + benchmark_urls = {} + + for benchmark_name, benchmark_result in eval_artifacts.benchmarks.items(): + benchmark_urls[benchmark_name] = self._upload_png_image( + image=benchmark_result.box_plots, + key=f"{s3_prefix}/{S3_IMAGES_DIR}/{name_prefix}_{benchmark_name}.png", + ) + if pbar: + pbar.update(1) + + if pbar: + pbar.close() + + return UploadedBoxPlotsInfo( + summary_url=summary_url, + benchmark_urls=benchmark_urls, + ) + + def _upload_report( + self, + report: str, + s3_prefix: str, + report_name: str, + ) -> str: + report_url = self.s3_client.upload_bytes( + data=report.encode("utf-8"), + key=f"{s3_prefix}/{report_name}.html", + expires_in=S3_LINK_EXPIRES_IN, + content_type="text/html; charset=utf-8", + ) + return report_url + + def _render_audio_report( + self, + baseline_name: str, + candidate_name: str, + used_benchmarks: list[str], + uploaded_audio_info: dict[str, list[UploadedAudioPairInfo]], + task_info: TaskInfo, + expiration_info: ExpirationInfo, + ) -> str: + expiration_comment = f"This report will expire at {expiration_info.user_str}" + + header_block = self.renderer.render( + name=TemplateName.audio_report_header, + baseline_name=baseline_name, + candidate_name=candidate_name, + expiration_comment=expiration_comment, + ) + benchmark_blocks, benchmark_section_info = [], [] + + for benchmark_name in used_benchmarks: + pair_blocks = [] + + for pair in uploaded_audio_info[benchmark_name]: + block = self.renderer.render( + name=TemplateName.audio_report_pair, + context_url=pair.context_url, + baseline_url=pair.baseline_url, + candidate_url=pair.candidate_url, + text=pair.text, + ) + pair_blocks.append(block) + + block = self.renderer.render( + name=TemplateName.audio_report_block, + title=benchmark_name, + section_id=benchmark_name, + baseline_name=baseline_name, + candidate_name=candidate_name, + pair_blocks=pair_blocks, + ) + benchmark_blocks.append(block) + benchmark_section_info.append((benchmark_name, benchmark_name)) + + report = self.renderer.render( + name=TemplateName.audio_report, + jira_id=task_info.jira_id, + jira_url=task_info.jira_url, + header_block=header_block, + benchmark_blocks=benchmark_blocks, + benchmark_section_info=benchmark_section_info, + ) + + return report + + def _render_eval_report( + self, + baseline_name: str, + candidate_name: str, + eval_artifacts: EvalArtifacts, + uploaded_box_plots_info: UploadedBoxPlotsInfo, + task_info: TaskInfo, + expiration_info: ExpirationInfo, + audio_report_url: Optional[str], + ) -> str: + expiration_comment = f"This report will expire at {expiration_info.user_str}" + + configuration_block = self.renderer.render( + name=TemplateName.eval_report_configuration, + baseline_name=baseline_name, + baseline_configuration=eval_artifacts.configuration.baseline, + candidate_name=candidate_name, + candidate_configuration=eval_artifacts.configuration.candidate, + ) + header_block = self.renderer.render( + name=TemplateName.eval_report_header, + baseline_name=baseline_name, + candidate_name=candidate_name, + expiration_comment=expiration_comment, + ) + metrics_table = self.renderer.render( + name=TemplateName.eval_report_table, + title="Metrics (macro-average across benchmarks)", + headers=["Metric", baseline_name, candidate_name], + rows=eval_artifacts.summary.metrics_table_row, + ) + stat_tests_table = self.renderer.render( + name=TemplateName.eval_report_table, + title="Statistical Tests (pooled filewise across benchmarks)", + headers=["Metric", "Winner", "Alternative", "p-value"], + rows=eval_artifacts.summary.stat_test_table_row, + ) + stat_tests_analysis = self.renderer.render( + name=TemplateName.eval_report_stat_analysis, + winner=eval_artifacts.summary.stat_tests_analysis_info.winner, + advantages=eval_artifacts.summary.stat_tests_analysis_info.advantages, + ) + image_block = self.renderer.render( + name=TemplateName.eval_report_image, + image_url=uploaded_box_plots_info.summary_url, + ) + summary_block = self.renderer.render( + name=TemplateName.eval_report_block, + is_summary=True, + metrics_table=metrics_table, + stat_tests_table=stat_tests_table, + stat_tests_analysis=stat_tests_analysis, + image_block=image_block, + ) + benchmark_blocks, benchmark_section_info = [], [] + + for benchmark_name in sorted(eval_artifacts.benchmarks.keys()): + metrics_table = self.renderer.render( + name=TemplateName.eval_report_table, + title="Metrics", + headers=["Metric", baseline_name, candidate_name], + rows=eval_artifacts.benchmarks[benchmark_name].metrics_table_row, + ) + stat_tests_table = self.renderer.render( + name=TemplateName.eval_report_table, + title="Statistical Tests", + headers=["Metric", "Winner", "Alternative", "p-value"], + rows=eval_artifacts.benchmarks[benchmark_name].stat_test_table_row, + ) + stat_tests_analysis = self.renderer.render( + name=TemplateName.eval_report_stat_analysis, + winner=eval_artifacts.benchmarks[benchmark_name].stat_tests_analysis_info.winner, + advantages=eval_artifacts.benchmarks[benchmark_name].stat_tests_analysis_info.advantages, + ) + image_block = self.renderer.render( + name=TemplateName.eval_report_image, + image_url=uploaded_box_plots_info.benchmark_urls[benchmark_name], + ) + block = self.renderer.render( + name=TemplateName.eval_report_block, + is_summary=False, + title=benchmark_name, + section_id=benchmark_name, + metrics_table=metrics_table, + stat_tests_table=stat_tests_table, + stat_tests_analysis=stat_tests_analysis, + image_block=image_block, + ) + benchmark_blocks.append(block) + benchmark_section_info.append((benchmark_name, benchmark_name)) + + report = self.renderer.render( + name=TemplateName.eval_report, + is_self_comparison=eval_artifacts.is_self_comparison, + jira_id=task_info.jira_id, + jira_url=task_info.jira_url, + audio_report_url=audio_report_url, + configuration_block=configuration_block, + header_block=header_block, + summary_block=summary_block, + benchmark_blocks=benchmark_blocks, + benchmark_section_info=benchmark_section_info, + ) + return report + + def run( + self, + baseline_name: str, + candidate_name: str, + baseline_path: Path, + candidate_path: Path, + benchmarks: list[str], + generate_audio_report: bool, + audio_report_benchmarks: Optional[list[str]], + samples_per_benchmark: int, + task_id: str, + ) -> tuple[str, Optional[str]]: + """Generate evaluation reports, upload report artifacts to S3, and return report URLs. + + This method performs the full end-to-end comparison workflow: + it loads evaluation buckets, prepares summary and benchmark-level artifacts, + uploads plots and optional audio samples to S3, renders the final HTML + reports, uploads them, and returns their presigned URLs. + + Args: + baseline_name: Name of the baseline model used in the reports. + candidate_name: Name of the candidate model used in the reports. + baseline_path: Path to the baseline evaluation bucket root. + candidate_path: Path to the candidate evaluation bucket root. + benchmarks: Benchmark names to include in the evaluation report. + generate_audio_report: Whether to generate the audio comparison report. + audio_report_benchmarks: Benchmark names to include in the audio report. + samples_per_benchmark: Number of audio pairs to sample per benchmark. + task_id: Task identifier used for report metadata and Jira linking. + + Returns: + Tuple containing the evaluation report URL and the optional audio report URL. + + Raises: + ValueError: If input configuration is inconsistent, required benchmarks + are missing, or report generation inputs are invalid. + FileNotFoundError: If required bucket artifacts are missing from storage. + TypeError: If loaded metric files have unexpected types. + """ + benchmark_names = tuple(sorted(benchmarks, key=len, reverse=True)) + + audio_report: Optional[str] = None + audio_report_url: Optional[str] = None + + bucket_baseline, bucket_candidate = self._load_buckets( + baseline_name=baseline_name, + candidate_name=candidate_name, + baseline_path=baseline_path, + candidate_path=candidate_path, + benchmark_names=benchmark_names, + check_audio=generate_audio_report, + ) + + task_info = make_task_info(task_id) + expiration_info = make_expiration_info(S3_LINK_EXPIRES_IN) + s3_prefix = generate_s3_prefix(baseline_path, candidate_path, task_info, expiration_info) + box_plots_cfg = BoxPlotsConfig() + + self._log_info("\nPreparing evaluation artifacts...") + eval_artifacts = prepare_eval_artifacts( + bucket_baseline=bucket_baseline, + bucket_candidate=bucket_candidate, + box_plots_cfg=box_plots_cfg, + ) + self._log_info("\nUploading images to S3:") + uploaded_box_plots_info = self._upload_boxplots( + eval_artifacts=eval_artifacts, + s3_prefix=s3_prefix, + ) + + if generate_audio_report: + if audio_report_benchmarks is None: + raise ValueError("Audio report benchmarks must be provided when audio report is enabled.") + + audio_pairs = prepare_audio_pairs( + bucket_baseline=bucket_baseline, + bucket_candidate=bucket_candidate, + bucket_structure=self.bucket_structure, + used_benchmarks=audio_report_benchmarks, + samples_per_benchmark=samples_per_benchmark, + ) + self._log_info("\nUploading audio files to S3:") + uploaded_audio_info = self._upload_audio( + used_benchmarks=audio_report_benchmarks, + audio_pairs=audio_pairs, + s3_prefix=s3_prefix, + ) + self._log_info("\nPreparing audio report...") + audio_report = self._render_audio_report( + baseline_name=baseline_name, + candidate_name=candidate_name, + used_benchmarks=audio_report_benchmarks, + uploaded_audio_info=uploaded_audio_info, + task_info=task_info, + expiration_info=expiration_info, + ) + audio_report_url = self._upload_report( + report=audio_report, + s3_prefix=s3_prefix, + report_name="audio_report", + ) + + self._log_info("\nPreparing evaluation report...") + eval_report = self._render_eval_report( + baseline_name=bucket_baseline.name, + candidate_name=bucket_candidate.name, + eval_artifacts=eval_artifacts, + uploaded_box_plots_info=uploaded_box_plots_info, + task_info=task_info, + expiration_info=expiration_info, + audio_report_url=audio_report_url, + ) + eval_report_url = self._upload_report( + report=eval_report, + s3_prefix=s3_prefix, + report_name="eval_report", + ) + self._log_info(f"\nUploaded artifacts to bucket '{self.s3_client.cfg.bucket}' with prefix '{s3_prefix}'.") + + return eval_report_url, audio_report_url diff --git a/scripts/tts_comparison_report/reporting/renderer.py b/scripts/tts_comparison_report/reporting/renderer.py new file mode 100644 index 000000000000..00b3275ab06f --- /dev/null +++ b/scripts/tts_comparison_report/reporting/renderer.py @@ -0,0 +1,59 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from enum import Enum + +from pathlib import Path +from jinja2 import Environment, FileSystemLoader, select_autoescape + + +class TemplateName(str, Enum): + """Template file names used by the report renderer.""" + + eval_report = "eval_report.jinja" + eval_report_configuration = "eval_report_configuration.jinja" + eval_report_header = "eval_report_header.jinja" + eval_report_table = "eval_report_table.jinja" + eval_report_stat_analysis = "eval_report_stat_analysis.jinja" + eval_report_image = "eval_report_image.jinja" + eval_report_block = "eval_report_block.jinja" + + audio_report = "audio_report.jinja" + audio_report_header = "audio_report_header.jinja" + audio_report_pair = "audio_report_pair.jinja" + audio_report_block = "audio_report_block.jinja" + + +class Renderer: + """Load and render Jinja templates used for report generation.""" + + def __init__(self, templates_dir: Path) -> None: + self.env = Environment( + loader=FileSystemLoader(templates_dir.as_posix()), + autoescape=select_autoescape(["html", "xml", "jinja"]), + trim_blocks=True, + lstrip_blocks=True, + ) + self.templates = {t: self.env.get_template(t.value) for t in TemplateName} + + def render(self, name: TemplateName, **kwargs) -> str: + """Render the selected template with the provided context variables. + + Args: + name: Template identifier to render. + **kwargs: Context variables passed to the template. + + Returns: + Rendered template as a string. + """ + return self.templates[name].render(**kwargs) diff --git a/scripts/tts_comparison_report/reporting/s3_client.py b/scripts/tts_comparison_report/reporting/s3_client.py new file mode 100644 index 000000000000..d55a31daaa3b --- /dev/null +++ b/scripts/tts_comparison_report/reporting/s3_client.py @@ -0,0 +1,135 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import BinaryIO, Optional + +import boto3 +from botocore.config import Config + + +@dataclass +class S3Config: + """Configuration required to connect to the target S3-compatible storage.""" + + bucket: str + endpoint_url: str + region_name: str + connect_timeout: int = 10 + + +class S3Client: + """Client for uploading report artifacts to S3-compatible object storage.""" + + def __init__( + self, + cfg: S3Config, + aws_access_key_id: str, + aws_secret_access_key: str, + ) -> None: + self.cfg = cfg + self.client = boto3.client( + "s3", + endpoint_url=cfg.endpoint_url, + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + region_name=cfg.region_name, + config=Config(connect_timeout=cfg.connect_timeout), + ) + + def upload_fileobj( + self, + fileobj: BinaryIO, + key: str, + expires_in: int, + content_type: Optional[str] = None, + ) -> str: + """Upload a binary file-like object and return a presigned download URL. + + Args: + fileobj: File-like object to upload. + key: S3 object key for the uploaded file. + expires_in: Lifetime of the presigned URL in seconds. + content_type: Optional content type stored with the uploaded object. + + Returns: + Presigned URL for downloading the uploaded object. + """ + kwargs = { + "Fileobj": fileobj, + "Bucket": self.cfg.bucket, + "Key": key, + } + + if content_type is not None: + kwargs["ExtraArgs"] = {"ContentType": content_type} + + self.client.upload_fileobj(**kwargs) + + return self.get_presigned_url(key, expires_in) + + def upload_bytes( + self, + data: bytes, + key: str, + expires_in: int, + content_type: Optional[str] = None, + ) -> str: + """Upload raw bytes and return a presigned download URL. + + Args: + data: File content to upload. + key: S3 object key for the uploaded file. + expires_in: Lifetime of the presigned URL in seconds. + content_type: Optional content type stored with the uploaded object. + + Returns: + Presigned URL for downloading the uploaded object. + """ + extra_args = {} + + if content_type is not None: + extra_args["ContentType"] = content_type + + self.client.put_object( + Bucket=self.cfg.bucket, + Key=key, + Body=data, + **extra_args, + ) + + return self.get_presigned_url(key, expires_in) + + def get_presigned_url( + self, + key: str, + expires_in: int, + ) -> str: + """Generate a presigned download URL for an uploaded object. + + Args: + key: S3 object key. + expires_in: Lifetime of the presigned URL in seconds. + + Returns: + Presigned URL for the requested object. + """ + return self.client.generate_presigned_url( + "get_object", + Params={"Bucket": self.cfg.bucket, "Key": key}, + ExpiresIn=expires_in, + ) + + def close(self) -> None: + """Close the underlying S3 client.""" + self.client.close() diff --git a/scripts/tts_comparison_report/reporting/storage.py b/scripts/tts_comparison_report/reporting/storage.py new file mode 100644 index 000000000000..444fd39d3208 --- /dev/null +++ b/scripts/tts_comparison_report/reporting/storage.py @@ -0,0 +1,172 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import errno +import json +from abc import ABC, abstractmethod +from pathlib import Path +from stat import S_ISDIR +from typing import Any, BinaryIO, Generator + +from paramiko.sftp_client import SFTPClient + + +class BaseStorage(ABC): + """Abstract storage interface for accessing report artifacts.""" + + @abstractmethod + def exists(self, path: Path) -> bool: + """Check whether a path exists. + + Args: + path: Path to check. + + Returns: + True if the path exists, otherwise False. + """ + ... + + @abstractmethod + def iter_dir( + self, + path: Path, + only_dirs: bool = False, + ) -> Generator[Path, None, None]: + """Iterate over items in a directory. + + Args: + path: Directory path to iterate. + only_dirs: Whether to yield only directory entries. + + Yields: + Paths of directory entries. + """ + ... + + @abstractmethod + def open_file(self, path: Path) -> BinaryIO: + """Open a file for binary reading. + + Args: + path: Path to the file. + + Returns: + Binary file-like object. + """ + ... + + @abstractmethod + def read_json(self, path: Path) -> Any: + """Read and parse a JSON file. + + Args: + path: Path to the JSON file. + + Returns: + Parsed JSON content. + """ + ... + + @abstractmethod + def read_bytes(self, path: Path) -> bytes: + """Read raw bytes from a file in storage. + + Args: + path: Path to the file. + + Returns: + File content as bytes. + """ + ... + + +class LocalStorage(BaseStorage): + """Storage backend for accessing artifacts on the local filesystem.""" + + def exists(self, path: Path) -> bool: + """See the BaseStorage class docstring.""" + return path.exists() + + def iter_dir( + self, + path: Path, + only_dirs: bool = False, + ) -> Generator[Path, None, None]: + """See the BaseStorage class docstring.""" + for p in path.iterdir(): + if only_dirs and not p.is_dir(): + continue + yield p + + def open_file(self, path: Path) -> BinaryIO: + """See the BaseStorage class docstring.""" + return open(path, "rb") + + def read_json(self, path: Path) -> Any: + """See the BaseStorage class docstring.""" + with open(path, "r", encoding="utf-8") as f: + data = json.load(f) + return data + + def read_bytes(self, path: Path) -> bytes: + """See the BaseStorage class docstring.""" + return path.read_bytes() + + +class SFTPStorage(BaseStorage): + """Storage backend for accessing artifacts on a remote host over SFTP.""" + + def __init__(self, sftp: SFTPClient) -> None: + super().__init__() + self.sftp = sftp + + def exists(self, path: Path) -> bool: + """See the BaseStorage class docstring.""" + try: + self.sftp.stat(path.as_posix()) + return True + + except FileNotFoundError: + return False + + except OSError as e: + if getattr(e, "errno", None) == errno.ENOENT: + return False + raise + + def iter_dir( + self, + path: Path, + only_dirs: bool = False, + ) -> Generator[Path, None, None]: + """See the BaseStorage class docstring.""" + for item in self.sftp.listdir_attr(path.as_posix()): + if only_dirs and not S_ISDIR(item.st_mode): + continue + yield path / item.filename + + def open_file(self, path: Path) -> BinaryIO: + """See the BaseStorage class docstring.""" + return self.sftp.open(path.as_posix(), "rb") + + def read_json(self, path: Path) -> Any: + """See the BaseStorage class docstring.""" + with self.sftp.open(path.as_posix(), "rb") as f: + data = json.loads(f.read().decode("utf-8")) + return data + + def read_bytes(self, path: Path) -> bytes: + """See the BaseStorage class docstring.""" + with self.sftp.open(path.as_posix(), "rb") as f: + data = f.read() + return data diff --git a/scripts/tts_comparison_report/requirements.txt b/scripts/tts_comparison_report/requirements.txt new file mode 100644 index 000000000000..e0432982d045 --- /dev/null +++ b/scripts/tts_comparison_report/requirements.txt @@ -0,0 +1,7 @@ +numpy +matplotlib +scipy +paramiko +tqdm +jinja2 +boto3 diff --git a/scripts/tts_comparison_report/templates/audio_report.jinja b/scripts/tts_comparison_report/templates/audio_report.jinja new file mode 100644 index 000000000000..617452cdeb16 --- /dev/null +++ b/scripts/tts_comparison_report/templates/audio_report.jinja @@ -0,0 +1,370 @@ + + +
+ + ++ Comparing baseline {{ baseline_name }} + against candidate {{ candidate_name }} +
++ {{ expiration_comment }} +
\ No newline at end of file diff --git a/scripts/tts_comparison_report/templates/audio_report_pair.jinja b/scripts/tts_comparison_report/templates/audio_report_pair.jinja new file mode 100644 index 000000000000..57da7bf637fc --- /dev/null +++ b/scripts/tts_comparison_report/templates/audio_report_pair.jinja @@ -0,0 +1,22 @@ ++ Text: {{ text }} +
+{{ baseline_configuration }}
+{{ candidate_configuration }}
++ Comparing baseline {{ baseline_name }} + against candidate {{ candidate_name }} +
+ ++ {{ expiration_comment }} +
\ No newline at end of file diff --git a/scripts/tts_comparison_report/templates/eval_report_image.jinja b/scripts/tts_comparison_report/templates/eval_report_image.jinja new file mode 100644 index 000000000000..79e92aa75aec --- /dev/null +++ b/scripts/tts_comparison_report/templates/eval_report_image.jinja @@ -0,0 +1,6 @@ +No statistically significant difference was observed between these two models.
+ {% else %} +{{ winner }} performs best. + Key statistically significant advantages: {{ advantages }}.
+ {% endif %} +| {{ header }} | + {% endfor %} +
|---|
| {{ cell|safe }} | + {% endfor %} +