Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cmd/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ func newProxy(logger *zap.Logger) (*proxy.Proxy, error) {
return nil, fmt.Errorf("failed to validate config %w", err)
}

cfg.ResolveTwingateHost()

registry := prometheus.NewRegistry()

p, err := proxy.NewProxy(cfg, registry, logger)
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ require (
github.com/fsnotify/fsnotify v1.10.1
github.com/golang-jwt/jwt/v5 v5.3.1
github.com/google/uuid v1.6.0
github.com/hashicorp/go-retryablehttp v0.7.8
github.com/hashicorp/vault/api v1.23.0
github.com/hashicorp/vault/api/auth/approle v0.12.0
github.com/hashicorp/vault/api/auth/aws v0.12.0
Expand Down Expand Up @@ -60,7 +61,6 @@ require (
github.com/hashicorp/go-cleanhttp v0.5.2 // indirect
github.com/hashicorp/go-hclog v1.6.3 // indirect
github.com/hashicorp/go-multierror v1.1.1 // indirect
github.com/hashicorp/go-retryablehttp v0.7.8 // indirect
github.com/hashicorp/go-rootcerts v1.0.2 // indirect
github.com/hashicorp/go-secure-stdlib/awsutil v0.3.0 // indirect
github.com/hashicorp/go-secure-stdlib/parseutil v0.2.0 // indirect
Expand Down
40 changes: 40 additions & 0 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ package config
import (
"errors"
"fmt"
"net/http"
"os"
"strings"
"time"

"github.com/hashicorp/go-retryablehttp"
"go.yaml.in/yaml/v4"
"golang.org/x/crypto/ssh"
)
Expand Down Expand Up @@ -187,6 +189,44 @@ func Load(path string) (*Config, error) {
return cfg, nil
}

func stripNetworkPrefix(hostname, network string) string {
return strings.TrimPrefix(hostname, network+".")
}

func resolveTwingateHostname(targetURL, defaultHost string, retryMax int) string {
client := retryablehttp.NewClient()
client.HTTPClient.Timeout = 1 * time.Second
client.HTTPClient.CheckRedirect = func(_ *http.Request, _ []*http.Request) error {
return http.ErrUseLastResponse
}
client.RetryMax = retryMax
client.Logger = nil

resp, err := client.Head(targetURL)
if err != nil {
return defaultHost
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusPermanentRedirect {
return defaultHost
}

location, err := resp.Location()
if err != nil {
return defaultHost
}

return location.Hostname()
}

func (c *Config) ResolveTwingateHost() {
targetURL := fmt.Sprintf("https://%s.%s/api/v1/jwk/ec", c.Twingate.Network, c.Twingate.Host)
resolvedHostname := resolveTwingateHostname(targetURL, c.Twingate.Host, 2)

c.Twingate.Host = stripNetworkPrefix(resolvedHostname, c.Twingate.Network)
}

func (c *Config) Validate() error {
if c.Twingate.Network == "" {
return fmt.Errorf("%w: twingate.network", ErrRequired)
Expand Down
88 changes: 88 additions & 0 deletions internal/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
package config

import (
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
Expand All @@ -13,6 +15,92 @@ import (
"github.com/stretchr/testify/require"
)

func TestStripNetworkPrefix(t *testing.T) {
tests := []struct {
name string
hostname string
network string
expected string
}{
{name: "sharded host", hostname: "acme.us1.test.com", network: "acme", expected: "us1.test.com"},
{name: "non-sharded host", hostname: "acme.test.com", network: "acme", expected: "test.com"},
{name: "no network prefix", hostname: "test.com", network: "acme", expected: "test.com"},
{name: "empty network", hostname: "us1.twingate.com", network: "", expected: "us1.twingate.com"},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.expected, stripNetworkPrefix(tt.hostname, tt.network))
})
}
}

func TestResolveTwingateHostname(t *testing.T) {
t.Run("returns location hostname on 308 status code", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Location", "https://acme.us1.twingate.com/api/v1/jwk/ec")
w.WriteHeader(http.StatusPermanentRedirect)
}))
t.Cleanup(server.Close)

result := resolveTwingateHostname(server.URL+"/api/v1/jwk/ec", "twingate.com", 0)
assert.Equal(t, "acme.us1.twingate.com", result)
})

t.Run("returns default host on empty location", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Location", "")
w.WriteHeader(http.StatusPermanentRedirect)
}))
t.Cleanup(server.Close)

result := resolveTwingateHostname(server.URL+"/api/v1/jwk/ec", "twingate.com", 0)

assert.Equal(t, "twingate.com", result)
})

t.Run("returns default host on non 308 status code", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
}))
t.Cleanup(server.Close)

result := resolveTwingateHostname(server.URL+"/api/v1/jwk/ec", "twingate.com", 0)

assert.Equal(t, "twingate.com", result)
})

t.Run("does not follow redirect", func(t *testing.T) {
shardServerCalled := make(chan struct{}, 1)

shardServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
shardServerCalled <- struct{}{}

w.WriteHeader(http.StatusOK)
}))
t.Cleanup(shardServer.Close)

redirectServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, shardServer.URL+r.URL.Path, http.StatusPermanentRedirect)
}))
t.Cleanup(redirectServer.Close)

resolveTwingateHostname(redirectServer.URL+"/api/v1/jwk/ec", "twingate.com", 0)

select {
case <-shardServerCalled:
t.Fatal("should not follow redirect to shard server")
default:
}
})

t.Run("returns default host on connection error", func(t *testing.T) {
result := resolveTwingateHostname("http://127.0.0.1:1/api/v1/jwk/ec", "twingate.com", 0)

assert.Equal(t, "twingate.com", result)
})
Comment thread
minhtule marked this conversation as resolved.
}

func TestLoad_Kubernetes(t *testing.T) {
yaml := `
twingate:
Expand Down
15 changes: 13 additions & 2 deletions internal/token/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package token
import (
"errors"
"fmt"
"strings"

"github.com/MicahParks/keyfunc/v3"
"github.com/golang-jwt/jwt/v5"
Expand All @@ -15,13 +16,23 @@ var errInvalidTokenType = errors.New("token type is invalid")

var allowedSigningMethods = []string{jwt.SigningMethodES256.Alg()}

var allowedIssuerByHost = map[string]string{
var allowedIssuerByDomain = map[string]string{
"test": "twingate-local",
"dev.opstg.com": "twingate-dev",
"stg.opstg.com": "twingate-stg",
"twingate.com": "twingate",
}

func getIssuer(host string) string {
for baseDomain, issuer := range allowedIssuerByDomain {
if baseDomain == host || strings.HasSuffix(host, "."+baseDomain) {
return issuer
}
}

return ""
}

type ClaimsWithHeaderType interface {
getHeaderType() string
}
Expand Down Expand Up @@ -55,7 +66,7 @@ func NewParser(config ParserConfig) (*Parser, error) {
return &Parser{
parser: jwt.NewParser(
jwt.WithValidMethods(allowedSigningMethods),
jwt.WithIssuer(allowedIssuerByHost[config.Host]),
jwt.WithIssuer(getIssuer(config.Host)),
jwt.WithAudience(config.Network),
jwt.WithIssuedAt(),
jwt.WithExpirationRequired(),
Expand Down
19 changes: 19 additions & 0 deletions internal/token/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"time"

"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -205,3 +206,21 @@ func TestNewParser(t *testing.T) {
})
}
}

func TestGetIssuer(t *testing.T) {
tests := []struct {
name string
host string
issuer string
}{
{name: "exact match", host: "twingate.com", issuer: "twingate"},
{name: "sharded host", host: "us1.twingate.com", issuer: "twingate"},
{name: "unknown host", host: "unknown-dev.opstg.com", issuer: ""},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.issuer, getIssuer(tt.host))
})
}
}