From d2831f77c03f661c79784ea534b0ef60b41a6c86 Mon Sep 17 00:00:00 2001 From: Sean Date: Mon, 7 Aug 2023 00:31:53 +0800 Subject: [PATCH] feat(baidu_netdisk): add retry on most operations; improve stability general: add local temp file creation checking of file size, avoid incomplete stream (cherry picked from commit fda3e5c2b48661c2a3241a52a93c8c994a4efea8) --- drivers/115/driver.go | 2 +- drivers/123/driver.go | 2 +- drivers/189pc/utils.go | 4 +- drivers/aliyundrive_open/upload.go | 2 +- drivers/baidu_netdisk/driver.go | 110 +++++++++++++++++------------ drivers/baidu_netdisk/util.go | 53 ++++++++------ drivers/baidu_photo/driver.go | 2 +- drivers/mediatrack/driver.go | 2 +- drivers/mopan/driver.go | 2 +- drivers/pikpak/driver.go | 2 +- drivers/quark_uc/driver.go | 2 +- drivers/terabox/driver.go | 2 +- drivers/thunder/driver.go | 2 +- drivers/weiyun/driver.go | 2 +- internal/errs/errors.go | 5 +- internal/fs/put.go | 2 +- pkg/utils/file.go | 13 ++-- 17 files changed, 121 insertions(+), 88 deletions(-) diff --git a/drivers/115/driver.go b/drivers/115/driver.go index ead83f6fa8b..b9554c89e12 100644 --- a/drivers/115/driver.go +++ b/drivers/115/driver.go @@ -83,7 +83,7 @@ func (d *Pan115) Remove(ctx context.Context, obj model.Obj) error { } func (d *Pan115) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { - tempFile, err := utils.CreateTempFile(stream.GetReadCloser()) + tempFile, err := utils.CreateTempFile(stream.GetReadCloser(), stream.GetSize()) if err != nil { return err } diff --git a/drivers/123/driver.go b/drivers/123/driver.go index ec2bf7147e0..bc1758a176b 100644 --- a/drivers/123/driver.go +++ b/drivers/123/driver.go @@ -184,7 +184,7 @@ func (d *Pan123) Put(ctx context.Context, dstDir model.Obj, stream model.FileStr // const DEFAULT int64 = 10485760 h := md5.New() // need to calculate md5 of the full content - tempFile, err := utils.CreateTempFile(stream.GetReadCloser()) + tempFile, err := utils.CreateTempFile(stream.GetReadCloser(), stream.GetSize()) if err != nil { return err } diff --git a/drivers/189pc/utils.go b/drivers/189pc/utils.go index 5de99e2c031..96bf3aceb2b 100644 --- a/drivers/189pc/utils.go +++ b/drivers/189pc/utils.go @@ -545,7 +545,7 @@ func (y *Cloud189PC) StreamUpload(ctx context.Context, dstDir model.Obj, file mo // 快传 func (y *Cloud189PC) FastUpload(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) { // 需要获取完整文件md5,必须支持 io.Seek - tempFile, err := utils.CreateTempFile(file.GetReadCloser()) + tempFile, err := utils.CreateTempFile(file.GetReadCloser(), file.GetSize()) if err != nil { return nil, err } @@ -672,7 +672,7 @@ func (y *Cloud189PC) FastUpload(ctx context.Context, dstDir model.Obj, file mode // 旧版本上传,家庭云不支持覆盖 func (y *Cloud189PC) OldUpload(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) { // 需要获取完整文件md5,必须支持 io.Seek - tempFile, err := utils.CreateTempFile(file.GetReadCloser()) + tempFile, err := utils.CreateTempFile(file.GetReadCloser(), file.GetSize()) if err != nil { return nil, err } diff --git a/drivers/aliyundrive_open/upload.go b/drivers/aliyundrive_open/upload.go index a2b3ca5938a..f7bb7f28567 100644 --- a/drivers/aliyundrive_open/upload.go +++ b/drivers/aliyundrive_open/upload.go @@ -224,7 +224,7 @@ func (d *AliyundriveOpen) upload(ctx context.Context, dstDir model.Obj, stream m } log.Debugf("[aliyundrive_open] pre_hash matched, start rapid upload") // convert to local file - file, err := utils.CreateTempFile(stream) + file, err := utils.CreateTempFile(stream, stream.GetSize()) if err != nil { return err } diff --git a/drivers/baidu_netdisk/driver.go b/drivers/baidu_netdisk/driver.go index c81225e4fa6..470d3b062f0 100644 --- a/drivers/baidu_netdisk/driver.go +++ b/drivers/baidu_netdisk/driver.go @@ -5,18 +5,19 @@ import ( "crypto/md5" "encoding/hex" "fmt" + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/avast/retry-go" + log "github.com/sirupsen/logrus" "io" "math" "os" stdpath "path" "strconv" "strings" - - "github.com/alist-org/alist/v3/drivers/base" - "github.com/alist-org/alist/v3/internal/driver" - "github.com/alist-org/alist/v3/internal/model" - "github.com/alist-org/alist/v3/pkg/utils" - log "github.com/sirupsen/logrus" ) type BaiduNetdisk struct { @@ -24,6 +25,9 @@ type BaiduNetdisk struct { Addition } +const BaiduFileAPI = "https://d.pcs.baidu.com/rest/2.0/pcs/superfile2" +const DefaultSliceSize int64 = 4 * 1024 * 1024 + func (d *BaiduNetdisk) Config() driver.Config { return config } @@ -108,7 +112,9 @@ func (d *BaiduNetdisk) Remove(ctx context.Context, obj model.Obj) error { } func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { - tempFile, err := utils.CreateTempFile(stream.GetReadCloser()) + streamSize := stream.GetSize() + + tempFile, err := utils.CreateTempFile(stream.GetReadCloser(), stream.GetSize()) if err != nil { return err } @@ -116,19 +122,20 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F _ = tempFile.Close() _ = os.Remove(tempFile.Name()) }() - var Default int64 = 4 * 1024 * 1024 - count := int(math.Ceil(float64(stream.GetSize()) / float64(Default))) - var SliceSize int64 = 256 * 1024 + + count := int(math.Ceil(float64(streamSize) / float64(DefaultSliceSize))) + //cal md5 for first 256k data + const SliceSize int64 = 256 * 1024 // cal md5 h1 := md5.New() h2 := md5.New() - block_list := make([]string, 0) - content_md5 := "" - slice_md5 := "" - left := stream.GetSize() + blockList := make([]string, 0) + contentMd5 := "" + sliceMd5 := "" + left := streamSize for i := 0; i < count; i++ { - byteSize := Default - if left < Default { + byteSize := DefaultSliceSize + if left < DefaultSliceSize { byteSize = left } left -= byteSize @@ -136,16 +143,16 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F if err != nil { return err } - block_list = append(block_list, fmt.Sprintf("\"%s\"", hex.EncodeToString(h2.Sum(nil)))) + blockList = append(blockList, fmt.Sprintf("\"%s\"", hex.EncodeToString(h2.Sum(nil)))) h2.Reset() } - content_md5 = hex.EncodeToString(h1.Sum(nil)) + contentMd5 = hex.EncodeToString(h1.Sum(nil)) _, err = tempFile.Seek(0, io.SeekStart) if err != nil { return err } - if stream.GetSize() <= SliceSize { - slice_md5 = content_md5 + if streamSize <= SliceSize { + sliceMd5 = contentMd5 } else { sliceData := make([]byte, SliceSize) _, err = io.ReadFull(tempFile, sliceData) @@ -153,19 +160,15 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F return err } h2.Write(sliceData) - slice_md5 = hex.EncodeToString(h2.Sum(nil)) - _, err = tempFile.Seek(0, io.SeekStart) - if err != nil { - return err - } + sliceMd5 = hex.EncodeToString(h2.Sum(nil)) } rawPath := stdpath.Join(dstDir.GetPath(), stream.GetName()) path := encodeURIComponent(rawPath) - block_list_str := fmt.Sprintf("[%s]", strings.Join(block_list, ",")) + block_list_str := fmt.Sprintf("[%s]", strings.Join(blockList, ",")) data := fmt.Sprintf("path=%s&size=%d&isdir=0&autoinit=1&block_list=%s&content-md5=%s&slice-md5=%s", - path, stream.GetSize(), + path, streamSize, block_list_str, - content_md5, slice_md5) + contentMd5, sliceMd5) params := map[string]string{ "method": "precreate", } @@ -177,6 +180,7 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F } log.Debugf("%+v", precreateResp) if precreateResp.ReturnType == 2 { + //rapid upload, since got md5 match from baidu server return nil } params = map[string]string{ @@ -186,33 +190,49 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F "path": path, "uploadid": precreateResp.Uploadid, } - left = stream.GetSize() + + var offset int64 = 0 for i, partseq := range precreateResp.BlockList { - if utils.IsCanceled(ctx) { - return ctx.Err() - } - byteSize := Default - if left < Default { - byteSize = left - } - left -= byteSize - u := "https://d.pcs.baidu.com/rest/2.0/pcs/superfile2" params["partseq"] = strconv.Itoa(partseq) - res, err := base.RestyClient.R(). - SetContext(ctx). - SetQueryParams(params). - SetFileReader("file", stream.GetName(), io.LimitReader(tempFile, byteSize)). - Post(u) + byteSize := int64(math.Min(float64(streamSize-offset), float64(DefaultSliceSize))) + err := retry.Do(func() error { + return d.uploadSlice(ctx, ¶ms, stream.GetName(), tempFile, offset, byteSize) + }, + retry.Context(ctx), + retry.Attempts(3)) if err != nil { return err } - log.Debugln(res.String()) + offset += byteSize + if len(precreateResp.BlockList) > 0 { up(i * 100 / len(precreateResp.BlockList)) } } - _, err = d.create(rawPath, stream.GetSize(), 0, precreateResp.Uploadid, block_list_str) + _, err = d.create(rawPath, streamSize, 0, precreateResp.Uploadid, block_list_str) return err } +func (d *BaiduNetdisk) uploadSlice(ctx context.Context, params *map[string]string, fileName string, file *os.File, offset int64, byteSize int64) error { + _, err := file.Seek(offset, io.SeekStart) + if err != nil { + return err + } + + res, err := base.RestyClient.R(). + SetContext(ctx). + SetQueryParams(*params). + SetFileReader("file", fileName, io.LimitReader(file, byteSize)). + Post(BaiduFileAPI) + if err != nil { + return err + } + log.Debugln(res.RawResponse.Status + res.String()) + errCode := utils.Json.Get(res.Body(), "error_code").ToInt() + errNo := utils.Json.Get(res.Body(), "errno").ToInt() + if errCode != 0 || errNo != 0 { + return errs.NewErr(errs.StreamIncomplete, "error in uploading to baidu, will retry. response=%s", res.String()) + } + return nil +} var _ driver.Driver = (*BaiduNetdisk)(nil) diff --git a/drivers/baidu_netdisk/util.go b/drivers/baidu_netdisk/util.go index 5e863036e39..bb344967a9f 100644 --- a/drivers/baidu_netdisk/util.go +++ b/drivers/baidu_netdisk/util.go @@ -2,6 +2,7 @@ package baidu_netdisk import ( "fmt" + "github.com/avast/retry-go" "net/http" "net/url" "strconv" @@ -51,31 +52,37 @@ func (d *BaiduNetdisk) _refreshToken() error { } func (d *BaiduNetdisk) request(furl string, method string, callback base.ReqCallback, resp interface{}) ([]byte, error) { - req := base.RestyClient.R() - req.SetQueryParam("access_token", d.AccessToken) - if callback != nil { - callback(req) - } - if resp != nil { - req.SetResult(resp) - } - res, err := req.Execute(method, furl) - if err != nil { - return nil, err - } - log.Debugf("[baidu_netdisk] req: %s, resp: %s", furl, res.String()) - errno := utils.Json.Get(res.Body(), "errno").ToInt() - if errno != 0 { - if utils.SliceContains([]int{111, -6}, errno) { - err = d.refreshToken() - if err != nil { - return nil, err + var result []byte + err := retry.Do(func() error { + req := base.RestyClient.R() + req.SetQueryParam("access_token", d.AccessToken) + if callback != nil { + callback(req) + } + if resp != nil { + req.SetResult(resp) + } + res, err := req.Execute(method, furl) + if err != nil { + return err + } + log.Debugf("[baidu_netdisk] req: %s, resp: %s", furl, res.String()) + errno := utils.Json.Get(res.Body(), "errno").ToInt() + if errno != 0 { + if utils.SliceContains([]int{111, -6}, errno) { + log.Info("refreshing baidu_netdisk token.") + err2 := d.refreshToken() + if err2 != nil { + return err2 + } } - return d.request(furl, method, callback, resp) + return fmt.Errorf("req: [%s] ,errno: %d, refer to https://pan.baidu.com/union/doc/", furl, errno) } - return nil, fmt.Errorf("req: [%s] ,errno: %d, refer to https://pan.baidu.com/union/doc/", furl, errno) - } - return res.Body(), nil + result = res.Body() + return nil + }, + retry.Attempts(3)) + return result, err } func (d *BaiduNetdisk) get(pathname string, params map[string]string, resp interface{}) ([]byte, error) { diff --git a/drivers/baidu_photo/driver.go b/drivers/baidu_photo/driver.go index ff47d25f784..760a5976f0a 100644 --- a/drivers/baidu_photo/driver.go +++ b/drivers/baidu_photo/driver.go @@ -212,7 +212,7 @@ func (d *BaiduPhoto) Remove(ctx context.Context, obj model.Obj) error { func (d *BaiduPhoto) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) { // 需要获取完整文件md5,必须支持 io.Seek - tempFile, err := utils.CreateTempFile(stream.GetReadCloser()) + tempFile, err := utils.CreateTempFile(stream.GetReadCloser(), stream.GetSize()) if err != nil { return nil, err } diff --git a/drivers/mediatrack/driver.go b/drivers/mediatrack/driver.go index c9937505478..eeed29ad1fc 100644 --- a/drivers/mediatrack/driver.go +++ b/drivers/mediatrack/driver.go @@ -181,7 +181,7 @@ func (d *MediaTrack) Put(ctx context.Context, dstDir model.Obj, stream model.Fil if err != nil { return err } - tempFile, err := utils.CreateTempFile(stream.GetReadCloser()) + tempFile, err := utils.CreateTempFile(stream.GetReadCloser(), stream.GetSize()) if err != nil { return err } diff --git a/drivers/mopan/driver.go b/drivers/mopan/driver.go index dfe8399cc43..edbcfe3a212 100644 --- a/drivers/mopan/driver.go +++ b/drivers/mopan/driver.go @@ -212,7 +212,7 @@ func (d *MoPan) Remove(ctx context.Context, obj model.Obj) error { } func (d *MoPan) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) { - file, err := utils.CreateTempFile(stream) + file, err := utils.CreateTempFile(stream, stream.GetSize()) if err != nil { return nil, err } diff --git a/drivers/pikpak/driver.go b/drivers/pikpak/driver.go index ddaadef0fe7..a86a75390da 100644 --- a/drivers/pikpak/driver.go +++ b/drivers/pikpak/driver.go @@ -124,7 +124,7 @@ func (d *PikPak) Remove(ctx context.Context, obj model.Obj) error { } func (d *PikPak) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { - tempFile, err := utils.CreateTempFile(stream.GetReadCloser()) + tempFile, err := utils.CreateTempFile(stream.GetReadCloser(), stream.GetSize()) if err != nil { return err } diff --git a/drivers/quark_uc/driver.go b/drivers/quark_uc/driver.go index a59b0bcd909..4969af5a70e 100644 --- a/drivers/quark_uc/driver.go +++ b/drivers/quark_uc/driver.go @@ -136,7 +136,7 @@ func (d *QuarkOrUC) Remove(ctx context.Context, obj model.Obj) error { } func (d *QuarkOrUC) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { - tempFile, err := utils.CreateTempFile(stream.GetReadCloser()) + tempFile, err := utils.CreateTempFile(stream.GetReadCloser(), stream.GetSize()) if err != nil { return err } diff --git a/drivers/terabox/driver.go b/drivers/terabox/driver.go index e87d3ad7053..4c4ad8b58cf 100644 --- a/drivers/terabox/driver.go +++ b/drivers/terabox/driver.go @@ -116,7 +116,7 @@ func (d *Terabox) Remove(ctx context.Context, obj model.Obj) error { } func (d *Terabox) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { - tempFile, err := utils.CreateTempFile(stream.GetReadCloser()) + tempFile, err := utils.CreateTempFile(stream.GetReadCloser(), stream.GetSize()) if err != nil { return err } diff --git a/drivers/thunder/driver.go b/drivers/thunder/driver.go index d933753527a..8b91b5a954a 100644 --- a/drivers/thunder/driver.go +++ b/drivers/thunder/driver.go @@ -333,7 +333,7 @@ func (xc *XunLeiCommon) Remove(ctx context.Context, obj model.Obj) error { } func (xc *XunLeiCommon) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { - tempFile, err := utils.CreateTempFile(stream.GetReadCloser()) + tempFile, err := utils.CreateTempFile(stream.GetReadCloser(), stream.GetSize()) if err != nil { return err } diff --git a/drivers/weiyun/driver.go b/drivers/weiyun/driver.go index 84088566e2c..3bd622a2456 100644 --- a/drivers/weiyun/driver.go +++ b/drivers/weiyun/driver.go @@ -298,7 +298,7 @@ func (d *WeiYun) Remove(ctx context.Context, obj model.Obj) error { func (d *WeiYun) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) { if folder, ok := dstDir.(*Folder); ok { - file, err := utils.CreateTempFile(stream) + file, err := utils.CreateTempFile(stream, stream.GetSize()) if err != nil { return nil, err } diff --git a/internal/errs/errors.go b/internal/errs/errors.go index d3345851163..0cab41356f8 100644 --- a/internal/errs/errors.go +++ b/internal/errs/errors.go @@ -14,8 +14,9 @@ var ( MoveBetweenTwoStorages = errors.New("can't move files between two storages, try to copy") UploadNotSupported = errors.New("upload not supported") - MetaNotFound = errors.New("meta not found") - StorageNotFound = errors.New("storage not found") + MetaNotFound = errors.New("meta not found") + StorageNotFound = errors.New("storage not found") + StreamIncomplete = errors.New("upload/download stream incomplete, possible network issue") ) // NewErr wrap constant error with an extra message diff --git a/internal/fs/put.go b/internal/fs/put.go index 029b37334ad..41f6b8db5b1 100644 --- a/internal/fs/put.go +++ b/internal/fs/put.go @@ -27,7 +27,7 @@ func putAsTask(dstDirPath string, file *model.FileStream) error { return errors.WithStack(errs.UploadNotSupported) } if file.NeedStore() { - tempFile, err := utils.CreateTempFile(file) + tempFile, err := utils.CreateTempFile(file, file.GetSize()) if err != nil { return errors.Wrapf(err, "failed to create temp file") } diff --git a/pkg/utils/file.go b/pkg/utils/file.go index 2a4fc6edaeb..1eb00cc700f 100644 --- a/pkg/utils/file.go +++ b/pkg/utils/file.go @@ -2,6 +2,7 @@ package utils import ( "fmt" + "github.com/alist-org/alist/v3/internal/errs" "io" "mime" "os" @@ -111,7 +112,7 @@ func CreateNestedFile(path string) (*os.File, error) { } // CreateTempFile create temp file from io.ReadCloser, and seek to 0 -func CreateTempFile(r io.ReadCloser) (*os.File, error) { +func CreateTempFile(r io.ReadCloser, size int64) (*os.File, error) { if f, ok := r.(*os.File); ok { return f, nil } @@ -119,15 +120,19 @@ func CreateTempFile(r io.ReadCloser) (*os.File, error) { if err != nil { return nil, err } - _, err = io.Copy(f, r) + readBytes, err := io.Copy(f, r) if err != nil { _ = os.Remove(f.Name()) - return nil, err + return nil, errs.NewErr(err, "CreateTempFile failed") + } + if size != 0 && readBytes != size { + _ = os.Remove(f.Name()) + return nil, errs.NewErr(err, "CreateTempFile failed, incoming stream actual size= %s, expect = %s ", readBytes, size) } _, err = f.Seek(0, io.SeekStart) if err != nil { _ = os.Remove(f.Name()) - return nil, err + return nil, errs.NewErr(err, "CreateTempFile failed, can't seek to 0 ") } return f, nil }