forked from hashicorp/nomad
/
ssh_runner.go
246 lines (222 loc) · 6.99 KB
/
ssh_runner.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
package provisioning
import (
"context"
"fmt"
"log"
"os"
"os/exec"
"path/filepath"
"strings"
"testing"
"time"
)
// SSHRunner is a ProvisioningRunner that deploys via ssh.
// Terraform does all of this more elegantly and portably in its
// ssh communicator, but by shelling out we avoid pulling in TF's as
// a Nomad dependency, and avoid some long-standing issues with
// connections to Windows servers. The tradeoff is losing portability
// but in practice we're always going to run this from a Unixish
// machine.
type SSHRunner struct {
Key string // `json:"key"`
User string // `json:"user"`
Host string // `json:"host"`
Port int // `json:"port"`
// none of these are available at time of construction, but
// should be populated in Open().
t *testing.T
controlSockPath string
ctx context.Context
cancelFunc context.CancelFunc
copyMethod func(*SSHRunner, string, string) error
muxWait chan struct{}
}
// Open establishes the ssh connection. We keep this connection open
// so that we can multiplex subsequent ssh connections.
func (runner *SSHRunner) Open(t *testing.T) error {
runner.t = t
runner.Logf("opening connection to %s", runner.Host)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
runner.ctx = ctx
runner.cancelFunc = cancel
runner.muxWait = make(chan struct{})
home, _ := os.UserHomeDir()
runner.controlSockPath = filepath.Join(
home, ".ssh",
fmt.Sprintf("ssh-control-%s-%d.sock", runner.Host, os.Getpid()))
cmd := exec.CommandContext(ctx,
"ssh",
"-M", "-S", runner.controlSockPath,
"-o", "StrictHostKeyChecking=no", // we're those terrible cloud devs
"-o", "UserKnownHostsFile=/dev/null",
"-o", "LogLevel=ERROR",
"-o", "ConnectTimeout=60", // give the target a while to come up
"-i", runner.Key,
"-p", fmt.Sprintf("%v", runner.Port),
fmt.Sprintf("%s@%s", runner.User, runner.Host),
)
go func() {
// will block until command completes, we cancel, or timeout.
// there's no point in returning the error here as we only
// hit it when we're done and Windows unfortunately tends to
// return 1 even when the script is complete.
cmd.Run()
runner.muxWait <- struct{}{}
}()
return nil
}
func (runner *SSHRunner) Run(script string) error {
commands := strings.Split(strings.TrimSpace(script), "\n")
for _, command := range commands {
err := runner.run(strings.TrimSpace(command))
if err != nil {
runner.cancelFunc()
return err
}
}
return nil
}
func (runner *SSHRunner) run(command string) error {
if runner.controlSockPath == "" {
return fmt.Errorf("Run failed: you need to call Open() first")
}
runner.Logf("running '%s'", command)
cmd := exec.CommandContext(runner.ctx,
"ssh",
"-S", runner.controlSockPath,
"-o", "StrictHostKeyChecking=no",
"-o", "UserKnownHostsFile=/dev/null",
"-o", "LogLevel=ERROR",
"-i", runner.Key,
"-p", fmt.Sprintf("%v", runner.Port),
fmt.Sprintf("%s@%s", runner.User, runner.Host),
command)
stdoutStderr, err := cmd.CombinedOutput()
if err != nil && err != context.Canceled {
runner.LogErrOutput(string(stdoutStderr))
return err
}
runner.LogOutput(string(stdoutStderr))
return nil
}
// Copy uploads the local path to the remote path. We call into
// different copy methods for Linux vs Windows because their path
// semantics are slightly different and the typical ssh users have
// different permissions.
func (runner *SSHRunner) Copy(local, remote string) error {
return runner.copyMethod(runner, local, remote)
}
// TODO: would be nice to set file owner/mode here
func copyLinux(runner *SSHRunner, local, remote string) error {
t := runner.t
runner.Logf("copying '%s' to '%s'", local, remote)
remoteDir, remoteFileName := filepath.Split(remote)
// we stage to /tmp so that we can handle root-owned files
tempPath := fmt.Sprintf("/tmp/%s", remoteFileName)
cmd := exec.CommandContext(runner.ctx,
"scp", "-r",
"-o", fmt.Sprintf("ControlPath=%s", runner.controlSockPath),
"-o", "StrictHostKeyChecking=no",
"-o", "UserKnownHostsFile=/dev/null",
"-o", "LogLevel=ERROR",
"-i", runner.Key,
"-P", fmt.Sprintf("%v", runner.Port),
local,
fmt.Sprintf("%s@%s:%s", runner.User, runner.Host, tempPath))
stdoutStderr, err := cmd.CombinedOutput()
if err != nil && err != context.Canceled {
runner.LogErrOutput(string(stdoutStderr))
runner.cancelFunc()
return err
}
fi, err := os.Stat(local)
if err != nil {
t.Fatalf("could not read '%s'", local)
}
if fi.IsDir() {
// this is a little inefficient but it lets us merge the contents of
// a bundled directory with existing directories
err = runner.Run(
fmt.Sprintf("sudo mkdir -p %s; sudo cp -R %s %s; sudo rm -r %s",
remote, tempPath, remoteDir, tempPath))
} else {
err = runner.run(fmt.Sprintf("sudo mv %s %s", tempPath, remoteDir))
}
return err
}
// staging to Windows tempdirs is a little messier, but "fortunately"
// nobody seems to complain about connecting via ssh as Administrator on
// Windows so we can just bypass the problem.
func copyWindows(runner *SSHRunner, local, remote string) error {
runner.Logf("copying '%s' to '%s'", local, remote)
remoteDir, _ := filepath.Split(remote)
fi, err := os.Stat(local)
if err != nil {
runner.t.Fatalf("could not read '%s'", local)
}
remotePath := remote
if fi.IsDir() {
remotePath = remoteDir
}
cmd := exec.CommandContext(runner.ctx,
"scp", "-r",
"-o", fmt.Sprintf("ControlPath=%s", runner.controlSockPath),
"-o", "StrictHostKeyChecking=no",
"-o", "UserKnownHostsFile=/dev/null",
"-o", "LogLevel=ERROR",
"-i", runner.Key,
"-P", fmt.Sprintf("%v", runner.Port),
local,
fmt.Sprintf("%s@%s:'%s'", runner.User, runner.Host, remotePath))
stdoutStderr, err := cmd.CombinedOutput()
if err != nil && err != context.Canceled {
runner.LogErrOutput(string(stdoutStderr))
runner.cancelFunc()
return err
}
return err
}
func (runner *SSHRunner) Close() {
runner.Log("closing connection")
runner.cancelFunc()
<-runner.muxWait
}
// 'go test -v' only emits logs after the entire test run is complete,
// but that makes it much harder to debug hanging deployments. These
// methods wrap the test logger or just emit directly w/ fmt.Print if
// the '-v' flag was set.
func (runner *SSHRunner) Log(args ...interface{}) {
if runner.t == nil {
log.Fatal("no t.Testing configured for SSHRunner")
}
if testing.Verbose() {
fmt.Printf("[" + runner.Host + "] ")
fmt.Println(args...)
} else {
runner.t.Log(args...)
}
}
func (runner *SSHRunner) Logf(format string, args ...interface{}) {
if runner.t == nil {
log.Fatal("no t.Testing configured for SSHRunner")
}
if testing.Verbose() {
fmt.Printf("["+runner.Host+"] "+format+"\n", args...)
} else {
runner.t.Logf("["+runner.Host+"] "+format, args...)
}
}
func (runner *SSHRunner) LogOutput(output string) {
if testing.Verbose() {
fmt.Println("\033[32m" + output + "\033[0m")
} else {
runner.t.Log(output)
}
}
func (runner *SSHRunner) LogErrOutput(output string) {
if testing.Verbose() {
fmt.Println("\033[31m" + output + "\033[0m")
} else {
runner.t.Log(output)
}
}