In [1]:
import boto3
import json
import requests
from concurrent.futures import ThreadPoolExecutor
import concurrent
from typing import Generator, Dict, Any
import itertools
from urllib.parse import urlparse, parse_qs
import os
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
import pandas as pd
import ast
import plotly.express as px


boto3.setup_default_session(profile_name="r2")
s3_client = boto3.client("s3")

image_bucket = "salad-benchmark-public-assets"
bucket_domain = "https://salad-benchmark-assets.download"

salad_green = "#53a626"

In [2]:
def list_all_file_urls(bucket: str, prefix: str) -> Generator[str, None, None]:
    paginator = s3_client.get_paginator("list_objects_v2")
    page_iterator = paginator.paginate(Bucket=bucket, Prefix=prefix)
    for page in page_iterator:
        for content in page["Contents"]:
            yield f"{bucket_domain}/{content['Key']}"


def create_session(max_retries=3, backoff_factor=0.3):
    """
    Create a requests session with retry strategy.
    :param max_retries: Maximum number of retries for each request.
    :param backoff_factor: A backoff factor to apply between attempts.
    :param status_forcelist: A set of HTTP status codes that we should force a retry on.
    :return: A requests session object.
    """
    session = requests.Session()
    retries = Retry(
        total=max_retries,
        read=max_retries,
        connect=max_retries,
        backoff_factor=backoff_factor,
        status_forcelist=[i for i in range(500, 600)],
        respect_retry_after_header=True,
    )
    adapter = HTTPAdapter(max_retries=retries)
    session.mount("http://", adapter)
    session.mount("https://", adapter)
    return session


# Usage
session = create_session()

num_requests = 0
num_responses = 0


def http_worker(request_params: Dict[str, Any]) -> requests.Response:
    """Function to make an HTTP request"""
    global num_requests, num_responses
    num_requests += 1
    print(f"\r{num_requests} submitted | {num_responses} responses", end="", flush=True)
    response = session.request(**request_params)
    num_responses += 1
    return response


def chunk_list(lst, chunk_size):
    """Yield successive chunks of chunk_size from lst."""
    for i in range(0, len(lst), chunk_size):
        yield lst[i : i + chunk_size]
        

def chunk_generator(generator, chunk_size):
    """Yield successive chunks of chunk_size from generator."""
    while True:
        chunk = list(itertools.islice(generator, chunk_size))
        if not chunk:
            break
        yield chunk


def fetch_responses(
    requests_generator: Generator[Dict[str, Any], None, None], pool_size: int = 5
) -> Generator[requests.Response, None, None]:
    """Function to manage a pool of HTTP workers"""

    with ThreadPoolExecutor(max_workers=pool_size) as executor:
        # Chunk the requests into groups of (pool_size * 10)
        for chunk in chunk_generator(requests_generator, pool_size * 20):
            futures = [executor.submit(http_worker, params) for params in chunk]
            for future in concurrent.futures.as_completed(futures):
                yield future.result()

In [3]:
asr_apis = ["https://honeyberry-spinach-04iea1s0vf4y7jef.salad.cloud", "https://kumquat-potato-hqka58l4iluyvpin.salad.cloud"]


def all_urls():
    for url in list_all_file_urls(image_bucket, "wikipedia/english 1/"):
        if url.endswith(".ogg"):
            yield url
    for url in list_all_file_urls(image_bucket, "wikipedia/english 2/"):
        if url.endswith(".ogg"):
            yield url
    for url in list_all_file_urls(image_bucket, "cv-corpus-15.0-2023-09-08/en/clips/"):
        if url.endswith(".mp3"):
            yield url

def all_requests(urls):
    for url in urls:
        for asr_api in asr_apis:
            yield {"method": "POST", "url": f"{asr_api}/asr", "json": {"url": url}}

def get_audio_url_from_response(response: requests.Response) -> str:
    request_body = response.request.body
    decoded_body = request_body.decode("utf-8")
    return json.loads(decoded_body)["url"]


def get_rows(request_generator):
    for response in fetch_responses(request_generator, pool_size=5):
        if response.status_code != 200:
            raise Exception(f"Error: {response.status_code} {response.reason}")
        body = {}
        body["audio_url"] = get_audio_url_from_response(response)
        salad_headers = [
            header
            for header in response.headers
            if header
            in [
                "x-gpu-name",
                "x-salad-machine-id",
                "x-salad-container-group-id",
                "x-processing-time",
                "x-audio-length",
                "x-realtime-factor",
                "x-model-id",
            ]
        ]
        for header in salad_headers:
            body[header] = response.headers[header]
        yield body

In [4]:
all_urls_file = "all_urls.txt"
if not os.path.exists(all_urls_file):
    with open(all_urls_file, "w") as f:
        for url in all_urls():
            f.write(f"{url}\n")
            
with open(all_urls_file, "r") as f:
    all_urls = f.readlines()
    
print(f"Total Number of Audio Clips: {len(all_urls)}")

Total Number of Audio Clips: 1159240


In [5]:
all_requests_file = "all_requests.jsonl"
if not os.path.exists(all_requests_file):
    with open(all_requests_file, "w") as f:
        for request in all_requests(all_urls):
            f.write(f"{json.dumps(request)}\n")

def request_generator():
    with open(all_requests_file, "r") as f:
        for line in f:
            yield json.loads(line)

In [6]:
# row_file = "all_rows.jsonl"
# if not os.path.exists(row_file):
#     with open(row_file, "w") as f:
#         for row in get_rows(request_generator()):
#             f.write(json.dumps(row) + "\n")

# with open(row_file) as f:
#     rows = [json.loads(line) for line in f]

# print(f"Total number of rows: {len(rows)}")