Skip to content

Commit

Permalink
fix(smb): remount smb before each operation (close #2123 pr #2140)
Browse files Browse the repository at this point in the history
  • Loading branch information
BoYanZh committed Oct 30, 2022
1 parent 18165eb commit a3b631f
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 1 deletion.
44 changes: 43 additions & 1 deletion drivers/smb/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"path/filepath"
"time"

"github.com/alist-org/alist/v3/internal/driver"
"github.com/alist-org/alist/v3/internal/model"
Expand All @@ -15,7 +16,8 @@ import (
type SMB struct {
model.Storage
Addition
fs *smb2.Share
fs *smb2.Share
lastConnTime time.Time
}

func (d *SMB) Config() driver.Config {
Expand Down Expand Up @@ -43,11 +45,16 @@ func (d *SMB) Drop(ctx context.Context) error {
}

func (d *SMB) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) {
if err := d.checkConn(); err != nil {
return nil, err
}
fullPath := d.getSMBPath(dir)
rawFiles, err := d.fs.ReadDir(fullPath)
if err != nil {
d.cleanLastConnTime()
return nil, err
}
d.updateLastConnTime()
var files []model.Obj
for _, f := range rawFiles {
file := model.ObjThumb{
Expand All @@ -69,46 +76,69 @@ func (d *SMB) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]m
//}

func (d *SMB) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) {
if err := d.checkConn(); err != nil {
return nil, err
}
fullPath := d.getSMBPath(file)
remoteFile, err := d.fs.Open(fullPath)
if err != nil {
d.cleanLastConnTime()
return nil, err
}
d.updateLastConnTime()
return &model.Link{
Data: remoteFile,
}, nil
}

func (d *SMB) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error {
if err := d.checkConn(); err != nil {
return err
}
fullPath := filepath.Join(d.getSMBPath(parentDir), dirName)
err := d.fs.MkdirAll(fullPath, 0700)
if err != nil {
d.cleanLastConnTime()
return err
}
d.updateLastConnTime()
return nil
}

func (d *SMB) Move(ctx context.Context, srcObj, dstDir model.Obj) error {
if err := d.checkConn(); err != nil {
return err
}
srcPath := d.getSMBPath(srcObj)
dstPath := filepath.Join(d.getSMBPath(dstDir), srcObj.GetName())
err := d.fs.Rename(srcPath, dstPath)
if err != nil {
d.cleanLastConnTime()
return err
}
d.updateLastConnTime()
return nil
}

func (d *SMB) Rename(ctx context.Context, srcObj model.Obj, newName string) error {
if err := d.checkConn(); err != nil {
return err
}
srcPath := d.getSMBPath(srcObj)
dstPath := filepath.Join(filepath.Dir(srcPath), newName)
err := d.fs.Rename(srcPath, dstPath)
if err != nil {
d.cleanLastConnTime()
return err
}
d.updateLastConnTime()
return nil
}

func (d *SMB) Copy(ctx context.Context, srcObj, dstDir model.Obj) error {
if err := d.checkConn(); err != nil {
return err
}
srcPath := d.getSMBPath(srcObj)
dstPath := filepath.Join(d.getSMBPath(dstDir), srcObj.GetName())
var err error
Expand All @@ -118,12 +148,17 @@ func (d *SMB) Copy(ctx context.Context, srcObj, dstDir model.Obj) error {
err = d.CopyFile(srcPath, dstPath)
}
if err != nil {
d.cleanLastConnTime()
return err
}
d.updateLastConnTime()
return nil
}

func (d *SMB) Remove(ctx context.Context, obj model.Obj) error {
if err := d.checkConn(); err != nil {
return err
}
var err error
fullPath := d.getSMBPath(obj)
if obj.IsDir() {
Expand All @@ -132,17 +167,24 @@ func (d *SMB) Remove(ctx context.Context, obj model.Obj) error {
err = d.fs.Remove(fullPath)
}
if err != nil {
d.cleanLastConnTime()
return err
}
d.updateLastConnTime()
return nil
}

func (d *SMB) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error {
if err := d.checkConn(); err != nil {
return err
}
fullPath := filepath.Join(d.getSMBPath(dstDir), stream.GetName())
out, err := d.fs.Create(fullPath)
if err != nil {
d.cleanLastConnTime()
return err
}
d.updateLastConnTime()
defer func() {
_ = out.Close()
if errors.Is(err, context.Canceled) {
Expand Down
20 changes: 20 additions & 0 deletions drivers/smb/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,20 @@ import (
"net"
"os"
"path/filepath"
"time"

"github.com/alist-org/alist/v3/internal/model"
"github.com/hirochachacha/go-smb2"
)

func (d *SMB) updateLastConnTime() {
d.lastConnTime = time.Now()
}

func (d *SMB) cleanLastConnTime() {
d.lastConnTime = time.Now().AddDate(0, 0, -1)
}

func (d *SMB) initFS() error {
conn, err := net.Dial("tcp", d.Address)
if err != nil {
Expand All @@ -30,9 +39,20 @@ func (d *SMB) initFS() error {
if err != nil {
return err
}
d.updateLastConnTime()
return err
}

func (d *SMB) checkConn() error {
if time.Since(d.lastConnTime) < 5*time.Minute {
return nil
}
if d.fs != nil {
_ = d.fs.Umount()
}
return d.initFS()
}

func (d *SMB) getSMBPath(dir model.Obj) string {
fullPath := dir.GetPath()
if fullPath[0:1] != "." {
Expand Down

0 comments on commit a3b631f

Please sign in to comment.