/
ssh_shell.go
151 lines (129 loc) · 3.61 KB
/
ssh_shell.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
package easyshell
import (
"github.com/3th1nk/easygo/util"
"github.com/3th1nk/easyshell/core"
"github.com/3th1nk/easyshell/internal/misc"
"github.com/3th1nk/easyshell/pkg/interceptor"
"github.com/pkg/sftp"
"golang.org/x/crypto/ssh"
"time"
)
type SshShellConfig struct {
core.Config
Credential *SshCredential // 凭证
Echo bool // 模拟终端回显,默认值 false,部分网络设备上无效(总是回显)
Term string // 模拟终端类型,默认值 VT100
TermHeight int // 模拟终端高度,默认值 200
TermWidth int // 模拟终端宽度,默认值 80
}
func (c *SshShellConfig) EnsureInit() {
if c.Term == "" {
c.Term = "VT100"
}
if c.TermHeight <= 0 {
c.TermHeight = 200
}
if c.TermWidth <= 0 {
c.TermWidth = 80
}
}
func NewSshShell(config ...*SshShellConfig) (*SshShell, error) {
var cfg *SshShellConfig
if len(config) > 0 && config[0] != nil {
cfg = config[0]
} else {
cfg = &SshShellConfig{}
}
cfg.EnsureInit()
client, e := NewSshClient(cfg.Credential)
if e != nil {
return nil, e
}
shell, err := NewSshShellFromClient(client, cfg)
if err != nil {
_ = client.Close()
return nil, err
}
shell.ownClient = true
return shell, nil
}
func NewSshShellFromClient(client *ssh.Client, config ...*SshShellConfig) (*SshShell, error) {
var cfg *SshShellConfig
if len(config) > 0 && config[0] != nil {
cfg = config[0]
} else {
cfg = &SshShellConfig{}
}
cfg.EnsureInit()
addr := client.RemoteAddr().String()
session, err := client.NewSession()
if err != nil {
return nil, &core.Error{Op: "session", Addr: addr, Err: err}
}
echo := util.IfInt(cfg.Echo, 1, 0)
if err = session.RequestPty(cfg.Term, cfg.TermHeight, cfg.TermWidth, ssh.TerminalModes{
ssh.ECHO: uint32(echo),
ssh.TTY_OP_ISPEED: 14400,
ssh.TTY_OP_OSPEED: 14400,
}); err != nil {
_ = session.Close()
return nil, &core.Error{Op: "term", Addr: addr, Err: err}
}
pIn, _ := session.StdinPipe()
pOut, _ := session.StdoutPipe()
pErr, _ := session.StderrPipe()
if err = session.Shell(); err != nil {
_ = session.Close()
return nil, &core.Error{Op: "shell", Addr: addr, Err: err}
}
r := core.New(pIn, pOut, pErr, cfg.Config)
// 此时可能会有一些输出,可能是欢迎信息、日志打印、密码修改提示等,需要读取并处理,防止影响后续操作
// 对于密码修改提示,部分设备是会提示密码过期,是否修改密码,也有设备是直接提示输入密码,这里只处理前者,总是答复否,不自动修改密码
var headLine []string
_ = r.ReadToEndLine(5*time.Second, func(lines []string) {
headLine = append(headLine, lines...)
}, interceptor.AlwaysNo(true))
headLine = misc.TrimEmptyLine(headLine)
return &SshShell{ReadWriter: r, client: client, session: session, headLine: headLine}, nil
}
type SshShell struct {
*core.ReadWriter
client *ssh.Client
session *ssh.Session
sftp *sftp.Client
ownClient bool
headLine []string
}
func (this *SshShell) Client() *ssh.Client {
return this.client
}
func (this *SshShell) Session() *ssh.Session {
return this.session
}
func (this *SshShell) HeadLine() []string {
return this.headLine
}
func (this *SshShell) Close() (err error) {
if this.sftp != nil {
if e := this.sftp.Close(); e != nil {
err = e
}
this.sftp = nil
}
if this.session != nil {
if e := this.session.Close(); e != nil && err == nil {
err = e
}
this.session = nil
}
if this.client != nil {
if this.ownClient {
if e := this.client.Close(); e != nil && err == nil {
err = e
}
}
this.client = nil
}
this.ReadWriter.Stop()
return
}