Skip to content

Commit 0957c48

Browse files
committed
chore: 优化s3获取文件列表慢的问题,和支持更多的进度条显示
1 parent 2f5cc24 commit 0957c48

File tree

8 files changed

+226
-54
lines changed

8 files changed

+226
-54
lines changed

core/__init__.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import asyncio
22
import base64
33
import hashlib
4-
import hmac
54
import io
65
import time
76
from typing import Optional
@@ -10,11 +9,11 @@
109
import fastapi
1110

1211
from core import units
13-
from core.abc import ResponseFile, ResponseFileLocal, ResponseFileMemory, ResponseFileRemote
12+
from core.abc import ResponseFileLocal, ResponseFileMemory, ResponseFileRemote
1413

1514
from .locale import load_languages
1615
from .cluster import ClusterManager
17-
from .config import API_VERSION, ROOT_PATH, VERSION, cfg
16+
from .config import API_VERSION, VERSION, cfg
1817
from .logger import logger
1918
from .utils import runtime
2019
from . import web

core/cluster.py

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -402,12 +402,12 @@ def __init__(
402402
self._missing_files = missing_files
403403
self._clusters = clusters
404404
self._storages = storages
405-
self._pbar = tqdm.tqdm(
406-
total=sum(f.size for f in missing_files),
407-
desc="Download",
408-
unit="b",
409-
unit_divisor=1024,
405+
self._pbar = utils.MultiTQDM(
406+
total=sum(missing_file.size for missing_file in missing_files),
407+
description="Download",
408+
unit="B",
410409
unit_scale=True,
410+
unit_divisor=1024,
411411
)
412412
self._failed = 0
413413
self._success = 0
@@ -430,13 +430,16 @@ async def download(self):
430430
configuration = await self.get_configurations()
431431
logger.tinfo("download.configuration", source=configuration.source, concurrency=configuration.concurrency)
432432
async with anyio.create_task_group() as task_group:
433-
missfiles = deque(self._missing_files)
433+
queue = utils.Queue()
434+
for file in self._missing_files:
435+
queue.put_item(file)
434436
for _ in range(configuration.concurrency):
435-
task_group.start_soon(self._download_files, missfiles)
437+
task_group.start_soon(self._download_files, queue, _)
436438

437439
async def _download_files(
438440
self,
439-
files: deque[BMCLAPIFile]
441+
files: utils.Queue[BMCLAPIFile],
442+
worker_id: int
440443
):
441444
async with aiohttp.ClientSession(
442445
base_url=cfg.base_url,
@@ -445,14 +448,21 @@ async def _download_files(
445448
"User-Agent": USER_AGENT,
446449
}
447450
) as session:
448-
while len(files) != 0:
449-
file = files.popleft()
450-
await self._download_file(file, session)
451+
with self._pbar.sub(0, f"Worker {worker_id}", unit="B", unit_scale=True, unit_divisor=1024) as pbar:
452+
while len(files) != 0:
453+
file = await files.get_item()
454+
pbar._tqdm.total = file.size
455+
pbar._tqdm.n = 0
456+
pbar._tqdm.set_description_str(file.path)
457+
pbar._tqdm.refresh()
458+
pbar._tqdm.update(0)
459+
await self._download_file(file, session, pbar)
451460

452461
async def _download_file(
453462
self,
454463
file: BMCLAPIFile,
455-
session: aiohttp.ClientSession
464+
session: aiohttp.ClientSession,
465+
pbar: utils.SubTQDM
456466
):
457467
last_error = None
458468
for _ in range(10):
@@ -461,6 +471,7 @@ async def _download_file(
461471
with tempfile.NamedTemporaryFile(
462472
dir=self._cache_dir,
463473
) as tmp_file:
474+
print(pbar.position)
464475
try:
465476
async with session.get(
466477
file.path
@@ -471,13 +482,15 @@ async def _download_file(
471482
inc = len(data)
472483
size += inc
473484
self._pbar.update(inc)
485+
pbar.update(inc)
474486
if hash.hexdigest() != file.hash or size != file.size:
475487
await anyio.sleep(50)
476488
raise Exception(f"hash mismatch, got {hash.hexdigest()} expected {file.hash}")
477489

478490
except Exception as e:
479491
last_error = e
480492
self._pbar.update(-size)
493+
pbar.update(-size)
481494
self.update_failed()
482495
continue
483496
self.update_success()
@@ -585,11 +598,15 @@ async def sync(self):
585598
logger.tinfo("cluster.sync.files", count=len(files), size=units.format_bytes(total_size), last_modified=units.format_datetime_from_timestamp(last_modified))
586599

587600
check_storages = [CheckStorage(storage) for storage in self.storages.storages]
588-
missing_files: set[BMCLAPIFile] = set().union(
589-
*await utils.gather(*[
590-
check_storage.get_missing_files(files) for check_storage in check_storages
591-
])
592-
)
601+
with utils.MultiTQDM(
602+
len(check_storages),
603+
description="Listing files"
604+
) as pbar:
605+
missing_files: set[BMCLAPIFile] = set().union(
606+
*await utils.gather(*[
607+
check_storage.get_missing_files(files, pbar) for check_storage in check_storages
608+
])
609+
)
593610
if len(missing_files) > 0:
594611
logger.tinfo("cluster.sync.missing_files", count=len(missing_files), size=units.format_bytes(sum([f.size for f in missing_files])))
595612
download_manager = DownloadManager(missing_files, self.clusters, check_storages)

core/storage/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,9 @@ def __init__(
109109
self.files: dict[str, FileInfo] = {}
110110
self.missing_files: set[BMCLAPIFile] = set()
111111

112-
async def get_missing_files(self, bmclapi_files: set[BMCLAPIFile]):
112+
async def get_missing_files(self, bmclapi_files: set[BMCLAPIFile], muitlpbar: utils.MultiTQDM):
113113

114-
for file in await self.storage.list_download_files():
114+
for file in await self.storage.list_download_files(muitlpbar):
115115
self.files[file.name] = file
116116

117117
# start check

core/storage/abc.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -68,16 +68,22 @@ async def list_files(
6868

6969
async def list_download_files(
7070
self,
71+
muitlpbar: utils.MultiTQDM
7172
):
7273
res: list[FileInfo] = []
73-
async def works(root_ids: list[int]):
74-
for root_id in root_ids:
75-
files = await self.list_files(f"download/{root_id:02x}")
76-
res.extend(files)
77-
async with anyio.create_task_group() as task_group:
78-
work = utils.split_workload(list(RANGE), 10)
79-
for w in work:
80-
task_group.start_soon(works, w)
74+
with muitlpbar.sub(
75+
256,
76+
description=f"Listing files in {self.name}({self.type})"
77+
) as pbar:
78+
async def works(root_ids: list[int]):
79+
for root_id in root_ids:
80+
files = await self.list_files(f"download/{root_id:02x}")
81+
res.extend(files)
82+
pbar.update(1)
83+
async with anyio.create_task_group() as task_group:
84+
work = utils.split_workload(list(RANGE), 10)
85+
for w in work:
86+
task_group.start_soon(works, w)
8187
return res
8288

8389
@abc.abstractmethod

core/storage/alist.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def __repr__(self):
2424
return f"<AlistResponse code={self.code} data={self.data} message={self.message}>"
2525

2626
class AlistStorage(abc.Storage):
27+
type = "alist"
2728
def __init__(
2829
self,
2930
name: str,

core/storage/s3.py

Lines changed: 54 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import inspect
12
import tempfile
23
import time
4+
from typing import Any
35
import aioboto3.session
46
import anyio.abc
57
import anyio.to_thread
@@ -37,6 +39,7 @@ def __repr__(
3739
return f"S3Response(metadata={self.metadata}, data={self.raw_data})"
3840

3941
class S3Storage(abc.Storage):
42+
type = "s3"
4043
def __init__(
4144
self,
4245
name: str,
@@ -53,12 +56,21 @@ def __init__(
5356
self.bucket = bucket
5457
self.access_key = access_key
5558
self.secret_key = secret_key
59+
self.region = kwargs.get("region")
5660
self.custom_s3_host = kwargs.get("custom_s3_host", "")
5761
self.public_endpoint = kwargs.get("public_endpoint", "")
5862
self.session = aioboto3.Session()
5963
self.list_lock = anyio.Lock()
6064
self.cache_list_bucket: dict[str, abc.FileInfo] = {}
6165
self.last_cache: float = 0
66+
self._config = {
67+
"endpoint_url": self.endpoint,
68+
"aws_access_key_id": self.access_key,
69+
"aws_secret_access_key": self.secret_key,
70+
}
71+
if self.region:
72+
self._config["region_name"] = self.region
73+
6274

6375
async def setup(
6476
self,
@@ -71,37 +83,52 @@ async def setup(
7183
async def list_bucket(
7284
self,
7385
):
74-
async with self.list_lock:
75-
if time.perf_counter() - self.last_cache < 60:
76-
return
77-
async with self.session.resource(
78-
"s3",
79-
endpoint_url=self.endpoint,
80-
aws_access_key_id=self.access_key,
81-
aws_secret_access_key=self.secret_key,
82-
) as resource:
83-
bucket = await resource.Bucket(self.bucket)
84-
self.cache_list_bucket = {}
85-
async for obj in bucket.objects.all():
86-
cp = abc.CPath("/" + obj.key)
87-
self.cache_list_bucket[str(cp)] = abc.FileInfo(
88-
path=str(cp),
89-
name=cp.name,
90-
size=await obj.size,
91-
)
92-
self.last_cache = time.perf_counter()
86+
...
9387

9488
async def list_files(
9589
self,
9690
path: str
9791
) -> list[abc.FileInfo]:
98-
await self.list_bucket()
99-
# find by keys
10092
p = str(self.path / path)
10193
res = []
102-
for key in self.cache_list_bucket.keys():
103-
if str(abc.CPath(key).parents[-1]) == p:
104-
res.append(self.cache_list_bucket[key])
94+
async with self.session.client(
95+
"s3",
96+
endpoint_url=self.endpoint,
97+
aws_access_key_id=self.access_key,
98+
aws_secret_access_key=self.secret_key,
99+
region_name=self.region
100+
) as client: # type: ignore
101+
continuation_token = None
102+
while True:
103+
kwargs = {
104+
"Bucket": self.bucket,
105+
"Prefix": p[1:],
106+
#"Delimiter": "/", # 使用分隔符来模拟文件夹结构
107+
#"MaxKeys": 1000
108+
}
109+
if continuation_token:
110+
kwargs["ContinuationToken"] = continuation_token
111+
112+
response = await client.list_objects_v2(**kwargs)
113+
contents = response.get("Contents", [])
114+
for content in contents:
115+
file_path = f"/{content['Key']}"
116+
if "/" in file_path:
117+
file_name = file_path.rsplit("/", 1)[1]
118+
else:
119+
file_name = file_path[1:]
120+
res.append(abc.FileInfo(
121+
name=file_name,
122+
size=content["Size"],
123+
path=f'/{content["Key"]}',
124+
))
125+
126+
#res.extend(response.get("Contents", [])) # 添加文件
127+
#res.extend(response.get("CommonPrefixes", [])) # 添加子目录
128+
129+
if "NextContinuationToken" not in response:
130+
break
131+
continuation_token = response["NextContinuationToken"]
105132
return res
106133

107134

@@ -115,6 +142,7 @@ async def upload(
115142
endpoint_url=self.endpoint,
116143
aws_access_key_id=self.access_key,
117144
aws_secret_access_key=self.secret_key,
145+
region_name=self.region
118146
) as resource:
119147
bucket = await resource.Bucket(self.bucket)
120148
obj = await bucket.Object(str(self.path / path))
@@ -152,6 +180,7 @@ async def get_response_file(self, hash: str) -> abc.ResponseFile:
152180
endpoint_url=self.endpoint,
153181
aws_access_key_id=self.access_key,
154182
aws_secret_access_key=self.secret_key,
183+
region_name=self.region
155184
) as client: # type: ignore
156185
url = await client.generate_presigned_url(
157186
ClientMethod="get_object",
@@ -182,6 +211,7 @@ async def get_response_file(self, hash: str) -> abc.ResponseFile:
182211
endpoint_url=self.endpoint,
183212
aws_access_key_id=self.access_key,
184213
aws_secret_access_key=self.secret_key,
214+
region_name=self.region
185215
) as resource:
186216
bucket = await resource.Bucket(self.bucket)
187217
obj = await bucket.Object(cpath)

core/storage/webdav.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import aiowebdav
1414

1515
class WebDavStorage(abc.Storage):
16+
type = "webdav"
1617
def __init__(
1718
self,
1819
name: str,

0 commit comments

Comments
 (0)