Skip to content

Commit e5662ef

Browse files
committed
feat(driver): Enhanced Baidu Netdisk upload logic with dynamic URL retrieval
- Added support for dynamically retrieving upload URLs, with fallback in case of failure - Improved token refresh and error handling during uploads - Prevented uploading of empty files and added error message for invalid operations - Refactored large file uploads with segment-level progress and retry handling logic - Introduced constants for upload settings, enabling better configurability - Improved logs to include the driver name for better debugging context
1 parent b4d9beb commit e5662ef

File tree

4 files changed

+276
-76
lines changed

4 files changed

+276
-76
lines changed

drivers/baidu_netdisk/driver.go

Lines changed: 151 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,26 @@ import (
55
"crypto/md5"
66
"encoding/hex"
77
"errors"
8+
"fmt"
89
"io"
910
"net/url"
1011
"os"
1112
stdpath "path"
1213
"strconv"
14+
"strings"
15+
"sync"
1316
"time"
1417

15-
"golang.org/x/sync/semaphore"
16-
1718
"github.com/alist-org/alist/v3/drivers/base"
1819
"github.com/alist-org/alist/v3/internal/conf"
1920
"github.com/alist-org/alist/v3/internal/driver"
2021
"github.com/alist-org/alist/v3/internal/errs"
2122
"github.com/alist-org/alist/v3/internal/model"
2223
"github.com/alist-org/alist/v3/pkg/errgroup"
24+
"github.com/alist-org/alist/v3/pkg/singleflight"
2325
"github.com/alist-org/alist/v3/pkg/utils"
2426
"github.com/avast/retry-go"
27+
"github.com/go-resty/resty/v2"
2528
log "github.com/sirupsen/logrus"
2629
)
2730

@@ -31,8 +34,16 @@ type BaiduNetdisk struct {
3134

3235
uploadThread int
3336
vipType int // 会员类型,0普通用户(4G/4M)、1普通会员(10G/16M)、2超级会员(20G/32M)
37+
38+
upClient *resty.Client // 上传文件使用的http客户端
39+
uploadUrlG singleflight.Group[string]
40+
uploadUrlMu sync.RWMutex
41+
uploadUrl string // 上传域名
42+
uploadUrlUpdateTime time.Time // 上传域名上次更新时间
3443
}
3544

45+
var ErrUploadIDExpired = errors.New("uploadid expired")
46+
3647
func (d *BaiduNetdisk) Config() driver.Config {
3748
return config
3849
}
@@ -42,19 +53,26 @@ func (d *BaiduNetdisk) GetAddition() driver.Additional {
4253
}
4354

4455
func (d *BaiduNetdisk) Init(ctx context.Context) error {
56+
d.upClient = base.NewRestyClient().
57+
SetTimeout(UPLOAD_TIMEOUT).
58+
SetRetryCount(UPLOAD_RETRY_COUNT).
59+
SetRetryWaitTime(UPLOAD_RETRY_WAIT_TIME).
60+
SetRetryMaxWaitTime(UPLOAD_RETRY_MAX_WAIT_TIME)
4561
d.uploadThread, _ = strconv.Atoi(d.UploadThread)
46-
if d.uploadThread < 1 || d.uploadThread > 32 {
47-
d.uploadThread, d.UploadThread = 3, "3"
62+
if d.uploadThread < 1 {
63+
d.uploadThread, d.UploadThread = 1, "1"
64+
} else if d.uploadThread > 32 {
65+
d.uploadThread, d.UploadThread = 32, "32"
4866
}
4967

5068
if _, err := url.Parse(d.UploadAPI); d.UploadAPI == "" || err != nil {
51-
d.UploadAPI = "https://d.pcs.baidu.com"
69+
d.UploadAPI = UPLOAD_FALLBACK_API
5270
}
5371

5472
res, err := d.get("/xpan/nas", map[string]string{
5573
"method": "uinfo",
5674
}, nil)
57-
log.Debugf("[baidu] get uinfo: %s", string(res))
75+
log.Debugf("[baidu_netdisk] get uinfo: %s", string(res))
5876
if err != nil {
5977
return err
6078
}
@@ -181,6 +199,11 @@ func (d *BaiduNetdisk) PutRapid(ctx context.Context, dstDir model.Obj, stream mo
181199
// **注意**: 截至 2024/04/20 百度云盘 api 接口返回的时间永远是当前时间,而不是文件时间。
182200
// 而实际上云盘存储的时间是文件时间,所以此处需要覆盖时间,保证缓存与云盘的数据一致
183201
func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) {
202+
// 百度网盘不允许上传空文件
203+
if stream.GetSize() < 1 {
204+
return nil, ErrBaiduEmptyFilesNotAllowed
205+
}
206+
184207
// rapid upload
185208
if newObj, err := d.PutRapid(ctx, dstDir, stream); err == nil {
186209
return newObj, nil
@@ -245,7 +268,7 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F
245268
}
246269
if tmpF != nil {
247270
if written != streamSize {
248-
return nil, errs.NewErr(err, "CreateTempFile failed, incoming stream actual size= %d, expect = %d ", written, streamSize)
271+
return nil, errs.NewErr(err, "CreateTempFile failed, size mismatch: %d != %d ", written, streamSize)
249272
}
250273
_, err = tmpF.Seek(0, io.SeekStart)
251274
if err != nil {
@@ -259,82 +282,97 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F
259282
mtime := stream.ModTime().Unix()
260283
ctime := stream.CreateTime().Unix()
261284

262-
// step.1 预上传
263-
// 尝试获取之前的进度
285+
// step.1 尝试读取已保存进度
264286
precreateResp, ok := base.GetUploadProgress[*PrecreateResp](d, d.AccessToken, contentMd5)
265287
if !ok {
266-
params := map[string]string{
267-
"method": "precreate",
268-
}
269-
form := map[string]string{
270-
"path": path,
271-
"size": strconv.FormatInt(streamSize, 10),
272-
"isdir": "0",
273-
"autoinit": "1",
274-
"rtype": "3",
275-
"block_list": blockListStr,
276-
"content-md5": contentMd5,
277-
"slice-md5": sliceMd5,
278-
}
279-
joinTime(form, ctime, mtime)
280-
281-
log.Debugf("[baidu_netdisk] precreate data: %s", form)
282-
_, err = d.postForm("/xpan/file", params, form, &precreateResp)
288+
// 没有进度,走预上传
289+
precreateResp, err = d.precreate(ctx, path, streamSize, blockListStr, contentMd5, sliceMd5, ctime, mtime)
283290
if err != nil {
284291
return nil, err
285292
}
286-
log.Debugf("%+v", precreateResp)
287293
if precreateResp.ReturnType == 2 {
288294
//rapid upload, since got md5 match from baidu server
289295
// 修复时间,具体原因见 Put 方法注释的 **注意**
290-
precreateResp.File.Ctime = ctime
291-
precreateResp.File.Mtime = mtime
292296
return fileToObj(precreateResp.File), nil
293297
}
294298
}
299+
295300
// step.2 上传分片
296-
threadG, upCtx := errgroup.NewGroupWithContext(ctx, d.uploadThread,
297-
retry.Attempts(1),
298-
retry.Delay(time.Second),
299-
retry.DelayType(retry.BackOffDelay))
300-
sem := semaphore.NewWeighted(3)
301-
for i, partseq := range precreateResp.BlockList {
302-
if utils.IsCanceled(upCtx) {
303-
break
301+
uploadLoop:
302+
for attempt := 0; attempt < 2; attempt++ {
303+
// 获取上传域名
304+
uploadUrl := d.getUploadUrl(path, precreateResp.Uploadid)
305+
// 并发上传
306+
threadG, upCtx := errgroup.NewGroupWithContext(ctx, d.uploadThread,
307+
retry.Attempts(1),
308+
retry.Delay(time.Second),
309+
retry.DelayType(retry.BackOffDelay))
310+
311+
cacheReaderAt, okReaderAt := cache.(io.ReaderAt)
312+
if !okReaderAt {
313+
return nil, fmt.Errorf("cache object must implement io.ReaderAt interface for upload operations")
304314
}
305315

306-
i, partseq, offset, byteSize := i, partseq, int64(partseq)*sliceSize, sliceSize
307-
if partseq+1 == count {
308-
byteSize = lastBlockSize
309-
}
310-
threadG.Go(func(ctx context.Context) error {
311-
if err = sem.Acquire(ctx, 1); err != nil {
312-
return err
313-
}
314-
defer sem.Release(1)
315-
params := map[string]string{
316-
"method": "upload",
317-
"access_token": d.AccessToken,
318-
"type": "tmpfile",
319-
"path": path,
320-
"uploadid": precreateResp.Uploadid,
321-
"partseq": strconv.Itoa(partseq),
316+
totalParts := len(precreateResp.BlockList)
317+
for i, partseq := range precreateResp.BlockList {
318+
if utils.IsCanceled(upCtx) || partseq < 0 {
319+
continue
322320
}
323-
err := d.uploadSlice(ctx, params, stream.GetName(),
324-
driver.NewLimitedUploadStream(ctx, io.NewSectionReader(cache, offset, byteSize)))
325-
if err != nil {
326-
return err
321+
322+
i, partseq := i, partseq
323+
offset, size := int64(partseq)*sliceSize, sliceSize
324+
if partseq+1 == count {
325+
size = lastBlockSize
327326
}
328-
up(float64(threadG.Success()) * 100 / float64(len(precreateResp.BlockList)))
329-
precreateResp.BlockList[i] = -1
330-
return nil
331-
})
332-
}
333-
if err = threadG.Wait(); err != nil {
334-
// 如果属于用户主动取消,则保存上传进度
327+
threadG.Go(func(ctx context.Context) error {
328+
params := map[string]string{
329+
"method": "upload",
330+
"access_token": d.AccessToken,
331+
"type": "tmpfile",
332+
"path": path,
333+
"uploadid": precreateResp.Uploadid,
334+
"partseq": strconv.Itoa(partseq),
335+
}
336+
section := io.NewSectionReader(cacheReaderAt, offset, size)
337+
err := d.uploadSlice(ctx, uploadUrl, params, stream.GetName(), driver.NewLimitedUploadStream(ctx, section))
338+
if err != nil {
339+
return err
340+
}
341+
precreateResp.BlockList[i] = -1
342+
// 当前goroutine还没退出,+1才是真正成功的数量
343+
success := threadG.Success() + 1
344+
progress := float64(success) * 100 / float64(totalParts)
345+
up(progress)
346+
return nil
347+
})
348+
}
349+
350+
err = threadG.Wait()
351+
if err == nil {
352+
break uploadLoop
353+
}
354+
355+
// 保存进度(所有错误都会保存)
356+
precreateResp.BlockList = utils.SliceFilter(precreateResp.BlockList, func(s int) bool { return s >= 0 })
357+
base.SaveUploadProgress(d, precreateResp, d.AccessToken, contentMd5)
358+
335359
if errors.Is(err, context.Canceled) {
336-
precreateResp.BlockList = utils.SliceFilter(precreateResp.BlockList, func(s int) bool { return s >= 0 })
360+
return nil, err
361+
}
362+
if errors.Is(err, ErrUploadIDExpired) {
363+
log.Warn("[baidu_netdisk] uploadid expired, will restart from scratch")
364+
// 重新 precreate(所有分片都要重传)
365+
newPre, err2 := d.precreate(ctx, path, streamSize, blockListStr, "", "", ctime, mtime)
366+
if err2 != nil {
367+
return nil, err2
368+
}
369+
if newPre.ReturnType == 2 {
370+
return fileToObj(newPre.File), nil
371+
}
372+
precreateResp = newPre
373+
// 覆盖掉旧的进度
337374
base.SaveUploadProgress(d, precreateResp, d.AccessToken, contentMd5)
375+
continue uploadLoop
338376
}
339377
return nil, err
340378
}
@@ -348,23 +386,67 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F
348386
// 修复时间,具体原因见 Put 方法注释的 **注意**
349387
newFile.Ctime = ctime
350388
newFile.Mtime = mtime
389+
// 上传成功清理进度
390+
base.SaveUploadProgress(d, nil, d.AccessToken, contentMd5)
351391
return fileToObj(newFile), nil
352392
}
353393

354-
func (d *BaiduNetdisk) uploadSlice(ctx context.Context, params map[string]string, fileName string, file io.Reader) error {
355-
res, err := base.RestyClient.R().
394+
// precreate 执行预上传操作,支持首次上传和 uploadid 过期重试
395+
func (d *BaiduNetdisk) precreate(ctx context.Context, path string, streamSize int64, blockListStr, contentMd5, sliceMd5 string, ctime, mtime int64) (*PrecreateResp, error) {
396+
params := map[string]string{"method": "precreate"}
397+
form := map[string]string{
398+
"path": path,
399+
"size": strconv.FormatInt(streamSize, 10),
400+
"isdir": "0",
401+
"autoinit": "1",
402+
"rtype": "3",
403+
"block_list": blockListStr,
404+
}
405+
406+
// 只有在首次上传时才包含 content-md5 和 slice-md5
407+
if contentMd5 != "" && sliceMd5 != "" {
408+
form["content-md5"] = contentMd5
409+
form["slice-md5"] = sliceMd5
410+
}
411+
412+
joinTime(form, ctime, mtime)
413+
414+
var precreateResp PrecreateResp
415+
_, err := d.postForm("/xpan/file", params, form, &precreateResp)
416+
if err != nil {
417+
return nil, err
418+
}
419+
420+
// 修复时间,具体原因见 Put 方法注释的 **注意**
421+
if precreateResp.ReturnType == 2 {
422+
precreateResp.File.Ctime = ctime
423+
precreateResp.File.Mtime = mtime
424+
}
425+
426+
return &precreateResp, nil
427+
}
428+
429+
func (d *BaiduNetdisk) uploadSlice(ctx context.Context, uploadUrl string, params map[string]string, fileName string, file io.Reader) error {
430+
res, err := d.upClient.R().
356431
SetContext(ctx).
357432
SetQueryParams(params).
358433
SetFileReader("file", fileName, file).
359-
Post(d.UploadAPI + "/rest/2.0/pcs/superfile2")
434+
Post(uploadUrl + "/rest/2.0/pcs/superfile2")
360435
if err != nil {
361436
return err
362437
}
363438
log.Debugln(res.RawResponse.Status + res.String())
364439
errCode := utils.Json.Get(res.Body(), "error_code").ToInt()
365440
errNo := utils.Json.Get(res.Body(), "errno").ToInt()
441+
respStr := res.String()
442+
lower := strings.ToLower(respStr)
443+
if strings.Contains(lower, "uploadid") &&
444+
(strings.Contains(lower, "invalid") || strings.Contains(lower, "expired") || strings.Contains(lower, "not found")) {
445+
return ErrUploadIDExpired
446+
}
447+
366448
if errCode != 0 || errNo != 0 {
367-
return errs.NewErr(errs.StreamIncomplete, "error in uploading to baidu, will retry. response=%s", res.String())
449+
return errs.NewErr(errs.StreamIncomplete, "error uploading to baidu, response=%s", res.String())
368450
}
369451
return nil
370452
}

drivers/baidu_netdisk/meta.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package baidu_netdisk
22

33
import (
4+
"time"
5+
46
"github.com/alist-org/alist/v3/internal/driver"
57
"github.com/alist-org/alist/v3/internal/op"
68
)
@@ -17,11 +19,21 @@ type Addition struct {
1719
AccessToken string
1820
UploadThread string `json:"upload_thread" default:"3" help:"1<=thread<=32"`
1921
UploadAPI string `json:"upload_api" default:"https://d.pcs.baidu.com"`
22+
UseDynamicUploadAPI bool `json:"use_dynamic_upload_api" default:"true" help:"dynamically get upload api domain, when enabled, the 'Upload API' setting will be used as a fallback if failed to get"`
2023
CustomUploadPartSize int64 `json:"custom_upload_part_size" type:"number" default:"0" help:"0 for auto"`
2124
LowBandwithUploadMode bool `json:"low_bandwith_upload_mode" default:"false"`
2225
OnlyListVideoFile bool `json:"only_list_video_file" default:"false"`
2326
}
2427

28+
const (
29+
UPLOAD_FALLBACK_API = "https://d.pcs.baidu.com" // 备用上传地址
30+
UPLOAD_URL_EXPIRE_TIME = time.Minute * 60 // 上传地址有效期(分钟)
31+
UPLOAD_TIMEOUT = time.Minute * 30 // 上传请求超时时间
32+
UPLOAD_RETRY_COUNT = 3
33+
UPLOAD_RETRY_WAIT_TIME = time.Second * 1
34+
UPLOAD_RETRY_MAX_WAIT_TIME = time.Second * 5
35+
)
36+
2537
var config = driver.Config{
2638
Name: "BaiduNetdisk",
2739
DefaultRoot: "/",

0 commit comments

Comments
 (0)