In [1]:
import boto3
import json
import requests
from concurrent.futures import ThreadPoolExecutor
import concurrent
from typing import Generator, Dict, Any
import itertools
from urllib.parse import urlparse, parse_qs

boto3.setup_default_session(profile_name='r2')
s3_client = boto3.client('s3')

image_bucket = "salad-benchmark-public-assets"
bucket_domain = "https://salad-benchmark-assets.download"

In [2]:
def list_all_file_urls(bucket: str, prefix: str) -> Generator[str, None, None]:
    paginator = s3_client.get_paginator('list_objects_v2')
    page_iterator = paginator.paginate(Bucket=bucket, Prefix=prefix)
    for page in page_iterator:
        for content in page['Contents']:
            yield f"{bucket_domain}/{content['Key']}"
            

def http_worker(request_params: Dict[str, Any]) -> requests.Response:
    """ Function to make an HTTP request """
    try:
        response = requests.request(**request_params)
        return response
    except requests.RequestException as e:
        return e
      

def fetch_responses(requests_generator: Generator[Dict[str, Any], None, None], 
                    pool_size: int = 5) -> Generator[requests.Response, None, None]:
    """ Function to manage a pool of HTTP workers """
    num_processed = 0
    with ThreadPoolExecutor(max_workers=pool_size) as executor:
        futures = [executor.submit(http_worker, params) for params in requests_generator]
        for future in concurrent.futures.as_completed(futures):
            num_processed += 1
            if num_processed % 100 == 0:
                print(f"Processed {num_processed} requests")
            yield future.result()

In [3]:
image_tagging_api = "https://orange-splitpea-kunklrrfjwnyterv.salad.cloud"

def all_requests():
    for url in list_all_file_urls(image_bucket, "coco2017/train2017/"):
        yield {
          "method": "GET",
          "url": f"{image_tagging_api}/tag",
          "params": {"url": url}
        }
    for url in list_all_file_urls(image_bucket, "ava/images/"):
        yield {
          "method": "GET",
          "url": f"{image_tagging_api}/tag",
          "params": {"url": url}
        }
        
def get_image_url_from_response(response: requests.Response) -> str:
    return parse_qs(urlparse(response.request.url).query)["url"][0]
    
    
def get_rows(request_generator):
    for response in fetch_responses(request_generator, pool_size=80):
        body = response.json()
        body["image_url"] = get_image_url_from_response(response)
        salad_headers = [header for header in response.headers if header in ["x-gpu-name", "x-salad-machine-id", "x-salad-container-group-id", "x-inference-time", "x-image-download-time"]]
        for header in salad_headers:
            body[header] = response.headers[header]
        yield body

In [4]:
list_of_all_requests = list(all_requests())
print(f"Total number of requests: {len(list_of_all_requests)}")
with open("all_requests.json", "w") as f:
    json.dump(list_of_all_requests, f)

Total number of requests: 373795


In [None]:
import pandas as pd

df = pd.DataFrame(get_rows(list_of_all_requests))
df.to_csv("all_requests.csv", index=False)