Skip to content

Commit

Permalink
home: imp tls
Browse files Browse the repository at this point in the history
  • Loading branch information
EugeneOne1 committed Oct 27, 2022
1 parent 9c9d6b4 commit 540efcb
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 90 deletions.
173 changes: 89 additions & 84 deletions internal/home/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,10 @@ func loadTLSConf(tlsConf *tlsConfigSettings, status *tlsConfigStatus) (err error
defer func() {
if err != nil {
status.WarningValidation = err.Error()
if errors.As(err, new(warningError)) {
// Do not return warnings since those aren't critical.
err = nil
}
}
}()

Expand Down Expand Up @@ -199,10 +203,15 @@ func loadTLSConf(tlsConf *tlsConfigSettings, status *tlsConfigStatus) (err error
tlsConf.ServerName,
)
if err != nil {
return fmt.Errorf("validating certificate pair: %w", err)
if warning := (warningError{}); errors.As(err, &warning) {
warning.error = fmt.Errorf("validating certificate pair: %w", warning.error)
err = warning
} else {
err = fmt.Errorf("validating certificate pair: %w", err)
}
}

return nil
return err
}

// tlsConfigStatus contains the status of a certificate chain and key pair.
Expand Down Expand Up @@ -482,38 +491,26 @@ func validatePorts(
return nil
}

// validateCertChain validates the certificate chain and sets data in status.
// The returned error is also set in status.WarningValidation.
func validateCertChain(status *tlsConfigStatus, certChain []byte, serverName string) (err error) {
defer func() {
if err != nil {
status.WarningValidation = err.Error()
}
}()

log.Debug("tls: got certificate chain: %d bytes", len(certChain))
// validateCertChain validates the certificate chain. It returns the first
// certificate within the chain. mainCert is guaranteed to be non-nil if the
// returned error is nil or a [warningError].
func validateCertChain(chain []byte, serverName string) (mainCert *x509.Certificate, err error) {
log.Debug("tls: got certificate chain: %d bytes", len(chain))

var certs []*pem.Block
pemblock := certChain
for {
var decoded *pem.Block
decoded, pemblock = pem.Decode(pemblock)
if decoded == nil {
break
}

for decoded, pemblock := pem.Decode(chain); decoded != nil; {
if decoded.Type == "CERTIFICATE" {
certs = append(certs, decoded)
}

decoded, pemblock = pem.Decode(pemblock)
}

parsedCerts, err := parsePEMCerts(certs)
if err != nil {
return err
return nil, err
}

status.ValidCert = true

opts := x509.VerifyOptions{
DNSName: serverName,
Roots: Context.tlsRoots,
Expand All @@ -529,24 +526,14 @@ func validateCertChain(status *tlsConfigStatus, certChain []byte, serverName str

opts.Intermediates = pool

mainCert := parsedCerts[0]
mainCert = parsedCerts[0]
_, err = mainCert.Verify(opts)
if err != nil {
// Let self-signed certs through and don't return this error.
status.WarningValidation = fmt.Sprintf("certificate does not verify: %s", err)
} else {
status.ValidChain = true
err = warningError{fmt.Errorf("certificate does not verify: %s", err)}
}

if mainCert != nil {
status.Subject = mainCert.Subject.String()
status.Issuer = mainCert.Issuer.String()
status.NotAfter = mainCert.NotAfter
status.NotBefore = mainCert.NotBefore
status.DNSNames = mainCert.DNSNames
}

return nil
return mainCert, err
}

// parsePEMCerts parses multiple PEM-encoded certificates.
Expand All @@ -568,56 +555,65 @@ func parsePEMCerts(certs []*pem.Block) (parsedCerts []*x509.Certificate, err err
return parsedCerts, nil
}

// validatePKey validates the private key and sets data in status. The returned
// error is also set in status.WarningValidation.
func validatePKey(status *tlsConfigStatus, pkey []byte) (err error) {
defer func() {
if err != nil {
status.WarningValidation = err.Error()
}
}()

// validatePKey validates the private key, returning its type.
func validatePKey(pkey []byte) (keyType string, err error) {
var key *pem.Block

// Go through all pem blocks, but take first valid pem block and drop the
// rest.
pemblock := []byte(pkey)
for {
var decoded *pem.Block
decoded, pemblock = pem.Decode(pemblock)
if decoded == nil {
break
}

for decoded, pemblock := pem.Decode([]byte(pkey)); decoded != nil; {
if decoded.Type == "PRIVATE KEY" || strings.HasSuffix(decoded.Type, " PRIVATE KEY") {
key = decoded

break
}

decoded, pemblock = pem.Decode(pemblock)
}

if key == nil {
return errors.Error("no valid keys were found")
return "", errors.Error("no valid keys were found")
}

_, keyType, err := parsePrivateKey(key.Bytes)
_, keyType, err = parsePrivateKey(key.Bytes)
if err != nil {
return fmt.Errorf("parsing private key: %w", err)
return "", fmt.Errorf("parsing private key: %w", err)
}

if keyType == keyTypeED25519 {
return errors.Error(
return "", errors.Error(
"ED25519 keys are not supported by browsers; " +
"did you mean to use X25519 for key exchange?",
)
}

status.ValidKey = true
status.KeyType = keyType
return keyType, nil
}

return nil
// warningError is a non-critical error to be reported. It capitalizes self
// string representation, assuming that a wrapped error's message is ASCII-only.
type warningError struct{ error }

// type check
var _ error = warningError{}

// Error implements the [error] interface for warningError. It returns the
// capitalized string representation of werr.
func (werr warningError) Error() (msg string) {
msg = werr.error.Error()
if werr.error != nil {
msg = strings.ToUpper(msg[:1]) + msg[1:]
}

return msg
}

// type check
var _ errors.Wrapper = warningError{}

// Unwrap implements the [errors.Wrapper] interface for warningError.
func (werr warningError) Unwrap() (err error) { return werr.error }

// validateCertificates processes certificate data and its private key. All
// parameters are optional. status must not be nil. The returned error is also
// set in status.WarningValidation.
Expand All @@ -627,47 +623,56 @@ func validateCertificates(
pkey []byte,
serverName string,
) (err error) {
defer func() {
// Capitalize the warning for the UI. Assume that warnings are all
// ASCII-only.
//
// TODO(a.garipov): Figure out a better way to do this. Perhaps a
// custom string or error type.
if w := status.WarningValidation; w != "" {
status.WarningValidation = strings.ToUpper(w[:1]) + w[1:]
}
}()

// Check only the public certificate separately from the key.
if len(certChain) > 0 {
err = validateCertChain(status, certChain, serverName)
if err != nil {
return err
mainCert, verr := validateCertChain(certChain, serverName)
if verr != nil {
if !errors.As(verr, new(warningError)) {
return verr
} else {
err = verr
}
} else {
status.ValidChain = true
}

status.ValidCert = true
status.Subject = mainCert.Subject.String()
status.Issuer = mainCert.Issuer.String()
status.NotAfter = mainCert.NotAfter
status.NotBefore = mainCert.NotBefore
status.DNSNames = mainCert.DNSNames

if err == nil && len(mainCert.IPAddresses) == 0 {
const errMsg errors.Error = `certificate has no IP addresses` +
`, this may cause issues with DNS-over-TLS clients.`

err = warningError{errMsg}
}
}

// Validate the private key by parsing it.
if len(pkey) > 0 {
err = validatePKey(status, pkey)
if err != nil {
return err
keyType, verr := validatePKey(pkey)
if verr != nil {
return verr
}

status.ValidKey = true
status.KeyType = keyType
}

// If both are set, validate together.
if len(certChain) > 0 && len(pkey) > 0 {
_, err = tls.X509KeyPair(certChain, pkey)
if err != nil {
err = fmt.Errorf("certificate-key pair: %w", err)
status.WarningValidation = err.Error()

return err
_, verr := tls.X509KeyPair(certChain, pkey)
if verr != nil {
return fmt.Errorf("certificate-key pair: %w", verr)
}

status.ValidPair = true
}

return nil
return err
}

// Key types.
Expand Down
10 changes: 4 additions & 6 deletions internal/home/tls_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"testing"
"time"

"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
)

Expand Down Expand Up @@ -43,29 +44,26 @@ func TestValidateCertificates(t *testing.T) {
t.Run("bad_certificate", func(t *testing.T) {
status := &tlsConfigStatus{}
err := validateCertificates(status, []byte("bad cert"), nil, "")
assert.Error(t, err)
assert.NotEmpty(t, status.WarningValidation)
testutil.AssertErrorMsg(t, "empty certificate", err)
assert.False(t, status.ValidCert)
assert.False(t, status.ValidChain)
})

t.Run("bad_private_key", func(t *testing.T) {
status := &tlsConfigStatus{}
err := validateCertificates(status, nil, []byte("bad priv key"), "")
assert.Error(t, err)
assert.NotEmpty(t, status.WarningValidation)
testutil.AssertErrorMsg(t, "no valid keys were found", err)
assert.False(t, status.ValidKey)
})

t.Run("valid", func(t *testing.T) {
status := &tlsConfigStatus{}
err := validateCertificates(status, testCertChainData, testPrivateKeyData, "")
assert.NoError(t, err)
assert.ErrorAs(t, err, new(warningError))

notBefore := time.Date(2019, 2, 27, 9, 24, 23, 0, time.UTC)
notAfter := time.Date(2046, 7, 14, 9, 24, 23, 0, time.UTC)

assert.NotEmpty(t, status.WarningValidation)
assert.True(t, status.ValidCert)
assert.False(t, status.ValidChain)
assert.True(t, status.ValidKey)
Expand Down

0 comments on commit 540efcb

Please sign in to comment.