In [1]:
import boto3
import json
import requests
from concurrent.futures import ThreadPoolExecutor
import concurrent
from typing import Generator, Dict, Any
import itertools
from urllib.parse import urlparse, parse_qs
import os
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
import pandas as pd
import ast
import plotly.express as px


boto3.setup_default_session(profile_name='r2')
s3_client = boto3.client('s3')

image_bucket = "salad-benchmark-public-assets"
bucket_domain = "https://salad-benchmark-assets.download"

salad_green = "#53a626"

In [2]:
def list_all_file_urls(bucket: str, prefix: str) -> Generator[str, None, None]:
    paginator = s3_client.get_paginator('list_objects_v2')
    page_iterator = paginator.paginate(Bucket=bucket, Prefix=prefix)
    for page in page_iterator:
        for content in page['Contents']:
            yield f"{bucket_domain}/{content['Key']}"
            
def create_session(max_retries=3, backoff_factor=0.3):
    """
    Create a requests session with retry strategy.
    :param max_retries: Maximum number of retries for each request.
    :param backoff_factor: A backoff factor to apply between attempts.
    :param status_forcelist: A set of HTTP status codes that we should force a retry on.
    :return: A requests session object.
    """
    session = requests.Session()
    retries = Retry(total=max_retries,
                    read=max_retries,
                    connect=max_retries,
                    backoff_factor=backoff_factor,
                    status_forcelist=[i for i in range(500, 600)],
                    respect_retry_after_header=True
                    )
    adapter = HTTPAdapter(max_retries=retries)
    session.mount('http://', adapter)
    session.mount('https://', adapter)
    return session

# Usage
session = create_session()
            
num_requests = 0
num_responses = 0
list_of_all_requests = []
def http_worker() -> requests.Response:
    """ Function to make an HTTP request """
    global num_requests, num_responses, list_of_all_requests
    num_requests += 1
    request_params = list_of_all_requests.pop()
    print(f"\r{num_requests} submitted | {num_responses} responses", end="", flush=True)
    try:
        response = session.request(**request_params)
    except Exception as e:
        print(f"\nError: {e}", flush=True)
        response = None
    num_responses += 1
    return response
      
def chunk_list(lst, chunk_size):
    """Yield successive chunks of chunk_size from lst."""
    for i in range(0, len(lst), chunk_size):
        yield lst[i:i + chunk_size]

def fetch_responses(pool_size: int = 5) -> Generator[requests.Response, None, None]:
    """ Function to manage a pool of HTTP workers """
    global list_of_all_requests, num_responses
    with ThreadPoolExecutor(max_workers=pool_size) as executor:
        # Chunk the requests into groups of (pool_size * 20)
        for _ in range(0, len(list_of_all_requests), pool_size * 20):
            futures = [executor.submit(http_worker) for _ in range(pool_size * 20)]
            for future in concurrent.futures.as_completed(futures):
                yield future.result()

In [3]:
image_segmentation_api = "https://feta-coriander-ljq4xhb0fbxkxuye.salad.cloud"

def all_requests():
    for url in list_all_file_urls(image_bucket, "coco2017/train2017/"):
        yield {
          "method": "GET",
          "url": f"{image_segmentation_api}/segment",
          "params": {"url": url, "multimask_output": False }
        }
    for url in list_all_file_urls(image_bucket, "ava/images/"):
        yield {
          "method": "GET",
          "url": f"{image_segmentation_api}/segment",
          "params": {"url": url, "multimask_output": False}
        }
        
def get_image_url_from_response(response: requests.Response) -> str:
    return parse_qs(urlparse(response.request.url).query)["url"][0]
    
    
def get_rows():
    for response in fetch_responses(pool_size=50):
        if response is None:
            continue
        if response.status_code != 200:
            raise Exception(f"Error: {response.status_code} {response.reason}")
        row = {
            "image_url": get_image_url_from_response(response),
        }
        salad_headers = [header for header in response.headers if header in ["x-gpu-name", "x-salad-machine-id", "x-salad-container-group-id", "x-inference-time", "x-image-load-time"]]
        for header in salad_headers:
            row[header] = response.headers[header]
        yield row

In [4]:
request_list_cache_file = "all_requests.json"

if os.path.exists(request_list_cache_file):
    with open(request_list_cache_file) as f:
        list_of_all_requests = json.load(f)
else:
    list_of_all_requests = list(all_requests())
    with open(request_list_cache_file, "w") as f:
        json.dump(list_of_all_requests, f)
print(f"Total number of requests: {len(list_of_all_requests)}")

Total number of requests: 373795


In [5]:
row_file = "all_rows.jsonl"
rows = []
if os.path.exists(row_file):
    with open(row_file) as f:
        rows = [json.loads(line) for line in f]
        processed_urls = set([row["image_url"] for row in rows])
        print(f"Resuming at row #{len(rows)}")
        list_of_all_requests = [
            request
            for request in list_of_all_requests
            if request["params"]["url"] not in processed_urls
        ]
# with open(row_file, "a") as f:
#     for row in get_rows():
#         f.write(json.dumps(row) + "\n")
#         rows.append(row)

print(f"Total number of rows: {len(rows)}")

Resuming at row #152848
Total number of rows: 152848
