Skip to content

Commit

Permalink
feat: add common docker handler and common mirrors
Browse files Browse the repository at this point in the history
add `follow_redirects` options
add `out` params for `aria2.addUri`
  • Loading branch information
Anonymous committed Jun 13, 2024
1 parent d06c749 commit 833a31a
Show file tree
Hide file tree
Showing 10 changed files with 191 additions and 172 deletions.
16 changes: 14 additions & 2 deletions src/mirrorsrun/aria2_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
logger = logging.getLogger(__name__)


# refer to https://aria2.github.io/manual/en/html/aria2c.html
async def send_request(method, params=None):
request_id = uuid.uuid4().hex
payload = {
Expand All @@ -32,9 +33,20 @@ async def send_request(method, params=None):
raise e


async def add_download(url, save_dir="/app/cache"):
async def add_download(url, save_dir="/app/cache", out_file=None):
logger.info(f"[Aria2] add_download {url=} {save_dir=} {out_file=}")

method = "aria2.addUri"
params = [[url], {"dir": save_dir, "header": []}]
options = {
"dir": save_dir,
"header": [],
"out": out_file,
}

if out_file:
options["out"] = out_file

params = [[url], options]
response = await send_request(method, params)
return response["result"]

Expand Down
68 changes: 0 additions & 68 deletions src/mirrorsrun/docker_utils.py

This file was deleted.

8 changes: 6 additions & 2 deletions src/mirrorsrun/proxy/direct.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,15 @@ async def direct_proxy(
target_url: str,
pre_process: typing.Union[SyncPreProcessor, AsyncPreProcessor, None] = None,
post_process: typing.Union[SyncPostProcessor, AsyncPostProcessor, None] = None,
follow_redirects: bool = True,
) -> Response:

# httpx will use the following environment variables to determine the proxy
# https://www.python-httpx.org/environment_variables/#http_proxy-https_proxy-all_proxy
async with httpx.AsyncClient() as client:
req_headers = request.headers.mutablecopy()
for key in req_headers.keys():
if key not in ["user-agent", "accept"]:
if key not in ["user-agent", "accept", "authorization"]:
del req_headers[key]

httpx_req: HttpxRequest = client.build_request(
Expand All @@ -76,7 +78,9 @@ async def direct_proxy(

httpx_req = await pre_process_request(request, httpx_req, pre_process)

upstream_response = await client.send(httpx_req)
upstream_response = await client.send(
httpx_req, follow_redirects=follow_redirects
)

res_headers = upstream_response.headers

Expand Down
20 changes: 14 additions & 6 deletions src/mirrorsrun/proxy/file_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
from urllib.parse import urlparse, quote

import httpx
from mirrorsrun.aria2_api import add_download
from mirrorsrun.config import CACHE_DIR, EXTERNAL_URL_ARIA2
from starlette.requests import Request
from starlette.responses import Response
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR, HTTP_504_GATEWAY_TIMEOUT

from mirrorsrun.aria2_api import add_download
from mirrorsrun.config import CACHE_DIR, EXTERNAL_URL_ARIA2

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -80,7 +81,7 @@ async def try_file_based_cache(
return make_cached_response(target_url)

if cache_status == DownloadingStatus.DOWNLOADING:
logger.info(f"Download is not finished, return 503 for {target_url}")
logger.info(f"Download is not finished, return 504 for {target_url}")
return Response(
content=f"This file is downloading, view it at {EXTERNAL_URL_ARIA2}",
status_code=HTTP_504_GATEWAY_TIMEOUT,
Expand All @@ -94,8 +95,12 @@ async def try_file_based_cache(
processed_url = quote(target_url, safe="/:?=&%")

try:
logger.info(f"Start download {processed_url}")
await add_download(processed_url, save_dir=cache_file_dir)
# resolve redirect via aria2
await add_download(
processed_url,
save_dir=cache_file_dir,
out_file=os.path.basename(cache_file),
)
except Exception as e:
logger.error(f"Download error, return 500 for {target_url}", exc_info=e)
return Response(
Expand All @@ -110,7 +115,10 @@ async def try_file_based_cache(
if cache_status == DownloadingStatus.DOWNLOADED:
logger.info(f"Cache hit for {target_url}")
return make_cached_response(target_url)
logger.info(f"Download is not finished, return 503 for {target_url}")

assert cache_status != DownloadingStatus.NOT_FOUND

logger.info(f"Download is not finished, return 504 for {target_url}")
return Response(
content=f"This file is downloading, view it at {EXTERNAL_URL_ARIA2}",
status_code=HTTP_504_GATEWAY_TIMEOUT,
Expand Down
16 changes: 11 additions & 5 deletions src/mirrorsrun/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,27 @@
EXTERNAL_HOST_ARIA2,
SCHEME,
)
from mirrorsrun.sites.docker import docker

from mirrorsrun.sites.npm import npm
from mirrorsrun.sites.pypi import pypi
from mirrorsrun.sites.torch import torch
from mirrorsrun.sites.k8s import k8s
from mirrorsrun.sites.docker import dockerhub, k8s, quay, ghcr
from mirrorsrun.sites.common import common

subdomain_mapping = {
"mirrors": common,
"pypi": pypi,
"torch": torch,
"docker": docker,
"npm": npm,
"docker": dockerhub,
"k8s": k8s,
"ghcr": ghcr,
"quay": quay,
}

logging.basicConfig(level=logging.INFO)
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -123,7 +129,7 @@ async def capture_request(request: Request, call_next: Callable):
app="server:app",
host="0.0.0.0",
port=port,
reload=True, # TODO: reload only in dev mode
reload=True, # TODO: reload only in dev mode
proxy_headers=True, # trust x-forwarded-for etc.
forwarded_allow_ips="*",
)
18 changes: 18 additions & 0 deletions src/mirrorsrun/sites/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from starlette.requests import Request

from mirrorsrun.proxy.direct import direct_proxy
from starlette.responses import Response


async def common(request: Request):
path = request.url.path
if path == "/":
return
if path.startswith("/alpine"):
return await direct_proxy(request, "https://dl-cdn.alpinelinux.org" + path)
if path.startswith("/ubuntu/"):
return await direct_proxy(request, "http://archive.ubuntu.com" + path)
if path.startswith("/ubuntu-ports/"):
return await direct_proxy(request, "http://ports.ubuntu.com" + path)

return Response("Not Found", status_code=404)
Loading

0 comments on commit 833a31a

Please sign in to comment.