/
main.go
131 lines (108 loc) · 2.71 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
package main
import (
"bytes"
"flag"
"fmt"
"io/ioutil"
"log"
"os"
"os/exec"
"path"
"strings"
"github.com/kevinburke/ssh_config"
)
var (
helpFlag = flag.Bool("help", false, "display help message of the command line")
hostFlag = flag.String("host", "*", "Set the default host value you want to target")
listFlag = flag.Bool("list", false, "List the configurations in .ssh/config file")
confFlag = flag.String("config", ".ssh/config", "default config file of the SSH")
)
func main() {
// Parse command line flags
flag.Parse()
// Print default and terminate program
if *helpFlag {
flag.PrintDefaults()
return
}
// Return arguments in command line
identity := flag.Arg(0)
// Fetch user home directory
home, err := os.UserHomeDir()
if err != nil {
log.Fatal(err)
}
// Get current working directory
wdir, err := os.Getwd()
if err != nil {
log.Fatal(err)
}
// Appends the Home root folder to the relative route
if strings.HasPrefix(*confFlag, "./") {
*confFlag = path.Join(wdir, *confFlag)
}
// Appends the Home root folder to the relative route
if !bytes.Contains([]byte(*confFlag), []byte(home)) {
*confFlag = path.Join(home, *confFlag)
}
// Pull up file and check if there's an error
f, err := os.Open(*confFlag)
if err != nil {
log.Fatal(err)
}
// Decode the config file to a struct
c, err := ssh_config.Decode(f)
if err != nil {
log.Fatal(err)
}
// Display the list of configurations in the .ssh/config
if *listFlag {
fmt.Println(c.String())
return
}
// Check if there is at least one argument in the command line
if flag.NArg() < 1 {
panic("no argument given to the command line")
}
for _, host := range c.Hosts {
// Check if hostFlag is not `*` and that host has
// a match for *, to skip the selection for that host
if *hostFlag != "*" && host.Matches("*") {
continue
}
// Skip everything that has not matched the `-host` flag
if !host.Matches(*hostFlag) {
continue
}
for _, node := range host.Nodes {
switch t := node.(type) {
case *ssh_config.KV:
if t.Key != "IdentityFile" {
continue
}
// ssh Identity Path
identityLocation := path.Join(home, ".ssh", identity)
// Replace the default value
t.Value = identityLocation
// Add a key to the ssh-agent and the keychain
cmd := exec.Command("ssh-add", "-K", identityLocation)
// Run the command and check for errors
if err := cmd.Run(); err != nil {
log.Fatal(err)
}
}
}
// Dump changes on the terminal
fmt.Println(host.String())
}
// Marshal text to bytes
mt, err := c.MarshalText()
if err != nil {
log.Fatal(err)
}
// Write the changes to file
err = ioutil.WriteFile(*confFlag, mt, 0644)
if err != nil {
log.Fatal(err)
}
}