Skip to content
This repository has been archived by the owner on Apr 2, 2020. It is now read-only.

Commit

Permalink
Merge pull request #8 from daveschaefer/arg-patch
Browse files Browse the repository at this point in the history
Add support for prompting for password
  • Loading branch information
razvanm committed Jun 8, 2016
2 parents 3b25a79 + 41fbf10 commit e72f99b
Showing 1 changed file with 45 additions and 4 deletions.
49 changes: 45 additions & 4 deletions main.go
Expand Up @@ -33,15 +33,20 @@ import (
"io/ioutil"
"log"
"os"
"regexp"
"strings"
"syscall"
"time"

"github.com/go-sql-driver/mysql"
"golang.org/x/crypto/ssh/terminal"
)

var (
dump = flag.String("dump", "", "MySQL dump file")
dsn = flag.String("dsn", "root:root@tcp(0.0.0.0:3306)/", "MySQL Data Source Name")
dsn = flag.String("dsn", "user:password@tcp(0.0.0.0:3306)/", "MySQL Data Source Name")
enableSsl = flag.Bool("enable_ssl", false, "Connect to MySQL with SSL")
prompt = flag.Bool("prompt", false, "Prompt for password rather than specifying in the command. Change dsn format to 'user@tcp(0.0.0.0:3306)/'")
sslCa = flag.String("ssl_ca", "server-ca.pem", "MySQL Server certificate")
sslCert = flag.String("ssl_cert", "client-cert.pem", "MySQL Client PEM cert file")
sslKey = flag.String("ssl_key", "client-key.pem", "MySQL Client PEM key file")
Expand Down Expand Up @@ -134,12 +139,13 @@ func main() {
log.Fatalf("no -dump file specified")
}

var finalDsn = *dsn
if *enableSsl {
rootCertPool := x509.NewCertPool()
pem, err := ioutil.ReadFile(*sslCa)
if err != nil {
log.Fatalln("ioutil.Readline:", err)
}
rootCertPool := x509.NewCertPool()
if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
log.Fatal("Failed to append CA certificate PEM.")
}
Expand All @@ -149,13 +155,48 @@ func main() {
log.Fatalln("tls.LoadX509KeyPair:", err)
}
clientCert = append(clientCert, certs)
mysql.RegisterTLSConfig("custom", &tls.Config{
const customTLSName = "custom"
tlserr := mysql.RegisterTLSConfig(customTLSName, &tls.Config{
RootCAs: rootCertPool,
Certificates: clientCert,
ServerName: *serverName,
})
if tlserr != nil {
log.Fatalln("mysql.RegisterTLSConfig:", tlserr)
}
finalDsn = strings.Join([]string{finalDsn, "?tls=", customTLSName}, "")
}
db, err := sql.Open("mysql", *dsn)

if *prompt {
// DSN strings look like:
// user:password@tcp(0.0.0.0:3306)/
// With this flag the user can avoid typing their password:
// user@tcp(0.0.0.0:3306)/
// Save text before ':' and after '@' so we can insert the password
// to create a proper DSN string.
dsnRegex := regexp.MustCompile(`(\w*):?\w*(@.+)`)
matches := dsnRegex.FindStringSubmatch(finalDsn)
if matches == nil {
fmt.Print("Incorrect format for dsn. Usage:\n")
flag.PrintDefaults()
os.Exit(1)
}

fmt.Print("Enter password: ")
// Don't echo password to screen during input.
password, err := terminal.ReadPassword(int(syscall.Stdin))
if err != nil {
log.Fatalln("Error reading password:", err)
}
// ReadPassword() leaves cursor on the input line,
// so begin output on the next line
fmt.Print("\n")

// Insert password into the connection string.
finalDsn = strings.Join([]string{matches[1], ":", string(password), matches[2]}, "")
}

db, err := sql.Open("mysql", finalDsn)
if err != nil {
log.Fatalln("sql.Open:", err)
}
Expand Down

0 comments on commit e72f99b

Please sign in to comment.