@@ -5,23 +5,26 @@ import (
55 "crypto/md5"
66 "encoding/hex"
77 "errors"
8+ "fmt"
89 "io"
910 "net/url"
1011 "os"
1112 stdpath "path"
1213 "strconv"
14+ "strings"
15+ "sync"
1316 "time"
1417
15- "golang.org/x/sync/semaphore"
16-
1718 "github.com/alist-org/alist/v3/drivers/base"
1819 "github.com/alist-org/alist/v3/internal/conf"
1920 "github.com/alist-org/alist/v3/internal/driver"
2021 "github.com/alist-org/alist/v3/internal/errs"
2122 "github.com/alist-org/alist/v3/internal/model"
2223 "github.com/alist-org/alist/v3/pkg/errgroup"
24+ "github.com/alist-org/alist/v3/pkg/singleflight"
2325 "github.com/alist-org/alist/v3/pkg/utils"
2426 "github.com/avast/retry-go"
27+ "github.com/go-resty/resty/v2"
2528 log "github.com/sirupsen/logrus"
2629)
2730
@@ -31,8 +34,16 @@ type BaiduNetdisk struct {
3134
3235 uploadThread int
3336 vipType int // 会员类型,0普通用户(4G/4M)、1普通会员(10G/16M)、2超级会员(20G/32M)
37+
38+ upClient * resty.Client // 上传文件使用的http客户端
39+ uploadUrlG singleflight.Group [string ]
40+ uploadUrlMu sync.RWMutex
41+ uploadUrl string // 上传域名
42+ uploadUrlUpdateTime time.Time // 上传域名上次更新时间
3443}
3544
45+ var ErrUploadIDExpired = errors .New ("uploadid expired" )
46+
3647func (d * BaiduNetdisk ) Config () driver.Config {
3748 return config
3849}
@@ -42,19 +53,26 @@ func (d *BaiduNetdisk) GetAddition() driver.Additional {
4253}
4354
4455func (d * BaiduNetdisk ) Init (ctx context.Context ) error {
56+ d .upClient = base .NewRestyClient ().
57+ SetTimeout (UPLOAD_TIMEOUT ).
58+ SetRetryCount (UPLOAD_RETRY_COUNT ).
59+ SetRetryWaitTime (UPLOAD_RETRY_WAIT_TIME ).
60+ SetRetryMaxWaitTime (UPLOAD_RETRY_MAX_WAIT_TIME )
4561 d .uploadThread , _ = strconv .Atoi (d .UploadThread )
46- if d .uploadThread < 1 || d .uploadThread > 32 {
47- d .uploadThread , d .UploadThread = 3 , "3"
62+ if d .uploadThread < 1 {
63+ d .uploadThread , d .UploadThread = 1 , "1"
64+ } else if d .uploadThread > 32 {
65+ d .uploadThread , d .UploadThread = 32 , "32"
4866 }
4967
5068 if _ , err := url .Parse (d .UploadAPI ); d .UploadAPI == "" || err != nil {
51- d .UploadAPI = "https://d.pcs.baidu.com"
69+ d .UploadAPI = UPLOAD_FALLBACK_API
5270 }
5371
5472 res , err := d .get ("/xpan/nas" , map [string ]string {
5573 "method" : "uinfo" ,
5674 }, nil )
57- log .Debugf ("[baidu ] get uinfo: %s" , string (res ))
75+ log .Debugf ("[baidu_netdisk ] get uinfo: %s" , string (res ))
5876 if err != nil {
5977 return err
6078 }
@@ -181,6 +199,11 @@ func (d *BaiduNetdisk) PutRapid(ctx context.Context, dstDir model.Obj, stream mo
181199// **注意**: 截至 2024/04/20 百度云盘 api 接口返回的时间永远是当前时间,而不是文件时间。
182200// 而实际上云盘存储的时间是文件时间,所以此处需要覆盖时间,保证缓存与云盘的数据一致
183201func (d * BaiduNetdisk ) Put (ctx context.Context , dstDir model.Obj , stream model.FileStreamer , up driver.UpdateProgress ) (model.Obj , error ) {
202+ // 百度网盘不允许上传空文件
203+ if stream .GetSize () < 1 {
204+ return nil , ErrBaiduEmptyFilesNotAllowed
205+ }
206+
184207 // rapid upload
185208 if newObj , err := d .PutRapid (ctx , dstDir , stream ); err == nil {
186209 return newObj , nil
@@ -245,7 +268,7 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F
245268 }
246269 if tmpF != nil {
247270 if written != streamSize {
248- return nil , errs .NewErr (err , "CreateTempFile failed, incoming stream actual size= %d, expect = %d " , written , streamSize )
271+ return nil , errs .NewErr (err , "CreateTempFile failed, size mismatch: %d ! = %d " , written , streamSize )
249272 }
250273 _ , err = tmpF .Seek (0 , io .SeekStart )
251274 if err != nil {
@@ -259,82 +282,97 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F
259282 mtime := stream .ModTime ().Unix ()
260283 ctime := stream .CreateTime ().Unix ()
261284
262- // step.1 预上传
263- // 尝试获取之前的进度
285+ // step.1 尝试读取已保存进度
264286 precreateResp , ok := base .GetUploadProgress [* PrecreateResp ](d , d .AccessToken , contentMd5 )
265287 if ! ok {
266- params := map [string ]string {
267- "method" : "precreate" ,
268- }
269- form := map [string ]string {
270- "path" : path ,
271- "size" : strconv .FormatInt (streamSize , 10 ),
272- "isdir" : "0" ,
273- "autoinit" : "1" ,
274- "rtype" : "3" ,
275- "block_list" : blockListStr ,
276- "content-md5" : contentMd5 ,
277- "slice-md5" : sliceMd5 ,
278- }
279- joinTime (form , ctime , mtime )
280-
281- log .Debugf ("[baidu_netdisk] precreate data: %s" , form )
282- _ , err = d .postForm ("/xpan/file" , params , form , & precreateResp )
288+ // 没有进度,走预上传
289+ precreateResp , err = d .precreate (ctx , path , streamSize , blockListStr , contentMd5 , sliceMd5 , ctime , mtime )
283290 if err != nil {
284291 return nil , err
285292 }
286- log .Debugf ("%+v" , precreateResp )
287293 if precreateResp .ReturnType == 2 {
288294 //rapid upload, since got md5 match from baidu server
289295 // 修复时间,具体原因见 Put 方法注释的 **注意**
290- precreateResp .File .Ctime = ctime
291- precreateResp .File .Mtime = mtime
292296 return fileToObj (precreateResp .File ), nil
293297 }
294298 }
299+
295300 // step.2 上传分片
296- threadG , upCtx := errgroup .NewGroupWithContext (ctx , d .uploadThread ,
297- retry .Attempts (1 ),
298- retry .Delay (time .Second ),
299- retry .DelayType (retry .BackOffDelay ))
300- sem := semaphore .NewWeighted (3 )
301- for i , partseq := range precreateResp .BlockList {
302- if utils .IsCanceled (upCtx ) {
303- break
301+ uploadLoop:
302+ for attempt := 0 ; attempt < 2 ; attempt ++ {
303+ // 获取上传域名
304+ uploadUrl := d .getUploadUrl (path , precreateResp .Uploadid )
305+ // 并发上传
306+ threadG , upCtx := errgroup .NewGroupWithContext (ctx , d .uploadThread ,
307+ retry .Attempts (1 ),
308+ retry .Delay (time .Second ),
309+ retry .DelayType (retry .BackOffDelay ))
310+
311+ cacheReaderAt , okReaderAt := cache .(io.ReaderAt )
312+ if ! okReaderAt {
313+ return nil , fmt .Errorf ("cache object must implement io.ReaderAt interface for upload operations" )
304314 }
305315
306- i , partseq , offset , byteSize := i , partseq , int64 (partseq )* sliceSize , sliceSize
307- if partseq + 1 == count {
308- byteSize = lastBlockSize
309- }
310- threadG .Go (func (ctx context.Context ) error {
311- if err = sem .Acquire (ctx , 1 ); err != nil {
312- return err
313- }
314- defer sem .Release (1 )
315- params := map [string ]string {
316- "method" : "upload" ,
317- "access_token" : d .AccessToken ,
318- "type" : "tmpfile" ,
319- "path" : path ,
320- "uploadid" : precreateResp .Uploadid ,
321- "partseq" : strconv .Itoa (partseq ),
316+ totalParts := len (precreateResp .BlockList )
317+ for i , partseq := range precreateResp .BlockList {
318+ if utils .IsCanceled (upCtx ) || partseq < 0 {
319+ continue
322320 }
323- err := d .uploadSlice (ctx , params , stream .GetName (),
324- driver .NewLimitedUploadStream (ctx , io .NewSectionReader (cache , offset , byteSize )))
325- if err != nil {
326- return err
321+
322+ i , partseq := i , partseq
323+ offset , size := int64 (partseq )* sliceSize , sliceSize
324+ if partseq + 1 == count {
325+ size = lastBlockSize
327326 }
328- up (float64 (threadG .Success ()) * 100 / float64 (len (precreateResp .BlockList )))
329- precreateResp .BlockList [i ] = - 1
330- return nil
331- })
332- }
333- if err = threadG .Wait (); err != nil {
334- // 如果属于用户主动取消,则保存上传进度
327+ threadG .Go (func (ctx context.Context ) error {
328+ params := map [string ]string {
329+ "method" : "upload" ,
330+ "access_token" : d .AccessToken ,
331+ "type" : "tmpfile" ,
332+ "path" : path ,
333+ "uploadid" : precreateResp .Uploadid ,
334+ "partseq" : strconv .Itoa (partseq ),
335+ }
336+ section := io .NewSectionReader (cacheReaderAt , offset , size )
337+ err := d .uploadSlice (ctx , uploadUrl , params , stream .GetName (), driver .NewLimitedUploadStream (ctx , section ))
338+ if err != nil {
339+ return err
340+ }
341+ precreateResp .BlockList [i ] = - 1
342+ // 当前goroutine还没退出,+1才是真正成功的数量
343+ success := threadG .Success () + 1
344+ progress := float64 (success ) * 100 / float64 (totalParts )
345+ up (progress )
346+ return nil
347+ })
348+ }
349+
350+ err = threadG .Wait ()
351+ if err == nil {
352+ break uploadLoop
353+ }
354+
355+ // 保存进度(所有错误都会保存)
356+ precreateResp .BlockList = utils .SliceFilter (precreateResp .BlockList , func (s int ) bool { return s >= 0 })
357+ base .SaveUploadProgress (d , precreateResp , d .AccessToken , contentMd5 )
358+
335359 if errors .Is (err , context .Canceled ) {
336- precreateResp .BlockList = utils .SliceFilter (precreateResp .BlockList , func (s int ) bool { return s >= 0 })
360+ return nil , err
361+ }
362+ if errors .Is (err , ErrUploadIDExpired ) {
363+ log .Warn ("[baidu_netdisk] uploadid expired, will restart from scratch" )
364+ // 重新 precreate(所有分片都要重传)
365+ newPre , err2 := d .precreate (ctx , path , streamSize , blockListStr , "" , "" , ctime , mtime )
366+ if err2 != nil {
367+ return nil , err2
368+ }
369+ if newPre .ReturnType == 2 {
370+ return fileToObj (newPre .File ), nil
371+ }
372+ precreateResp = newPre
373+ // 覆盖掉旧的进度
337374 base .SaveUploadProgress (d , precreateResp , d .AccessToken , contentMd5 )
375+ continue uploadLoop
338376 }
339377 return nil , err
340378 }
@@ -348,23 +386,67 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F
348386 // 修复时间,具体原因见 Put 方法注释的 **注意**
349387 newFile .Ctime = ctime
350388 newFile .Mtime = mtime
389+ // 上传成功清理进度
390+ base .SaveUploadProgress (d , nil , d .AccessToken , contentMd5 )
351391 return fileToObj (newFile ), nil
352392}
353393
354- func (d * BaiduNetdisk ) uploadSlice (ctx context.Context , params map [string ]string , fileName string , file io.Reader ) error {
355- res , err := base .RestyClient .R ().
394+ // precreate 执行预上传操作,支持首次上传和 uploadid 过期重试
395+ func (d * BaiduNetdisk ) precreate (ctx context.Context , path string , streamSize int64 , blockListStr , contentMd5 , sliceMd5 string , ctime , mtime int64 ) (* PrecreateResp , error ) {
396+ params := map [string ]string {"method" : "precreate" }
397+ form := map [string ]string {
398+ "path" : path ,
399+ "size" : strconv .FormatInt (streamSize , 10 ),
400+ "isdir" : "0" ,
401+ "autoinit" : "1" ,
402+ "rtype" : "3" ,
403+ "block_list" : blockListStr ,
404+ }
405+
406+ // 只有在首次上传时才包含 content-md5 和 slice-md5
407+ if contentMd5 != "" && sliceMd5 != "" {
408+ form ["content-md5" ] = contentMd5
409+ form ["slice-md5" ] = sliceMd5
410+ }
411+
412+ joinTime (form , ctime , mtime )
413+
414+ var precreateResp PrecreateResp
415+ _ , err := d .postForm ("/xpan/file" , params , form , & precreateResp )
416+ if err != nil {
417+ return nil , err
418+ }
419+
420+ // 修复时间,具体原因见 Put 方法注释的 **注意**
421+ if precreateResp .ReturnType == 2 {
422+ precreateResp .File .Ctime = ctime
423+ precreateResp .File .Mtime = mtime
424+ }
425+
426+ return & precreateResp , nil
427+ }
428+
429+ func (d * BaiduNetdisk ) uploadSlice (ctx context.Context , uploadUrl string , params map [string ]string , fileName string , file io.Reader ) error {
430+ res , err := d .upClient .R ().
356431 SetContext (ctx ).
357432 SetQueryParams (params ).
358433 SetFileReader ("file" , fileName , file ).
359- Post (d . UploadAPI + "/rest/2.0/pcs/superfile2" )
434+ Post (uploadUrl + "/rest/2.0/pcs/superfile2" )
360435 if err != nil {
361436 return err
362437 }
363438 log .Debugln (res .RawResponse .Status + res .String ())
364439 errCode := utils .Json .Get (res .Body (), "error_code" ).ToInt ()
365440 errNo := utils .Json .Get (res .Body (), "errno" ).ToInt ()
441+ respStr := res .String ()
442+ lower := strings .ToLower (respStr )
443+ if strings .Contains (lower , "uploadid" ) &&
444+ (strings .Contains (lower , "invalid" ) || strings .Contains (lower , "expired" ) || strings .Contains (lower , "not found" )) {
445+ return ErrUploadIDExpired
446+ }
447+
366448 if errCode != 0 || errNo != 0 {
367- return errs .NewErr (errs .StreamIncomplete , "error in uploading to baidu, will retry. response=%s" , res .String ())
449+ return errs .NewErr (errs .StreamIncomplete , "error uploading to baidu, response=%s" , res .String ())
368450 }
369451 return nil
370452}
0 commit comments