Skip to content

Commit

Permalink
feat: add support for credentials file (#1151)
Browse files Browse the repository at this point in the history
  • Loading branch information
enocom committed May 5, 2022
1 parent 0ac17a5 commit f1ae7ea
Show file tree
Hide file tree
Showing 11 changed files with 220 additions and 71 deletions.
2 changes: 2 additions & 0 deletions .envrc.example
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,5 @@ export SQLSERVER_CONNECTION_NAME="project:region:instance"
export SQLSERVER_USER="sqlserver-user"
export SQLSERVER_PASS="sqlserver-password"
export SQLSERVER_DB="sqlserver-db-name"

export GOOGLE_APPLICATION_CREDENTIALS=/path/to/key.json
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,7 @@
# Compiled binary
/cmd/cloud_sql_proxy/cloud_sql_proxy
/cloud_sql_proxy
# v2 binary
/cloudsql-proxy

/key.json
22 changes: 19 additions & 3 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ connecting to Cloud SQL instances. It listens on a local port and forwards conne
to your instance's IP address, providing a secure connection without having to manage
any client SSL certificates.`,
Args: func(cmd *cobra.Command, args []string) error {
err := parseConfig(c.conf, args)
err := parseConfig(cmd, c.conf, args)
if err != nil {
return err
}
Expand All @@ -108,6 +108,8 @@ any client SSL certificates.`,
// Global-only flags
cmd.PersistentFlags().StringVarP(&c.conf.Token, "token", "t", "",
"Bearer token used for authorization.")
cmd.PersistentFlags().StringVarP(&c.conf.CredentialsFile, "credentials-file", "c", "",
"Path to a service account key to use for authentication.")

// Global and per instance flags
cmd.PersistentFlags().StringVarP(&c.conf.Addr, "address", "a", "127.0.0.1",
Expand All @@ -119,7 +121,7 @@ any client SSL certificates.`,
return c
}

func parseConfig(conf *proxy.Config, args []string) error {
func parseConfig(cmd *cobra.Command, conf *proxy.Config, args []string) error {
// If no instance connection names were provided, error.
if len(args) == 0 {
return newBadCommandError("missing instance_connection_name (e.g., project:region:instance)")
Expand All @@ -129,6 +131,20 @@ func parseConfig(conf *proxy.Config, args []string) error {
return newBadCommandError(fmt.Sprintf("not a valid IP address: %q", conf.Addr))
}

// If both token and credentials file were set, error.
if conf.Token != "" && conf.CredentialsFile != "" {
return newBadCommandError("Cannot specify --token and --credentials-file flags at the same time")
}

switch {
case conf.Token != "":
cmd.Printf("Authorizing with the -token flag\n")
case conf.CredentialsFile != "":
cmd.Printf("Authorizing with the credentials file at %q\n", conf.CredentialsFile)
default:
cmd.Printf("Authorizing with Application Default Credentials")
}

var ics []proxy.InstanceConnConfig
for _, a := range args {
// Assume no query params initially
Expand Down Expand Up @@ -211,8 +227,8 @@ func runSignalWrapper(cmd *Command) error {
// Otherwise, initialize a new one.
d := cmd.conf.Dialer
if d == nil {
var err error
opts := append(cmd.conf.DialerOpts(), cloudsqlconn.WithUserAgent(userAgent))
var err error
d, err = cloudsqlconn.NewDialer(ctx, opts...)
if err != nil {
shutdownCh <- fmt.Errorf("error initializing dialer: %v", err)
Expand Down
20 changes: 20 additions & 0 deletions cmd/root_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,20 @@ func TestNewCommandArguments(t *testing.T) {
Token: "MYCOOLTOKEN",
}),
},
{
desc: "using the credentiale file flag",
args: []string{"--credentials-file", "/path/to/file", "proj:region:inst"},
want: withDefaults(&proxy.Config{
CredentialsFile: "/path/to/file",
}),
},
{
desc: "using the (short) credentiale file flag",
args: []string{"-c", "/path/to/file", "proj:region:inst"},
want: withDefaults(&proxy.Config{
CredentialsFile: "/path/to/file",
}),
},
}

for _, tc := range tcs {
Expand Down Expand Up @@ -186,6 +200,12 @@ func TestNewCommandWithErrors(t *testing.T) {
desc: "when the port query param is not a number",
args: []string{"proj:region:inst?port=hi"},
},
{
desc: "when both token and credentials file is set",
args: []string{
"--token", "my-token",
"--credentials-file", "/path/to/file", "proj:region:inst"},
},
}

for _, tc := range tcs {
Expand Down
16 changes: 9 additions & 7 deletions internal/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ type Config struct {
// Token is the Bearer token used for authorization.
Token string

// CredentialsFile is the path to a service account key.
CredentialsFile string

// Addr is the address on which to bind all instances.
Addr string

Expand All @@ -61,18 +64,17 @@ type Config struct {
Dialer cloudsql.Dialer
}

// NewConfig initializes a Config struct using the default database engine
// ports.
func NewConfig() *Config {
return &Config{}
}

func (c *Config) DialerOpts() []cloudsqlconn.Option {
var opts []cloudsqlconn.Option
if c.Token != "" {
switch {
case c.Token != "":
opts = append(opts, cloudsqlconn.WithTokenSource(
oauth2.StaticTokenSource(&oauth2.Token{AccessToken: c.Token}),
))
case c.CredentialsFile != "":
opts = append(opts, cloudsqlconn.WithCredentialsFile(
c.CredentialsFile,
))
}
return opts
}
Expand Down
27 changes: 0 additions & 27 deletions internal/proxy/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,30 +167,3 @@ func TestClientInitialization(t *testing.T) {
})
}
}

func TestConfigDialerOpts(t *testing.T) {
tcs := []struct {
desc string
config proxy.Config
wantLen int
}{
{
desc: "when there are no options",
config: proxy.Config{},
wantLen: 0,
},
{
desc: "when a token is present",
config: proxy.Config{Token: "my-token"},
wantLen: 1,
},
}

for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
if got := tc.config.DialerOpts(); tc.wantLen != len(got) {
t.Errorf("want len = %v, got = %v", tc.wantLen, len(got))
}
})
}
}
15 changes: 10 additions & 5 deletions tests/alldb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"fmt"
"net/http"
"testing"
"time"
)

// requireAllVars skips the given test if at least one environment variable is undefined.
Expand All @@ -43,13 +44,17 @@ func TestMultiInstanceDial(t *testing.T) {
t.Skip("skipping Health Check integration tests")
}
requireAllVars(t)
ctx := context.Background()

var args []string
args = append(args, fmt.Sprintf("-instances=%s=tcp:%d,%s=tcp:%d,%s=tcp:%d", *mysqlConnName, mysqlPort, *postgresConnName, postgresPort, *sqlserverConnName, sqlserverPort))
args = append(args, "-use_http_health_check")
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()

// Start the proxy.
args := []string{
// This test doesn't care what the instance port is, so use "0" which
// means, let the runtime pick a random port.
fmt.Sprintf("-instances=%s=tcp:0,%s=tcp:0,%s=tcp:0",
*mysqlConnName, *postgresConnName, *sqlserverConnName),
"-use_http_health_check",
}
p, err := StartProxy(ctx, args...)
if err != nil {
t.Fatalf("unable to start proxy: %v", err)
Expand Down
38 changes: 34 additions & 4 deletions testsV2/connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,45 @@ package tests
import (
"context"
"database/sql"
"os"
"testing"
"time"

"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
"google.golang.org/api/sqladmin/v1"
)

// proxyConnTest is a test helper to verify the proxy works with a basic connectivity test.
func proxyConnTest(t *testing.T, connName, driver, dsn string, port int, dir string) {
ctx := context.Background()
const connTestTimeout = time.Minute

args := []string{connName}
// removeAuthEnvVar retrieves an OAuth2 token and a path to a service account key
// and then unsets GOOGLE_APPLICATION_CREDENTIALS. It returns a cleanup function
// that restores the original setup.
func removeAuthEnvVar(t *testing.T) (*oauth2.Token, string, func()) {
ts, err := google.DefaultTokenSource(context.Background(), sqladmin.SqlserviceAdminScope)
if err != nil {
t.Errorf("failed to resolve token source: %v", err)
}
tok, err := ts.Token()
if err != nil {
t.Errorf("failed to get token: %v", err)
}
path, ok := os.LookupEnv("GOOGLE_APPLICATION_CREDENTIALS")
if !ok {
t.Fatalf("GOOGLE_APPLICATION_CREDENTIALS was not set in the environment")
}
if err := os.Unsetenv("GOOGLE_APPLICATION_CREDENTIALS"); err != nil {
t.Fatalf("failed to unset GOOGLE_APPLICATION_CREDENTIALS")
}
return tok, path, func() {
os.Setenv("GOOGLE_APPLICATION_CREDENTIALS", path)
}
}

// proxyConnTest is a test helper to verify the proxy works with a basic connectivity test.
func proxyConnTest(t *testing.T, args []string, driver, dsn string) {
ctx, cancel := context.WithTimeout(context.Background(), connTestTimeout)
defer cancel()
// Start the proxy
p, err := StartProxy(ctx, args...)
if err != nil {
Expand Down
58 changes: 49 additions & 9 deletions testsV2/mysql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,36 +27,76 @@ var (
mysqlConnName = flag.String("mysql_conn_name", os.Getenv("MYSQL_CONNECTION_NAME"), "Cloud SQL MYSQL instance connection name, in the form of 'project:region:instance'.")
mysqlUser = flag.String("mysql_user", os.Getenv("MYSQL_USER"), "Name of database user.")
mysqlPass = flag.String("mysql_pass", os.Getenv("MYSQL_PASS"), "Password for the database user; be careful when entering a password on the command line (it may go into your terminal's history).")
mysqlDb = flag.String("mysql_db", os.Getenv("MYSQL_DB"), "Name of the database to connect to.")

mysqlPort = 3306
mysqlDB = flag.String("mysql_db", os.Getenv("MYSQL_DB"), "Name of the database to connect to.")
)

func requireMysqlVars(t *testing.T) {
func requireMySQLVars(t *testing.T) {
switch "" {
case *mysqlConnName:
t.Fatal("'mysql_conn_name' not set")
case *mysqlUser:
t.Fatal("'mysql_user' not set")
case *mysqlPass:
t.Fatal("'mysql_pass' not set")
case *mysqlDb:
case *mysqlDB:
t.Fatal("'mysql_db' not set")
}
}

func TestMysqlTcp(t *testing.T) {
func TestMySQLTCP(t *testing.T) {
if testing.Short() {
t.Skip("skipping MySQL integration tests")
}
requireMySQLVars(t)
cfg := mysql.Config{
User: *mysqlUser,
Passwd: *mysqlPass,
DBName: *mysqlDB,
AllowNativePasswords: true,
Addr: "127.0.0.1:3306",
Net: "tcp",
}
proxyConnTest(t, []string{*mysqlConnName}, "mysql", cfg.FormatDSN())
}

func TestMySQLAuthWithToken(t *testing.T) {
if testing.Short() {
t.Skip("skipping MySQL integration tests")
}
requireMysqlVars(t)
requireMySQLVars(t)
tok, _, cleanup := removeAuthEnvVar(t)
defer cleanup()

cfg := mysql.Config{
User: *mysqlUser,
Passwd: *mysqlPass,
DBName: *mysqlDB,
AllowNativePasswords: true,
Addr: "127.0.0.1:3306",
Net: "tcp",
}
proxyConnTest(t,
[]string{"--token", tok.AccessToken, *mysqlConnName},
"mysql", cfg.FormatDSN())
}

func TestMySQLAuthWithCredentialsFile(t *testing.T) {
if testing.Short() {
t.Skip("skipping MySQL integration tests")
}
requireMySQLVars(t)
_, path, cleanup := removeAuthEnvVar(t)
defer cleanup()

cfg := mysql.Config{
User: *mysqlUser,
Passwd: *mysqlPass,
DBName: *mysqlDb,
DBName: *mysqlDB,
AllowNativePasswords: true,
Addr: "127.0.0.1:3306",
Net: "tcp",
}
proxyConnTest(t, *mysqlConnName, "mysql", cfg.FormatDSN(), mysqlPort, "")
proxyConnTest(t,
[]string{"--credentials-file", path, *mysqlConnName},
"mysql", cfg.FormatDSN())
}
Loading

0 comments on commit f1ae7ea

Please sign in to comment.