Skip to content

Commit

Permalink
fix(sftp): reconnect to server when connection was broken (#6416 close
Browse files Browse the repository at this point in the history
…#6403)

* fix(sftp): reconnect to server when conn was broken (close #6403)

* fix(sftp): fix typo

---------

Co-authored-by: George Chen <gchen@isimarkets.com>
  • Loading branch information
okcy1016 and George Chen committed May 9, 2024
1 parent f261ef5 commit b57afd0
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 1 deletion.
24 changes: 23 additions & 1 deletion drivers/sftp/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ import (
type SFTP struct {
model.Storage
Addition
client *sftp.Client
client *sftp.Client
clientConnectionError error
}

func (d *SFTP) Config() driver.Config {
Expand All @@ -39,6 +40,9 @@ func (d *SFTP) Drop(ctx context.Context) error {
}

func (d *SFTP) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) {
if err := d.clientReconnectOnConnectionError(); err != nil {
return nil, err
}
log.Debugf("[sftp] list dir: %s", dir.GetPath())
files, err := d.client.ReadDir(dir.GetPath())
if err != nil {
Expand All @@ -51,6 +55,9 @@ func (d *SFTP) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]
}

func (d *SFTP) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) {
if err := d.clientReconnectOnConnectionError(); err != nil {
return nil, err
}
remoteFile, err := d.client.Open(file.GetPath())
if err != nil {
return nil, err
Expand All @@ -62,14 +69,23 @@ func (d *SFTP) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*
}

func (d *SFTP) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error {
if err := d.clientReconnectOnConnectionError(); err != nil {
return err
}
return d.client.MkdirAll(path.Join(parentDir.GetPath(), dirName))
}

func (d *SFTP) Move(ctx context.Context, srcObj, dstDir model.Obj) error {
if err := d.clientReconnectOnConnectionError(); err != nil {
return err
}
return d.client.Rename(srcObj.GetPath(), path.Join(dstDir.GetPath(), srcObj.GetName()))
}

func (d *SFTP) Rename(ctx context.Context, srcObj model.Obj, newName string) error {
if err := d.clientReconnectOnConnectionError(); err != nil {
return err
}
return d.client.Rename(srcObj.GetPath(), path.Join(path.Dir(srcObj.GetPath()), newName))
}

Expand All @@ -78,10 +94,16 @@ func (d *SFTP) Copy(ctx context.Context, srcObj, dstDir model.Obj) error {
}

func (d *SFTP) Remove(ctx context.Context, obj model.Obj) error {
if err := d.clientReconnectOnConnectionError(); err != nil {
return err
}
return d.remove(obj.GetPath())
}

func (d *SFTP) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error {
if err := d.clientReconnectOnConnectionError(); err != nil {
return err
}
dstFile, err := d.client.Create(path.Join(dstDir.GetPath(), stream.GetName()))
if err != nil {
return err
Expand Down
18 changes: 18 additions & 0 deletions drivers/sftp/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"path"

"github.com/pkg/sftp"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"
)

Expand All @@ -30,6 +31,23 @@ func (d *SFTP) initClient() error {
return err
}
d.client, err = sftp.NewClient(conn)
if err == nil {
d.clientConnectionError = nil
go func(d *SFTP) {
d.clientConnectionError = d.client.Wait()
}(d)
}
return err
}

func (d *SFTP) clientReconnectOnConnectionError() error {
err := d.clientConnectionError
if err == nil {
return nil
}
log.Debugf("[sftp] discarding closed sftp connection: %v", err)
_ = d.client.Close()
err = d.initClient()
return err
}

Expand Down

0 comments on commit b57afd0

Please sign in to comment.