diff --git a/cmd/start.go b/cmd/start.go index a91dfe0..82de448 100644 --- a/cmd/start.go +++ b/cmd/start.go @@ -49,6 +49,8 @@ func newProxy(logger *zap.Logger) (*proxy.Proxy, error) { return nil, fmt.Errorf("failed to validate config %w", err) } + cfg.ResolveTwingateHost() + registry := prometheus.NewRegistry() p, err := proxy.NewProxy(cfg, registry, logger) diff --git a/go.mod b/go.mod index d86d32c..814bdf1 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ require ( github.com/fsnotify/fsnotify v1.10.1 github.com/golang-jwt/jwt/v5 v5.3.1 github.com/google/uuid v1.6.0 + github.com/hashicorp/go-retryablehttp v0.7.8 github.com/hashicorp/vault/api v1.23.0 github.com/hashicorp/vault/api/auth/approle v0.12.0 github.com/hashicorp/vault/api/auth/aws v0.12.0 @@ -60,7 +61,6 @@ require ( github.com/hashicorp/go-cleanhttp v0.5.2 // indirect github.com/hashicorp/go-hclog v1.6.3 // indirect github.com/hashicorp/go-multierror v1.1.1 // indirect - github.com/hashicorp/go-retryablehttp v0.7.8 // indirect github.com/hashicorp/go-rootcerts v1.0.2 // indirect github.com/hashicorp/go-secure-stdlib/awsutil v0.3.0 // indirect github.com/hashicorp/go-secure-stdlib/parseutil v0.2.0 // indirect diff --git a/internal/config/config.go b/internal/config/config.go index 45bdcbb..46f76e9 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -6,10 +6,12 @@ package config import ( "errors" "fmt" + "net/http" "os" "strings" "time" + "github.com/hashicorp/go-retryablehttp" "go.yaml.in/yaml/v4" "golang.org/x/crypto/ssh" ) @@ -187,6 +189,44 @@ func Load(path string) (*Config, error) { return cfg, nil } +func stripNetworkPrefix(hostname, network string) string { + return strings.TrimPrefix(hostname, network+".") +} + +func resolveTwingateHostname(targetURL, defaultHost string, retryMax int) string { + client := retryablehttp.NewClient() + client.HTTPClient.Timeout = 1 * time.Second + client.HTTPClient.CheckRedirect = func(_ *http.Request, _ []*http.Request) error { + return http.ErrUseLastResponse + } + client.RetryMax = retryMax + client.Logger = nil + + resp, err := client.Head(targetURL) + if err != nil { + return defaultHost + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusPermanentRedirect { + return defaultHost + } + + location, err := resp.Location() + if err != nil { + return defaultHost + } + + return location.Hostname() +} + +func (c *Config) ResolveTwingateHost() { + targetURL := fmt.Sprintf("https://%s.%s/api/v1/jwk/ec", c.Twingate.Network, c.Twingate.Host) + resolvedHostname := resolveTwingateHostname(targetURL, c.Twingate.Host, 2) + + c.Twingate.Host = stripNetworkPrefix(resolvedHostname, c.Twingate.Network) +} + func (c *Config) Validate() error { if c.Twingate.Network == "" { return fmt.Errorf("%w: twingate.network", ErrRequired) diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 746bc04..42947a1 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -4,6 +4,8 @@ package config import ( + "net/http" + "net/http/httptest" "os" "path/filepath" "testing" @@ -13,6 +15,92 @@ import ( "github.com/stretchr/testify/require" ) +func TestStripNetworkPrefix(t *testing.T) { + tests := []struct { + name string + hostname string + network string + expected string + }{ + {name: "sharded host", hostname: "acme.us1.test.com", network: "acme", expected: "us1.test.com"}, + {name: "non-sharded host", hostname: "acme.test.com", network: "acme", expected: "test.com"}, + {name: "no network prefix", hostname: "test.com", network: "acme", expected: "test.com"}, + {name: "empty network", hostname: "us1.twingate.com", network: "", expected: "us1.twingate.com"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, stripNetworkPrefix(tt.hostname, tt.network)) + }) + } +} + +func TestResolveTwingateHostname(t *testing.T) { + t.Run("returns location hostname on 308 status code", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Location", "https://acme.us1.twingate.com/api/v1/jwk/ec") + w.WriteHeader(http.StatusPermanentRedirect) + })) + t.Cleanup(server.Close) + + result := resolveTwingateHostname(server.URL+"/api/v1/jwk/ec", "twingate.com", 0) + assert.Equal(t, "acme.us1.twingate.com", result) + }) + + t.Run("returns default host on empty location", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Location", "") + w.WriteHeader(http.StatusPermanentRedirect) + })) + t.Cleanup(server.Close) + + result := resolveTwingateHostname(server.URL+"/api/v1/jwk/ec", "twingate.com", 0) + + assert.Equal(t, "twingate.com", result) + }) + + t.Run("returns default host on non 308 status code", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + t.Cleanup(server.Close) + + result := resolveTwingateHostname(server.URL+"/api/v1/jwk/ec", "twingate.com", 0) + + assert.Equal(t, "twingate.com", result) + }) + + t.Run("does not follow redirect", func(t *testing.T) { + shardServerCalled := make(chan struct{}, 1) + + shardServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + shardServerCalled <- struct{}{} + + w.WriteHeader(http.StatusOK) + })) + t.Cleanup(shardServer.Close) + + redirectServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, shardServer.URL+r.URL.Path, http.StatusPermanentRedirect) + })) + t.Cleanup(redirectServer.Close) + + resolveTwingateHostname(redirectServer.URL+"/api/v1/jwk/ec", "twingate.com", 0) + + select { + case <-shardServerCalled: + t.Fatal("should not follow redirect to shard server") + default: + } + }) + + t.Run("returns default host on connection error", func(t *testing.T) { + result := resolveTwingateHostname("http://127.0.0.1:1/api/v1/jwk/ec", "twingate.com", 0) + + assert.Equal(t, "twingate.com", result) + }) +} + func TestLoad_Kubernetes(t *testing.T) { yaml := ` twingate: diff --git a/internal/token/parser.go b/internal/token/parser.go index e0f649e..2a13063 100644 --- a/internal/token/parser.go +++ b/internal/token/parser.go @@ -6,6 +6,7 @@ package token import ( "errors" "fmt" + "strings" "github.com/MicahParks/keyfunc/v3" "github.com/golang-jwt/jwt/v5" @@ -15,13 +16,23 @@ var errInvalidTokenType = errors.New("token type is invalid") var allowedSigningMethods = []string{jwt.SigningMethodES256.Alg()} -var allowedIssuerByHost = map[string]string{ +var allowedIssuerByDomain = map[string]string{ "test": "twingate-local", "dev.opstg.com": "twingate-dev", "stg.opstg.com": "twingate-stg", "twingate.com": "twingate", } +func getIssuer(host string) string { + for baseDomain, issuer := range allowedIssuerByDomain { + if baseDomain == host || strings.HasSuffix(host, "."+baseDomain) { + return issuer + } + } + + return "" +} + type ClaimsWithHeaderType interface { getHeaderType() string } @@ -55,7 +66,7 @@ func NewParser(config ParserConfig) (*Parser, error) { return &Parser{ parser: jwt.NewParser( jwt.WithValidMethods(allowedSigningMethods), - jwt.WithIssuer(allowedIssuerByHost[config.Host]), + jwt.WithIssuer(getIssuer(config.Host)), jwt.WithAudience(config.Network), jwt.WithIssuedAt(), jwt.WithExpirationRequired(), diff --git a/internal/token/parser_test.go b/internal/token/parser_test.go index e1b8390..ff21833 100644 --- a/internal/token/parser_test.go +++ b/internal/token/parser_test.go @@ -12,6 +12,7 @@ import ( "time" "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -205,3 +206,21 @@ func TestNewParser(t *testing.T) { }) } } + +func TestGetIssuer(t *testing.T) { + tests := []struct { + name string + host string + issuer string + }{ + {name: "exact match", host: "twingate.com", issuer: "twingate"}, + {name: "sharded host", host: "us1.twingate.com", issuer: "twingate"}, + {name: "unknown host", host: "unknown-dev.opstg.com", issuer: ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.issuer, getIssuer(tt.host)) + }) + } +}