Skip to content

Commit

Permalink
Control ALPN enabled verification using an env var
Browse files Browse the repository at this point in the history
  • Loading branch information
arjan-bal committed May 3, 2024
1 parent bb33388 commit 514de04
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 35 deletions.
29 changes: 18 additions & 11 deletions credentials/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,13 @@ import (
"net/url"
"os"

"google.golang.org/grpc/grpclog"
credinternal "google.golang.org/grpc/internal/credentials"
"google.golang.org/grpc/internal/envconfig"
)

var logger = grpclog.Component("credentials")

// TLSInfo contains the auth information for a TLS authenticated connection.
// It implements the AuthInfo interface.
type TLSInfo struct {
Expand Down Expand Up @@ -113,17 +117,20 @@ func (c *tlsCreds) ClientHandshake(ctx context.Context, authority string, rawCon
return nil, nil, ctx.Err()
}

// The negotiated protocol can be either of the following:
// 1. h2: When the server supports ALPN. Only HTTP/2 can be negotiated since
// it is the only protocol advertised by the client during the handshake.
// The tls library ensures that the server chooses a protocol advertised
// by the client.
// 2. "" (empty string): If the server doesn't support ALPN. ALPN is a requirement
// for using HTTP/2 over TLS. We can terminate the connection immediately.
np := conn.ConnectionState().NegotiatedProtocol
if np == "" {
_ = conn.Close()
return nil, nil, fmt.Errorf("cannot check peer: missing selected ALPN property")
// The negotiated protocol can be either of the following:
// 1. h2: When the server supports ALPN. Only HTTP/2 can be negotiated since
// it is the only protocol advertised by the client during the handshake.
// The tls library ensures that the server chooses a protocol advertised
// by the client.
// 2. "" (empty string): If the server doesn't support ALPN. ALPN is a requirement
// for using HTTP/2 over TLS. We can terminate the connection immediately.
np := conn.ConnectionState().NegotiatedProtocol
if np == "" {
if envconfig.EnforceALPNEnabled {
_ = conn.Close()
return nil, nil, fmt.Errorf("cannot check peer: missing selected ALPN property")
}
logger.Warning("Allowing TLS connection to server %q with ALPN disabled")
}
tlsInfo := TLSInfo{
State: conn.ConnectionState(),
Expand Down
100 changes: 76 additions & 24 deletions credentials/tls_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,15 @@ import (
"crypto/x509"
"fmt"
"os"
"regexp"
"strings"
"testing"
"time"

"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal/envconfig"
"google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/internal/stubserver"
"google.golang.org/grpc/status"
Expand Down Expand Up @@ -243,6 +245,11 @@ func (s) TestTLS_DisabledALPN(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()

initialVal := envconfig.EnforceALPNEnabled
defer func() {
envconfig.EnforceALPNEnabled = initialVal
}()

// Start a non gRPC TLS server.
config := &tls.Config{
Certificates: []tls.Certificate{serverCert},
Expand All @@ -254,31 +261,76 @@ func (s) TestTLS_DisabledALPN(t *testing.T) {
}
defer listner.Close()

// Start listening for server requests in a new go routine.
go func() {
conn, err := listner.Accept()
if err != nil {
t.Errorf("tls.Accept failed err = %v", err)
} else {
_, _ = conn.Write([]byte("Hello, World!"))
_ = conn.Close()
}
}()

clientCreds := credentials.NewTLS(&tls.Config{
ServerName: serverName,
RootCAs: certPool,
})

cc, err := grpc.NewClient("dns:"+listner.Addr().String(), grpc.WithTransportCredentials(clientCreds))
if err != nil {
t.Fatalf("grpc.NewClient error: %v", err)
tests := []struct {
description string
alpnEnforced bool
wantErrMatchPattern string
wantErrNonMatchPattern string
}{
{
description: "enforced",
alpnEnforced: true,
wantErrMatchPattern: "transport: .*missing selected ALPN property",
},
{
description: "not_enforced",
wantErrNonMatchPattern: "transport:",
},
{
description: "default_value",
wantErrNonMatchPattern: "transport:",
alpnEnforced: initialVal,
},
}
defer cc.Close()
client := testgrpc.NewTestServiceClient(cc)

const wantStr = "missing selected ALPN property"
if _, err = client.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) != codes.Unavailable || !strings.Contains(status.Convert(err).Message(), wantStr) {
t.Fatalf("EmptyCall err = %v; want code=%v, message contains %q", err, codes.Unavailable, wantStr)
for _, tc := range tests {
t.Run(tc.description, func(t *testing.T) {
envconfig.EnforceALPNEnabled = tc.alpnEnforced
// Listen to one TCP connection request.
go func() {
conn, err := listner.Accept()
if err != nil {
t.Errorf("tls.Accept failed err = %v", err)
} else {
_, _ = conn.Write([]byte("Hello, World!"))
_ = conn.Close()
}
}()

clientCreds := credentials.NewTLS(&tls.Config{
ServerName: serverName,
RootCAs: certPool,
})

cc, err := grpc.NewClient("dns:"+listner.Addr().String(), grpc.WithTransportCredentials(clientCreds))
if err != nil {
t.Fatalf("grpc.NewClient error: %v", err)
}
defer cc.Close()
client := testgrpc.NewTestServiceClient(cc)
_, rpcErr := client.EmptyCall(ctx, &testpb.Empty{})

if gotCode := status.Code(rpcErr); gotCode != codes.Unavailable {
t.Errorf("EmptyCall returned unexpected code: got=%v, want=%v", gotCode, codes.Unavailable)
}

matchPat, err := regexp.Compile(tc.wantErrMatchPattern)
if err != nil {
t.Fatalf("Error message match pattern %q is invalid due to error: %v", tc.wantErrMatchPattern, err)
}

if tc.wantErrMatchPattern != "" && !matchPat.MatchString(status.Convert(rpcErr).Message()) {
t.Errorf("EmptyCall err = %v; want pattern match %q", rpcErr, matchPat)
}
nonMatchPat, err := regexp.Compile(tc.wantErrNonMatchPattern)
if err != nil {
t.Fatalf("Error message non-match pattern %q is invalid due to error: %v", tc.wantErrNonMatchPattern, err)
}

if tc.wantErrNonMatchPattern != "" && nonMatchPat.MatchString(status.Convert(rpcErr).Message()) {
t.Errorf("EmptyCall err = %v; want pattern missing %q", rpcErr, nonMatchPat)
}
})

}
}
6 changes: 6 additions & 0 deletions internal/envconfig/envconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ var (
// ALTSMaxConcurrentHandshakes is the maximum number of concurrent ALTS
// handshakes that can be performed.
ALTSMaxConcurrentHandshakes = uint64FromEnv("GRPC_ALTS_MAX_CONCURRENT_HANDSHAKES", 100, 1, 100)
// EnforceALPNEnabled is set if TLS connections to servers with ALPN disabled
// should be rejected. The HTTP/2 protocol requires ALPN to be enabled, this
// option is present for backward compatibility. This option may be overridden
// by setting the environment variable "GRPC_ENFORCE_ALPN_ENABLED" to "true"
// or "false".
EnforceALPNEnabled = boolFromEnv("GRPC_ENFORCE_ALPN_ENABLED", false)
)

func boolFromEnv(envVar string, def bool) bool {
Expand Down

0 comments on commit 514de04

Please sign in to comment.