1+ import inspect
12import tempfile
23import time
4+ from typing import Any
35import aioboto3 .session
46import anyio .abc
57import anyio .to_thread
@@ -37,6 +39,7 @@ def __repr__(
3739 return f"S3Response(metadata={ self .metadata } , data={ self .raw_data } )"
3840
3941class S3Storage (abc .Storage ):
42+ type = "s3"
4043 def __init__ (
4144 self ,
4245 name : str ,
@@ -53,12 +56,21 @@ def __init__(
5356 self .bucket = bucket
5457 self .access_key = access_key
5558 self .secret_key = secret_key
59+ self .region = kwargs .get ("region" )
5660 self .custom_s3_host = kwargs .get ("custom_s3_host" , "" )
5761 self .public_endpoint = kwargs .get ("public_endpoint" , "" )
5862 self .session = aioboto3 .Session ()
5963 self .list_lock = anyio .Lock ()
6064 self .cache_list_bucket : dict [str , abc .FileInfo ] = {}
6165 self .last_cache : float = 0
66+ self ._config = {
67+ "endpoint_url" : self .endpoint ,
68+ "aws_access_key_id" : self .access_key ,
69+ "aws_secret_access_key" : self .secret_key ,
70+ }
71+ if self .region :
72+ self ._config ["region_name" ] = self .region
73+
6274
6375 async def setup (
6476 self ,
@@ -71,37 +83,52 @@ async def setup(
7183 async def list_bucket (
7284 self ,
7385 ):
74- async with self .list_lock :
75- if time .perf_counter () - self .last_cache < 60 :
76- return
77- async with self .session .resource (
78- "s3" ,
79- endpoint_url = self .endpoint ,
80- aws_access_key_id = self .access_key ,
81- aws_secret_access_key = self .secret_key ,
82- ) as resource :
83- bucket = await resource .Bucket (self .bucket )
84- self .cache_list_bucket = {}
85- async for obj in bucket .objects .all ():
86- cp = abc .CPath ("/" + obj .key )
87- self .cache_list_bucket [str (cp )] = abc .FileInfo (
88- path = str (cp ),
89- name = cp .name ,
90- size = await obj .size ,
91- )
92- self .last_cache = time .perf_counter ()
86+ ...
9387
9488 async def list_files (
9589 self ,
9690 path : str
9791 ) -> list [abc .FileInfo ]:
98- await self .list_bucket ()
99- # find by keys
10092 p = str (self .path / path )
10193 res = []
102- for key in self .cache_list_bucket .keys ():
103- if str (abc .CPath (key ).parents [- 1 ]) == p :
104- res .append (self .cache_list_bucket [key ])
94+ async with self .session .client (
95+ "s3" ,
96+ endpoint_url = self .endpoint ,
97+ aws_access_key_id = self .access_key ,
98+ aws_secret_access_key = self .secret_key ,
99+ region_name = self .region
100+ ) as client : # type: ignore
101+ continuation_token = None
102+ while True :
103+ kwargs = {
104+ "Bucket" : self .bucket ,
105+ "Prefix" : p [1 :],
106+ #"Delimiter": "/", # 使用分隔符来模拟文件夹结构
107+ #"MaxKeys": 1000
108+ }
109+ if continuation_token :
110+ kwargs ["ContinuationToken" ] = continuation_token
111+
112+ response = await client .list_objects_v2 (** kwargs )
113+ contents = response .get ("Contents" , [])
114+ for content in contents :
115+ file_path = f"/{ content ['Key' ]} "
116+ if "/" in file_path :
117+ file_name = file_path .rsplit ("/" , 1 )[1 ]
118+ else :
119+ file_name = file_path [1 :]
120+ res .append (abc .FileInfo (
121+ name = file_name ,
122+ size = content ["Size" ],
123+ path = f'/{ content ["Key" ]} ' ,
124+ ))
125+
126+ #res.extend(response.get("Contents", [])) # 添加文件
127+ #res.extend(response.get("CommonPrefixes", [])) # 添加子目录
128+
129+ if "NextContinuationToken" not in response :
130+ break
131+ continuation_token = response ["NextContinuationToken" ]
105132 return res
106133
107134
@@ -115,6 +142,7 @@ async def upload(
115142 endpoint_url = self .endpoint ,
116143 aws_access_key_id = self .access_key ,
117144 aws_secret_access_key = self .secret_key ,
145+ region_name = self .region
118146 ) as resource :
119147 bucket = await resource .Bucket (self .bucket )
120148 obj = await bucket .Object (str (self .path / path ))
@@ -152,6 +180,7 @@ async def get_response_file(self, hash: str) -> abc.ResponseFile:
152180 endpoint_url = self .endpoint ,
153181 aws_access_key_id = self .access_key ,
154182 aws_secret_access_key = self .secret_key ,
183+ region_name = self .region
155184 ) as client : # type: ignore
156185 url = await client .generate_presigned_url (
157186 ClientMethod = "get_object" ,
@@ -182,6 +211,7 @@ async def get_response_file(self, hash: str) -> abc.ResponseFile:
182211 endpoint_url = self .endpoint ,
183212 aws_access_key_id = self .access_key ,
184213 aws_secret_access_key = self .secret_key ,
214+ region_name = self .region
185215 ) as resource :
186216 bucket = await resource .Bucket (self .bucket )
187217 obj = await bucket .Object (cpath )
0 commit comments