forked from gruntwork-io/terratest
-
Notifications
You must be signed in to change notification settings - Fork 0
/
agent.go
140 lines (122 loc) · 3.44 KB
/
agent.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
package ssh
import (
"crypto/x509"
"encoding/pem"
"io"
"io/ioutil"
"net"
"os"
"path/filepath"
"testing"
"golang.org/x/crypto/ssh/agent"
)
type SshAgent struct {
stop chan bool
stopped chan bool
socketDir string
socketFile string
agent agent.Agent
ln net.Listener
}
// Create SSH agent, start it in background and returns control back to the main thread
// You should stop the agent to cleanup files afterwards by calling `defer s.Stop()`
func NewSshAgent(t *testing.T, socketDir string, socketFile string) (*SshAgent, error) {
var err error
s := &SshAgent{make(chan bool), make(chan bool), socketDir, socketFile, agent.NewKeyring(), nil}
s.ln, err = net.Listen("unix", s.socketFile)
if err != nil {
return nil, err
}
go s.run(t)
return s, nil
}
// expose socketFile variable
func (s *SshAgent) SocketFile() string {
return s.socketFile
}
// SSH Agent listener and handler
func (s *SshAgent) run(t *testing.T) {
defer close(s.stopped)
for {
select {
case <-s.stop:
return
default:
c, err := s.ln.Accept()
if err != nil {
select {
// When s.Stop() closes the listener, s.ln.Accept() returns an error that can be ignored
// since the agent is in stopping process
case <-s.stop:
return
// When s.ln.Accept() returns a legit error, we print it and continue accepting further requests
default:
t.Logf("could not accept connection to agent %v", err)
continue
}
} else {
defer c.Close()
go func(c io.ReadWriter) {
err := agent.ServeAgent(s.agent, c)
if err != nil {
t.Logf("could not serve ssh agent %v", err)
}
}(c)
}
}
}
}
// Stop and clean up SSH agent
func (s *SshAgent) Stop() {
close(s.stop)
s.ln.Close()
<-s.stopped
os.RemoveAll(s.socketDir)
}
// Instantiates and returns an in-memory ssh agent with the given KeyPair already added
// You should stop the agent to cleanup files afterwards by calling `defer sshAgent.Stop()`
func SshAgentWithKeyPair(t *testing.T, keyPair *KeyPair) *SshAgent {
sshAgent, err := SshAgentWithKeyPairE(t, keyPair)
if err != nil {
t.Fatal(err)
}
return sshAgent
}
func SshAgentWithKeyPairE(t *testing.T, keyPair *KeyPair) (*SshAgent, error) {
sshAgent, err := SshAgentWithKeyPairsE(t, []*KeyPair{keyPair})
return sshAgent, err
}
func SshAgentWithKeyPairs(t *testing.T, keyPairs []*KeyPair) *SshAgent {
sshAgent, err := SshAgentWithKeyPairsE(t, keyPairs)
if err != nil {
t.Fatal(err)
}
return sshAgent
}
// Instantiates and returns an in-memory ssh agent with the given KeyPair(s) already added
// You should stop the agent to cleanup files afterwards by calling `defer sshAgent.Stop()`
func SshAgentWithKeyPairsE(t *testing.T, keyPairs []*KeyPair) (*SshAgent, error) {
t.Log("Generating SSH Agent with given KeyPair(s)")
// Instantiate a temporary SSH agent
socketDir, err := ioutil.TempDir("", "ssh-agent-")
if err != nil {
return nil, err
}
socketFile := filepath.Join(socketDir, "ssh_auth.sock")
sshAgent, err := NewSshAgent(t, socketDir, socketFile)
if err != nil {
return nil, err
}
// add given ssh keys to the newly created agent
for _, keyPair := range keyPairs {
// Create SSH key for the agent using the given SSH key pair(s)
block, _ := pem.Decode([]byte(keyPair.PrivateKey))
privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return nil, err
}
key := agent.AddedKey{PrivateKey: privateKey}
sshAgent.agent.Add(key)
}
return sshAgent, err
}