Skip to content

Commit

Permalink
feat(baidu_netdisk): add retry on most operations; improve stability
Browse files Browse the repository at this point in the history
general: add local temp file creation checking of file size, avoid incomplete stream
  • Loading branch information
SeanHeuc committed Aug 6, 2023
1 parent 3ec8629 commit fda3e5c
Show file tree
Hide file tree
Showing 17 changed files with 121 additions and 88 deletions.
2 changes: 1 addition & 1 deletion drivers/115/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ func (d *Pan115) Remove(ctx context.Context, obj model.Obj) error {
}

func (d *Pan115) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error {
tempFile, err := utils.CreateTempFile(stream.GetReadCloser())
tempFile, err := utils.CreateTempFile(stream.GetReadCloser(), stream.GetSize())
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion drivers/123/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ func (d *Pan123) Put(ctx context.Context, dstDir model.Obj, stream model.FileStr
// const DEFAULT int64 = 10485760
h := md5.New()
// need to calculate md5 of the full content
tempFile, err := utils.CreateTempFile(stream.GetReadCloser())
tempFile, err := utils.CreateTempFile(stream.GetReadCloser(), stream.GetSize())
if err != nil {
return err
}
Expand Down
4 changes: 2 additions & 2 deletions drivers/189pc/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ func (y *Cloud189PC) StreamUpload(ctx context.Context, dstDir model.Obj, file mo
// 快传
func (y *Cloud189PC) FastUpload(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) {
// 需要获取完整文件md5,必须支持 io.Seek
tempFile, err := utils.CreateTempFile(file.GetReadCloser())
tempFile, err := utils.CreateTempFile(file.GetReadCloser(), file.GetSize())
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -672,7 +672,7 @@ func (y *Cloud189PC) FastUpload(ctx context.Context, dstDir model.Obj, file mode
// 旧版本上传,家庭云不支持覆盖
func (y *Cloud189PC) OldUpload(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) {
// 需要获取完整文件md5,必须支持 io.Seek
tempFile, err := utils.CreateTempFile(file.GetReadCloser())
tempFile, err := utils.CreateTempFile(file.GetReadCloser(), file.GetSize())
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion drivers/aliyundrive_open/upload.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ func (d *AliyundriveOpen) upload(ctx context.Context, dstDir model.Obj, stream m
}
log.Debugf("[aliyundrive_open] pre_hash matched, start rapid upload")
// convert to local file
file, err := utils.CreateTempFile(stream)
file, err := utils.CreateTempFile(stream, stream.GetSize())
if err != nil {
return err
}
Expand Down
110 changes: 65 additions & 45 deletions drivers/baidu_netdisk/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,29 @@ import (
"crypto/md5"
"encoding/hex"
"fmt"
"github.com/alist-org/alist/v3/drivers/base"
"github.com/alist-org/alist/v3/internal/driver"
"github.com/alist-org/alist/v3/internal/errs"
"github.com/alist-org/alist/v3/internal/model"
"github.com/alist-org/alist/v3/pkg/utils"
"github.com/avast/retry-go"
log "github.com/sirupsen/logrus"
"io"
"math"
"os"
stdpath "path"
"strconv"
"strings"

"github.com/alist-org/alist/v3/drivers/base"
"github.com/alist-org/alist/v3/internal/driver"
"github.com/alist-org/alist/v3/internal/model"
"github.com/alist-org/alist/v3/pkg/utils"
log "github.com/sirupsen/logrus"
)

type BaiduNetdisk struct {
model.Storage
Addition
}

const BaiduFileAPI = "https://d.pcs.baidu.com/rest/2.0/pcs/superfile2"
const DefaultSliceSize int64 = 4 * 1024 * 1024

func (d *BaiduNetdisk) Config() driver.Config {
return config
}
Expand Down Expand Up @@ -108,64 +112,63 @@ func (d *BaiduNetdisk) Remove(ctx context.Context, obj model.Obj) error {
}

func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error {
tempFile, err := utils.CreateTempFile(stream.GetReadCloser())
streamSize := stream.GetSize()

tempFile, err := utils.CreateTempFile(stream.GetReadCloser(), stream.GetSize())
if err != nil {
return err
}
defer func() {
_ = tempFile.Close()
_ = os.Remove(tempFile.Name())
}()
var Default int64 = 4 * 1024 * 1024
count := int(math.Ceil(float64(stream.GetSize()) / float64(Default)))
var SliceSize int64 = 256 * 1024

count := int(math.Ceil(float64(streamSize) / float64(DefaultSliceSize)))
//cal md5 for first 256k data
const SliceSize int64 = 256 * 1024
// cal md5
h1 := md5.New()
h2 := md5.New()
block_list := make([]string, 0)
content_md5 := ""
slice_md5 := ""
left := stream.GetSize()
blockList := make([]string, 0)
contentMd5 := ""
sliceMd5 := ""
left := streamSize
for i := 0; i < count; i++ {
byteSize := Default
if left < Default {
byteSize := DefaultSliceSize
if left < DefaultSliceSize {
byteSize = left
}
left -= byteSize
_, err = io.Copy(io.MultiWriter(h1, h2), io.LimitReader(tempFile, byteSize))
if err != nil {
return err
}
block_list = append(block_list, fmt.Sprintf("\"%s\"", hex.EncodeToString(h2.Sum(nil))))
blockList = append(blockList, fmt.Sprintf("\"%s\"", hex.EncodeToString(h2.Sum(nil))))
h2.Reset()
}
content_md5 = hex.EncodeToString(h1.Sum(nil))
contentMd5 = hex.EncodeToString(h1.Sum(nil))
_, err = tempFile.Seek(0, io.SeekStart)
if err != nil {
return err
}
if stream.GetSize() <= SliceSize {
slice_md5 = content_md5
if streamSize <= SliceSize {
sliceMd5 = contentMd5
} else {
sliceData := make([]byte, SliceSize)
_, err = io.ReadFull(tempFile, sliceData)
if err != nil {
return err
}
h2.Write(sliceData)
slice_md5 = hex.EncodeToString(h2.Sum(nil))
_, err = tempFile.Seek(0, io.SeekStart)
if err != nil {
return err
}
sliceMd5 = hex.EncodeToString(h2.Sum(nil))
}
rawPath := stdpath.Join(dstDir.GetPath(), stream.GetName())
path := encodeURIComponent(rawPath)
block_list_str := fmt.Sprintf("[%s]", strings.Join(block_list, ","))
block_list_str := fmt.Sprintf("[%s]", strings.Join(blockList, ","))
data := fmt.Sprintf("path=%s&size=%d&isdir=0&autoinit=1&block_list=%s&content-md5=%s&slice-md5=%s",
path, stream.GetSize(),
path, streamSize,
block_list_str,
content_md5, slice_md5)
contentMd5, sliceMd5)
params := map[string]string{
"method": "precreate",
}
Expand All @@ -177,6 +180,7 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F
}
log.Debugf("%+v", precreateResp)
if precreateResp.ReturnType == 2 {
//rapid upload, since got md5 match from baidu server
return nil
}
params = map[string]string{
Expand All @@ -186,33 +190,49 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F
"path": path,
"uploadid": precreateResp.Uploadid,
}
left = stream.GetSize()

var offset int64 = 0
for i, partseq := range precreateResp.BlockList {
if utils.IsCanceled(ctx) {
return ctx.Err()
}
byteSize := Default
if left < Default {
byteSize = left
}
left -= byteSize
u := "https://d.pcs.baidu.com/rest/2.0/pcs/superfile2"
params["partseq"] = strconv.Itoa(partseq)
res, err := base.RestyClient.R().
SetContext(ctx).
SetQueryParams(params).
SetFileReader("file", stream.GetName(), io.LimitReader(tempFile, byteSize)).
Post(u)
byteSize := int64(math.Min(float64(streamSize-offset), float64(DefaultSliceSize)))
err := retry.Do(func() error {
return d.uploadSlice(ctx, &params, stream.GetName(), tempFile, offset, byteSize)
},
retry.Context(ctx),
retry.Attempts(3))
if err != nil {
return err
}
log.Debugln(res.String())
offset += byteSize

if len(precreateResp.BlockList) > 0 {
up(i * 100 / len(precreateResp.BlockList))
}
}
_, err = d.create(rawPath, stream.GetSize(), 0, precreateResp.Uploadid, block_list_str)
_, err = d.create(rawPath, streamSize, 0, precreateResp.Uploadid, block_list_str)
return err
}
func (d *BaiduNetdisk) uploadSlice(ctx context.Context, params *map[string]string, fileName string, file *os.File, offset int64, byteSize int64) error {
_, err := file.Seek(offset, io.SeekStart)
if err != nil {
return err
}

res, err := base.RestyClient.R().
SetContext(ctx).
SetQueryParams(*params).
SetFileReader("file", fileName, io.LimitReader(file, byteSize)).
Post(BaiduFileAPI)
if err != nil {
return err
}
log.Debugln(res.RawResponse.Status + res.String())
errCode := utils.Json.Get(res.Body(), "error_code").ToInt()
errNo := utils.Json.Get(res.Body(), "errno").ToInt()
if errCode != 0 || errNo != 0 {
return errs.NewErr(errs.StreamIncomplete, "error in uploading to baidu, will retry. response=%s", res.String())
}
return nil
}

var _ driver.Driver = (*BaiduNetdisk)(nil)
53 changes: 30 additions & 23 deletions drivers/baidu_netdisk/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package baidu_netdisk

import (
"fmt"
"github.com/avast/retry-go"
"net/http"
"net/url"
"strconv"
Expand Down Expand Up @@ -51,31 +52,37 @@ func (d *BaiduNetdisk) _refreshToken() error {
}

func (d *BaiduNetdisk) request(furl string, method string, callback base.ReqCallback, resp interface{}) ([]byte, error) {
req := base.RestyClient.R()
req.SetQueryParam("access_token", d.AccessToken)
if callback != nil {
callback(req)
}
if resp != nil {
req.SetResult(resp)
}
res, err := req.Execute(method, furl)
if err != nil {
return nil, err
}
log.Debugf("[baidu_netdisk] req: %s, resp: %s", furl, res.String())
errno := utils.Json.Get(res.Body(), "errno").ToInt()
if errno != 0 {
if utils.SliceContains([]int{111, -6}, errno) {
err = d.refreshToken()
if err != nil {
return nil, err
var result []byte
err := retry.Do(func() error {
req := base.RestyClient.R()
req.SetQueryParam("access_token", d.AccessToken)
if callback != nil {
callback(req)
}
if resp != nil {
req.SetResult(resp)
}
res, err := req.Execute(method, furl)
if err != nil {
return err
}
log.Debugf("[baidu_netdisk] req: %s, resp: %s", furl, res.String())
errno := utils.Json.Get(res.Body(), "errno").ToInt()
if errno != 0 {
if utils.SliceContains([]int{111, -6}, errno) {
log.Info("refreshing baidu_netdisk token.")
err2 := d.refreshToken()
if err2 != nil {
return err2
}
}
return d.request(furl, method, callback, resp)
return fmt.Errorf("req: [%s] ,errno: %d, refer to https://pan.baidu.com/union/doc/", furl, errno)
}
return nil, fmt.Errorf("req: [%s] ,errno: %d, refer to https://pan.baidu.com/union/doc/", furl, errno)
}
return res.Body(), nil
result = res.Body()
return nil
},
retry.Attempts(3))
return result, err
}

func (d *BaiduNetdisk) get(pathname string, params map[string]string, resp interface{}) ([]byte, error) {
Expand Down
2 changes: 1 addition & 1 deletion drivers/baidu_photo/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ func (d *BaiduPhoto) Remove(ctx context.Context, obj model.Obj) error {

func (d *BaiduPhoto) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) {
// 需要获取完整文件md5,必须支持 io.Seek
tempFile, err := utils.CreateTempFile(stream.GetReadCloser())
tempFile, err := utils.CreateTempFile(stream.GetReadCloser(), stream.GetSize())
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion drivers/mediatrack/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ func (d *MediaTrack) Put(ctx context.Context, dstDir model.Obj, stream model.Fil
if err != nil {
return err
}
tempFile, err := utils.CreateTempFile(stream.GetReadCloser())
tempFile, err := utils.CreateTempFile(stream.GetReadCloser(), stream.GetSize())
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion drivers/mopan/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ func (d *MoPan) Remove(ctx context.Context, obj model.Obj) error {
}

func (d *MoPan) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) {
file, err := utils.CreateTempFile(stream)
file, err := utils.CreateTempFile(stream, stream.GetSize())
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion drivers/pikpak/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ func (d *PikPak) Remove(ctx context.Context, obj model.Obj) error {
}

func (d *PikPak) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error {
tempFile, err := utils.CreateTempFile(stream.GetReadCloser())
tempFile, err := utils.CreateTempFile(stream.GetReadCloser(), stream.GetSize())
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion drivers/quark_uc/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ func (d *QuarkOrUC) Remove(ctx context.Context, obj model.Obj) error {
}

func (d *QuarkOrUC) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error {
tempFile, err := utils.CreateTempFile(stream.GetReadCloser())
tempFile, err := utils.CreateTempFile(stream.GetReadCloser(), stream.GetSize())
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion drivers/terabox/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ func (d *Terabox) Remove(ctx context.Context, obj model.Obj) error {
}

func (d *Terabox) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error {
tempFile, err := utils.CreateTempFile(stream.GetReadCloser())
tempFile, err := utils.CreateTempFile(stream.GetReadCloser(), stream.GetSize())
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion drivers/thunder/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ func (xc *XunLeiCommon) Remove(ctx context.Context, obj model.Obj) error {
}

func (xc *XunLeiCommon) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error {
tempFile, err := utils.CreateTempFile(stream.GetReadCloser())
tempFile, err := utils.CreateTempFile(stream.GetReadCloser(), stream.GetSize())
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion drivers/weiyun/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ func (d *WeiYun) Remove(ctx context.Context, obj model.Obj) error {

func (d *WeiYun) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) {
if folder, ok := dstDir.(*Folder); ok {
file, err := utils.CreateTempFile(stream)
file, err := utils.CreateTempFile(stream, stream.GetSize())
if err != nil {
return nil, err
}
Expand Down
5 changes: 3 additions & 2 deletions internal/errs/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@ var (
MoveBetweenTwoStorages = errors.New("can't move files between two storages, try to copy")
UploadNotSupported = errors.New("upload not supported")

MetaNotFound = errors.New("meta not found")
StorageNotFound = errors.New("storage not found")
MetaNotFound = errors.New("meta not found")
StorageNotFound = errors.New("storage not found")
StreamIncomplete = errors.New("upload/download stream incomplete, possible network issue")
)

// NewErr wrap constant error with an extra message
Expand Down
Loading

0 comments on commit fda3e5c

Please sign in to comment.