From 9cfcd9871eb341f1aa171a56fd2cc22ca0be62da Mon Sep 17 00:00:00 2001 From: tangrunze <614560331@qq.com> Date: Tue, 22 Aug 2023 17:17:52 +0800 Subject: [PATCH] test: modify scp test --- config.go | 3 +- config_test.go | 31 +++++++++++++++ logger.go | 2 +- scp.go | 14 +++---- scp_test.go | 82 +++++++++++++++++++++++++++----------- tmp.go | 19 --------- utils.go | 32 +++------------ utils_test.go | 104 ++++++++++++++++++++++++++++++++++++++++++------- 8 files changed, 196 insertions(+), 91 deletions(-) create mode 100644 config_test.go delete mode 100644 tmp.go diff --git a/config.go b/config.go index 9b60876..f55fbd4 100644 --- a/config.go +++ b/config.go @@ -1,6 +1,7 @@ package scpw import ( + "fmt" "gopkg.in/yaml.v2" "os" "os/user" @@ -70,5 +71,5 @@ func LoadConfigBytes(names ...string) ([]byte, error) { return sshw, nil } } - return nil, err + return nil, fmt.Errorf("cannot find config from %s", u.HomeDir) } diff --git a/config_test.go b/config_test.go new file mode 100644 index 0000000..691d8ab --- /dev/null +++ b/config_test.go @@ -0,0 +1,31 @@ +package scpw + +import ( + "github.com/stretchr/testify/assert" + "gopkg.in/yaml.v2" + "os" + "os/user" + "path/filepath" + "testing" +) + +func TestLoadConfig(t *testing.T) { + u, err := user.Current() + os.Remove(filepath.Join(u.HomeDir, ".scpw")) + os.Remove(filepath.Join(u.HomeDir, ".scpw.yml")) + os.Remove(filepath.Join(u.HomeDir, ".scpw.yaml")) + + _, err = LoadConfig() + assert.NotNil(t, err) + + var config []*Node + config = append(config, &Node{Name: "local", Host: "127.0.0.1", User: "root", Port: "22", Password: "123", LRMap: []LRMap{{Local: "/tmp/a", Remote: "/tmp/b"}}, Typ: GET}) + b, err := yaml.Marshal(config) + assert.Nil(t, err) + _, err = os.Create(filepath.Join(u.HomeDir, ".scpw")) + assert.Nil(t, err) + assert.Nil(t, os.WriteFile(filepath.Join(u.HomeDir, ".scpw"), b, os.FileMode(0777))) + + _, err = LoadConfig() + assert.Nil(t, err) +} diff --git a/logger.go b/logger.go index 25ee1ae..b5f8dfb 100644 --- a/logger.go +++ b/logger.go @@ -69,7 +69,7 @@ func (l *logHandle) Log(args ...interface{}) { } func newLogger(name string) *logHandle { - l := &logHandle{Logger: *logrus.New(), name: name, colorful: SupportANSIColor(os.Stderr.Fd())} + l := &logHandle{Logger: *logrus.New(), name: name, colorful: true} l.Formatter = l if syslogHook != nil { l.AddHook(syslogHook) diff --git a/scp.go b/scp.go index 00d8a2b..6440780 100644 --- a/scp.go +++ b/scp.go @@ -172,7 +172,7 @@ func (scp *SCP) SwitchScpwFunc(ctx context.Context, localPath, remotePath string last := remotePath[len(remotePath)-1] if last == '\\' || last == '/' { remotePath = remotePath[:len(remotePath)-1] - err := os.Mkdir(localTmp, os.FileMode(uint32(0755))) + err := os.Mkdir(localTmp, os.FileMode(0755)) if err != nil { return err } @@ -417,26 +417,26 @@ func WalkTree(ctx context.Context, scpChan *scpChan, rootParent, root, dstPath s } func (scp *SCP) Put(ctx context.Context, srcPath, dstPath string) error { - resource, err := NewResource(srcPath) + stat, err := os.Stat(srcPath) if err != nil { return err } - if resource.IsDir() { + if stat.IsDir() { return errors.New(fmt.Sprintf("local:[%s] is dir", srcPath)) } var atime, mtime string if scp.KeepTime { - atime, mtime = StatTimeV2(resource.FileInfo) + atime, mtime = StatTimeV2(stat) } - mode, err := FileModeV1(resource.Path) + mode, err := FileModeV1(srcPath) if err != nil { return err } - open, err := os.Open(resource.Path) + open, err := os.Open(srcPath) if err != nil { return err } - return scp.put(ctx, dstPath, open, mode, resource.Size(), atime, mtime) + return scp.put(ctx, dstPath, open, mode, stat.Size(), atime, mtime) } func (scp *SCP) put(ctx context.Context, dstPath string, in io.Reader, mode string, size int64, atime, mtime string) error { diff --git a/scp_test.go b/scp_test.go index 7ac76fd..fe3c72a 100644 --- a/scp_test.go +++ b/scp_test.go @@ -106,16 +106,6 @@ func TestPutAll(t *testing.T) { assert.Nil(t, err) } -//func TestGetSwitch(t *testing.T) { -// // keep remote server has remoteFile -// local, remote := RandName(baseDir), filepath.Join(baseDir, "file1") -// ssh, err := NewSSH(testNode) -// assert.Nil(t, err) -// scpwCli := NewSCP(ssh, true) -// err = scpwCli.SwitchScpwFunc(context.Background(), local, remote, GET) -// assert.Nil(t, err) -//} - func TestGetFile(t *testing.T) { local, remote := RandName("/tmp"), RandName(baseRemoteDir) assert.Nil(t, writeFile(remote)) @@ -154,18 +144,6 @@ func TestGetFileRemoteIsDir(t *testing.T) { assert.NotNil(t, err) } -//func TestGetAllSwitch(t *testing.T) { -// local := RandName("/tmp") -// os.Mkdir(local, os.FileMode(uint32(0777))) -// log.Infof("local:%s", local) -// remote := filepath.Join(baseDir, "dir1/") -// ssh, err := NewSSH(testNode) -// assert.Nil(t, err) -// scpwCli := NewSCP(ssh, true) -// err = scpwCli.SwitchScpwFunc(context.Background(), local, remote, GET) -// assert.NotNil(t, err) -//} - func TestGetAll(t *testing.T) { local := RandName("/tmp") assert.Nil(t, mkdir(local)) @@ -177,6 +155,66 @@ func TestGetAll(t *testing.T) { assert.Nil(t, err) } +func TestPutSwitchScpwFunc(t *testing.T) { + // put file + local := RandName("/tmp") + assert.Nil(t, writeFile(local)) + remote := RandName("/tmp") + ssh, err := NewSSH(testNode) + assert.Nil(t, err) + scpwCli := NewSCP(ssh, true) + err = scpwCli.SwitchScpwFunc(context.Background(), local, remote, PUT) + assert.Nil(t, err) + + local = RandName("/tmp") + assert.Nil(t, mkdir(local)) + assert.Nil(t, writeFile(RandName(local))) + assert.Nil(t, writeFile(RandName(local))) + remote = RandName("/tmp") + assert.Nil(t, mkdir(remote)) + // put dir all + err = scpwCli.SwitchScpwFunc(context.Background(), local, remote, PUT) + assert.Nil(t, err) + + // put dir exclude root + remote = RandName("/tmp") + assert.Nil(t, mkdir(remote)) + err = scpwCli.SwitchScpwFunc(context.Background(), local+"/*", remote, PUT) + assert.Nil(t, err) + + // put file permission deny + local = "/tmp/notexist" + remote = RandName("/tmp") + err = scpwCli.SwitchScpwFunc(context.Background(), local, remote, PUT) + assert.NotNil(t, err) +} + +func TestGetSwitchScpwFunc(t *testing.T) { + // get file + local := RandName("/tmp") + assert.Nil(t, mkdir(local)) + remote := RandName("/tmp") + assert.Nil(t, writeFile(remote)) + ssh, err := NewSSH(testNode) + assert.Nil(t, err) + scpwCli := NewSCP(ssh, true) + err = scpwCli.SwitchScpwFunc(context.Background(), local, remote, GET) + assert.Nil(t, err) + + // get dir all + local = RandName("/tmp") + assert.Nil(t, mkdir(local)) + remote = baseRemoteDir + err = scpwCli.SwitchScpwFunc(context.Background(), local, remote+"/", GET) + assert.Nil(t, err) + + // get file permission deny + local = RandName("/tmp") + remote = noPermissionFile + err = scpwCli.SwitchScpwFunc(context.Background(), local, remote, GET) + assert.NotNil(t, err) +} + func TestWalkTree(t *testing.T) { scpCh := &scpChan{fileChan: make(chan File), exitChan: make(chan struct{}), closeChan: make(chan struct{})} path := "./" diff --git a/tmp.go b/tmp.go deleted file mode 100644 index 115d05d..0000000 --- a/tmp.go +++ /dev/null @@ -1,19 +0,0 @@ -package scpw - -import ( - "io/fs" - "os" -) - -type Resource struct { - fs.FileInfo - Path string -} - -func NewResource(filePath string) (*Resource, error) { - stat, err := os.Stat(filePath) - if err != nil { - return nil, err - } - return &Resource{stat, filePath}, nil -} diff --git a/utils.go b/utils.go index a351721..25c5e5b 100644 --- a/utils.go +++ b/utils.go @@ -1,20 +1,14 @@ package scpw import ( - "errors" "fmt" - "github.com/go-cmd/cmd" "github.com/google/uuid" - "github.com/mattn/go-isatty" - "golang.org/x/crypto/ssh" "os" "path/filepath" - "runtime" "strconv" ) var ( - unit = []string{"B", "KB", "GB", "TB"} letters = []byte("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890") ) @@ -50,22 +44,6 @@ func Addr(ip, port string) string { return fmt.Sprintf("%s:%s", ip, port) } -func SupportANSIColor(fd uintptr) bool { - return isatty.IsTerminal(fd) && runtime.GOOS != "windows" -} - -func HostKey(ip string) (ssh.PublicKey, error) { - findCmd := cmd.NewCmd("ssh-keygen", "-F", ip) - statusChan := findCmd.Start() - finalStatus := <-statusChan - if finalStatus.Error != nil || len(finalStatus.Stdout) == 0 { - log.Errorf("cannot find ip:{%s} HostKey", ip) - return nil, errors.New("find HostKey fail") - } - hostKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(finalStatus.Stdout[1])) - return hostKey, err -} - func FileModeV1(root string) (string, error) { file, err := os.Stat(root) if err != nil { @@ -104,15 +82,15 @@ func StatDir(root string) (entries []os.DirEntry, name, mode, atime, mtime strin return } -func StatFile(root string) (string, string, string, string, string, error) { +func StatFile(root string) (name string, mode string, atime string, mtime string, size string, err error) { stat, err := os.Stat(root) if err != nil { return "", "", "", "", "", err } - name := stat.Name() - mode := FileModeV2(stat) - atime, mtime := StatTimeV2(stat) - size := strconv.FormatInt(stat.Size(), 10) + name = stat.Name() + mode = FileModeV2(stat) + atime, mtime = StatTimeV2(stat) + size = strconv.FormatInt(stat.Size(), 10) return name, mode, size, atime, mtime, nil } diff --git a/utils_test.go b/utils_test.go index e2be9e8..a92e75f 100644 --- a/utils_test.go +++ b/utils_test.go @@ -2,6 +2,7 @@ package scpw import ( "fmt" + "github.com/stretchr/testify/require" "math/rand" "os" "path/filepath" @@ -9,27 +10,102 @@ import ( "time" ) -func TestFileMode(t *testing.T) { +func TestMinInt(t *testing.T) { + a, b := 1, 2 + require.Equal(t, a, MinInt(a, b)) + require.Equal(t, a, MinInt(b, a)) +} + +func TestMaxInt(t *testing.T) { + a, b := 1, 2 + require.Equal(t, b, MaxInt(a, b)) + require.Equal(t, b, MaxInt(b, a)) +} + +func TestMinInt64(t *testing.T) { + a, b := int64(1), int64(2) + require.Equal(t, a, MinInt64(a, b)) + require.Equal(t, a, MinInt64(b, a)) +} + +func TestMaxInt64(t *testing.T) { + a, b := int64(1), int64(2) + require.Equal(t, b, MaxInt64(a, b)) + require.Equal(t, b, MaxInt64(b, a)) +} + +func TestAddr(t *testing.T) { + require.Equal(t, "127.0.0.1:80", Addr("127.0.0.1", "80")) +} + +func TestFileModeV1(t *testing.T) { mode, err := FileModeV1("./cmd/scpw/main.go") - if err != nil { - panic(err) - } - log.Infof("mode:%s", mode) + require.Nil(t, err) + log.Infof(mode) + + mode, err = FileModeV1("./notexist.go") + require.NotNil(t, err) } -func TestStatTimeV2(t *testing.T) { - open, _ := os.Stat("./cmd/scpw/main.go") - atime, mtime := StatTimeV2(open) - log.Infof("atime:%s mtime:%s", atime, mtime) +func TestFileModeV2(t *testing.T) { + stat, err := os.Stat("./cmd/scpw/main.go") + require.Nil(t, err) + v2 := FileModeV2(stat) + log.Infof(v2) +} + +func TestStatDirMeta(t *testing.T) { + _, _, _, _, _, err := StatDirMeta("./notexist") + require.NotNil(t, err) + + name, mode, atime, mtime, dir, err := StatDirMeta("./cmd") + require.Nil(t, err) + log.Infof("name:%s mode:%s atime:%s mtime:%s dir:%v", name, mode, atime, mtime, dir) +} + +func TestStatDirChild(t *testing.T) { + child, err := StatDirChild("./cmd") + require.Nil(t, err) + log.Infof("child:%v", child) +} + +func TestStatDir(t *testing.T) { + entries, name, mode, atime, mtime, isDir, err := StatDir("./cmd") + require.Nil(t, err) + log.Infof("entries:%v name:%s mode:%s atime:%s mtime:%s isDir:%v", entries, name, mode, atime, mtime, isDir) + + _, _, _, _, _, _, err = StatDir("./notexist") + require.NotNil(t, err) +} + +func TestFile(t *testing.T) { + name, mode, atime, mtime, size, err := StatFile("./cmd/scpw/main.go") + require.Nil(t, err) + log.Infof("name:%s mode:%s atime:%s mtime:%s size:%s", name, mode, atime, mtime, size) + + _, _, _, _, _, err = StatFile("./notexist") + require.NotNil(t, err) +} + +func TestParseUnit32(t *testing.T) { + unit32, err := ParseUnit32("0777") + require.Nil(t, err) + require.Equal(t, uint32(511), unit32) + + _, err = ParseUnit32("3gadfsgasd") + require.NotNil(t, err) } func TestParseInt64(t *testing.T) { res, err := ParseInt64("1446425371") - if err != nil { - panic(err) - } - time := res - log.Infof("time:%d", time) + require.Nil(t, err) + require.Equal(t, int64(1446425371), res) +} + +func TestStatTimeV2(t *testing.T) { + open, _ := os.Stat("./cmd/scpw/main.go") + atime, mtime := StatTimeV2(open) + log.Infof("atime:%s mtime:%s", atime, mtime) } func TestParseInt8(t *testing.T) {