Skip to content

Commit

Permalink
test: modify scp test
Browse files Browse the repository at this point in the history
  • Loading branch information
T-TRz879 committed Aug 22, 2023
1 parent 78999d9 commit 9cfcd98
Show file tree
Hide file tree
Showing 8 changed files with 196 additions and 91 deletions.
3 changes: 2 additions & 1 deletion config.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package scpw

import (
"fmt"
"gopkg.in/yaml.v2"
"os"
"os/user"
Expand Down Expand Up @@ -70,5 +71,5 @@ func LoadConfigBytes(names ...string) ([]byte, error) {
return sshw, nil
}
}
return nil, err
return nil, fmt.Errorf("cannot find config from %s", u.HomeDir)
}
31 changes: 31 additions & 0 deletions config_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package scpw

import (
"github.com/stretchr/testify/assert"
"gopkg.in/yaml.v2"
"os"
"os/user"
"path/filepath"
"testing"
)

func TestLoadConfig(t *testing.T) {
u, err := user.Current()
os.Remove(filepath.Join(u.HomeDir, ".scpw"))
os.Remove(filepath.Join(u.HomeDir, ".scpw.yml"))
os.Remove(filepath.Join(u.HomeDir, ".scpw.yaml"))

_, err = LoadConfig()
assert.NotNil(t, err)

var config []*Node
config = append(config, &Node{Name: "local", Host: "127.0.0.1", User: "root", Port: "22", Password: "123", LRMap: []LRMap{{Local: "/tmp/a", Remote: "/tmp/b"}}, Typ: GET})
b, err := yaml.Marshal(config)
assert.Nil(t, err)
_, err = os.Create(filepath.Join(u.HomeDir, ".scpw"))
assert.Nil(t, err)
assert.Nil(t, os.WriteFile(filepath.Join(u.HomeDir, ".scpw"), b, os.FileMode(0777)))

_, err = LoadConfig()
assert.Nil(t, err)
}
2 changes: 1 addition & 1 deletion logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func (l *logHandle) Log(args ...interface{}) {
}

func newLogger(name string) *logHandle {
l := &logHandle{Logger: *logrus.New(), name: name, colorful: SupportANSIColor(os.Stderr.Fd())}
l := &logHandle{Logger: *logrus.New(), name: name, colorful: true}
l.Formatter = l
if syslogHook != nil {
l.AddHook(syslogHook)
Expand Down
14 changes: 7 additions & 7 deletions scp.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ func (scp *SCP) SwitchScpwFunc(ctx context.Context, localPath, remotePath string
last := remotePath[len(remotePath)-1]
if last == '\\' || last == '/' {
remotePath = remotePath[:len(remotePath)-1]
err := os.Mkdir(localTmp, os.FileMode(uint32(0755)))
err := os.Mkdir(localTmp, os.FileMode(0755))
if err != nil {
return err
}
Expand Down Expand Up @@ -417,26 +417,26 @@ func WalkTree(ctx context.Context, scpChan *scpChan, rootParent, root, dstPath s
}

func (scp *SCP) Put(ctx context.Context, srcPath, dstPath string) error {
resource, err := NewResource(srcPath)
stat, err := os.Stat(srcPath)
if err != nil {
return err
}
if resource.IsDir() {
if stat.IsDir() {
return errors.New(fmt.Sprintf("local:[%s] is dir", srcPath))
}
var atime, mtime string
if scp.KeepTime {
atime, mtime = StatTimeV2(resource.FileInfo)
atime, mtime = StatTimeV2(stat)
}
mode, err := FileModeV1(resource.Path)
mode, err := FileModeV1(srcPath)
if err != nil {
return err
}
open, err := os.Open(resource.Path)
open, err := os.Open(srcPath)
if err != nil {
return err
}
return scp.put(ctx, dstPath, open, mode, resource.Size(), atime, mtime)
return scp.put(ctx, dstPath, open, mode, stat.Size(), atime, mtime)
}

func (scp *SCP) put(ctx context.Context, dstPath string, in io.Reader, mode string, size int64, atime, mtime string) error {
Expand Down
82 changes: 60 additions & 22 deletions scp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,16 +106,6 @@ func TestPutAll(t *testing.T) {
assert.Nil(t, err)
}

//func TestGetSwitch(t *testing.T) {
// // keep remote server has remoteFile
// local, remote := RandName(baseDir), filepath.Join(baseDir, "file1")
// ssh, err := NewSSH(testNode)
// assert.Nil(t, err)
// scpwCli := NewSCP(ssh, true)
// err = scpwCli.SwitchScpwFunc(context.Background(), local, remote, GET)
// assert.Nil(t, err)
//}

func TestGetFile(t *testing.T) {
local, remote := RandName("/tmp"), RandName(baseRemoteDir)
assert.Nil(t, writeFile(remote))
Expand Down Expand Up @@ -154,18 +144,6 @@ func TestGetFileRemoteIsDir(t *testing.T) {
assert.NotNil(t, err)
}

//func TestGetAllSwitch(t *testing.T) {
// local := RandName("/tmp")
// os.Mkdir(local, os.FileMode(uint32(0777)))
// log.Infof("local:%s", local)
// remote := filepath.Join(baseDir, "dir1/")
// ssh, err := NewSSH(testNode)
// assert.Nil(t, err)
// scpwCli := NewSCP(ssh, true)
// err = scpwCli.SwitchScpwFunc(context.Background(), local, remote, GET)
// assert.NotNil(t, err)
//}

func TestGetAll(t *testing.T) {
local := RandName("/tmp")
assert.Nil(t, mkdir(local))
Expand All @@ -177,6 +155,66 @@ func TestGetAll(t *testing.T) {
assert.Nil(t, err)
}

func TestPutSwitchScpwFunc(t *testing.T) {
// put file
local := RandName("/tmp")
assert.Nil(t, writeFile(local))
remote := RandName("/tmp")
ssh, err := NewSSH(testNode)
assert.Nil(t, err)
scpwCli := NewSCP(ssh, true)
err = scpwCli.SwitchScpwFunc(context.Background(), local, remote, PUT)
assert.Nil(t, err)

local = RandName("/tmp")
assert.Nil(t, mkdir(local))
assert.Nil(t, writeFile(RandName(local)))
assert.Nil(t, writeFile(RandName(local)))
remote = RandName("/tmp")
assert.Nil(t, mkdir(remote))
// put dir all
err = scpwCli.SwitchScpwFunc(context.Background(), local, remote, PUT)
assert.Nil(t, err)

// put dir exclude root
remote = RandName("/tmp")
assert.Nil(t, mkdir(remote))
err = scpwCli.SwitchScpwFunc(context.Background(), local+"/*", remote, PUT)
assert.Nil(t, err)

// put file permission deny
local = "/tmp/notexist"
remote = RandName("/tmp")
err = scpwCli.SwitchScpwFunc(context.Background(), local, remote, PUT)
assert.NotNil(t, err)
}

func TestGetSwitchScpwFunc(t *testing.T) {
// get file
local := RandName("/tmp")
assert.Nil(t, mkdir(local))
remote := RandName("/tmp")
assert.Nil(t, writeFile(remote))
ssh, err := NewSSH(testNode)
assert.Nil(t, err)
scpwCli := NewSCP(ssh, true)
err = scpwCli.SwitchScpwFunc(context.Background(), local, remote, GET)
assert.Nil(t, err)

// get dir all
local = RandName("/tmp")
assert.Nil(t, mkdir(local))
remote = baseRemoteDir
err = scpwCli.SwitchScpwFunc(context.Background(), local, remote+"/", GET)
assert.Nil(t, err)

// get file permission deny
local = RandName("/tmp")
remote = noPermissionFile
err = scpwCli.SwitchScpwFunc(context.Background(), local, remote, GET)
assert.NotNil(t, err)
}

func TestWalkTree(t *testing.T) {
scpCh := &scpChan{fileChan: make(chan File), exitChan: make(chan struct{}), closeChan: make(chan struct{})}
path := "./"
Expand Down
19 changes: 0 additions & 19 deletions tmp.go

This file was deleted.

32 changes: 5 additions & 27 deletions utils.go
Original file line number Diff line number Diff line change
@@ -1,20 +1,14 @@
package scpw

import (
"errors"
"fmt"
"github.com/go-cmd/cmd"
"github.com/google/uuid"
"github.com/mattn/go-isatty"
"golang.org/x/crypto/ssh"
"os"
"path/filepath"
"runtime"
"strconv"
)

var (
unit = []string{"B", "KB", "GB", "TB"}
letters = []byte("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890")
)

Expand Down Expand Up @@ -50,22 +44,6 @@ func Addr(ip, port string) string {
return fmt.Sprintf("%s:%s", ip, port)
}

func SupportANSIColor(fd uintptr) bool {
return isatty.IsTerminal(fd) && runtime.GOOS != "windows"
}

func HostKey(ip string) (ssh.PublicKey, error) {
findCmd := cmd.NewCmd("ssh-keygen", "-F", ip)
statusChan := findCmd.Start()
finalStatus := <-statusChan
if finalStatus.Error != nil || len(finalStatus.Stdout) == 0 {
log.Errorf("cannot find ip:{%s} HostKey", ip)
return nil, errors.New("find HostKey fail")
}
hostKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(finalStatus.Stdout[1]))
return hostKey, err
}

func FileModeV1(root string) (string, error) {
file, err := os.Stat(root)
if err != nil {
Expand Down Expand Up @@ -104,15 +82,15 @@ func StatDir(root string) (entries []os.DirEntry, name, mode, atime, mtime strin
return
}

func StatFile(root string) (string, string, string, string, string, error) {
func StatFile(root string) (name string, mode string, atime string, mtime string, size string, err error) {
stat, err := os.Stat(root)
if err != nil {
return "", "", "", "", "", err
}
name := stat.Name()
mode := FileModeV2(stat)
atime, mtime := StatTimeV2(stat)
size := strconv.FormatInt(stat.Size(), 10)
name = stat.Name()
mode = FileModeV2(stat)
atime, mtime = StatTimeV2(stat)
size = strconv.FormatInt(stat.Size(), 10)
return name, mode, size, atime, mtime, nil
}

Expand Down
Loading

0 comments on commit 9cfcd98

Please sign in to comment.