Skip to content

Commit

Permalink
Merge 847db62 into 601e9e9
Browse files Browse the repository at this point in the history
  • Loading branch information
vbrown608 committed Jan 3, 2019
2 parents 601e9e9 + 847db62 commit 70bdf1f
Show file tree
Hide file tree
Showing 11 changed files with 415 additions and 116 deletions.
10 changes: 8 additions & 2 deletions checker/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@ import (

func TestSimpleCacheMap(t *testing.T) {
cache := CreateSimpleCache(time.Hour)
err := cache.PutHostnameScan("anything", HostnameResult{Status: 3, Timestamp: time.Now()})
err := cache.PutHostnameScan("anything", HostnameResult{
ResultGroup: &ResultGroup{Status: 3},
Timestamp: time.Now(),
})
if err != nil {
t.Errorf("Expected scan put to succeed: %v", err)
}
Expand All @@ -22,7 +25,10 @@ func TestSimpleCacheMap(t *testing.T) {

func TestSimpleCacheExpires(t *testing.T) {
cache := CreateSimpleCache(0)
cache.PutHostnameScan("anything", HostnameResult{Status: 3})
cache.PutHostnameScan("anything", HostnameResult{
ResultGroup: &ResultGroup{Status: 3},
Timestamp: time.Now(),
})
_, err := cache.GetHostnameScan("anything")
if err == nil {
t.Errorf("Expected cache to expire and scan get to fail: %v", err)
Expand Down
38 changes: 30 additions & 8 deletions checker/check.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,36 +42,58 @@ func (c *CheckResult) ensureInit() {
// Error adds an error message to this check result.
// The Error status will override any other existing status for this check.
// Typically, when a check encounters an error, it stops executing.
func (c CheckResult) Error(format string, a ...interface{}) CheckResult {
func (c *CheckResult) Error(format string, a ...interface{}) CheckResult {
c.ensureInit()
c.Status = SetStatus(c.Status, Error)
c.Messages = append(c.Messages, fmt.Sprintf("Error: "+format, a...))
return c
return *c
}

// Failure adds a failure message to this check result.
// The Failure status will override any Status other than Error.
// Whenever Failure is called, the entire check is failed.
func (c CheckResult) Failure(format string, a ...interface{}) CheckResult {
func (c *CheckResult) Failure(format string, a ...interface{}) CheckResult {
c.ensureInit()
c.Status = SetStatus(c.Status, Failure)
c.Messages = append(c.Messages, fmt.Sprintf("Failure: "+format, a...))
return c
return *c
}

// Warning adds a warning message to this check result.
// The Warning status only supercedes the Success status.
func (c CheckResult) Warning(format string, a ...interface{}) CheckResult {
func (c *CheckResult) Warning(format string, a ...interface{}) CheckResult {
c.ensureInit()
c.Status = SetStatus(c.Status, Warning)
c.Messages = append(c.Messages, fmt.Sprintf("Warning: "+format, a...))
return c
return *c
}

// Success simply sets the status of CheckResult to a Success.
// Status is set if no other status has been declared on this check.
func (c CheckResult) Success() CheckResult {
func (c *CheckResult) Success() CheckResult {
c.ensureInit()
c.Status = SetStatus(c.Status, Success)
return c
return *c
}

// ResultGroup wraps the results of a security check against a particular hostname.
type ResultGroup struct {
Status CheckStatus `json:"status"`
Checks map[string]CheckResult `json:"checks"`
}

// Returns result of specified check.
// If called before that check occurs, returns false.
func (r ResultGroup) checkSucceeded(checkName string) bool {
if result, ok := r.Checks[checkName]; ok {
return result.Status == Success
}
return false
}

// Wrapping helper function to set the status of this hostname.
func (r *ResultGroup) addCheck(checkResult CheckResult) {
r.Checks[checkResult.Name] = checkResult
// SetStatus sets ResultGroup's status to the most severe of any individual check
r.Status = SetStatus(r.Status, checkResult.Status)
}
3 changes: 2 additions & 1 deletion checker/cmd/starttls-check/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@ func main() {
flag.PrintDefaults()
}
domainStr := flag.String("domain", "", "Required: Domain to check TLS for.")
domainsFileStr := flag.String("domains", "", "Required: Domain to check TLS for.")
flag.Parse()
if *domainStr == "" {
if *domainStr == "" && *domainsFileStr == "" {
flag.PrintDefaults()
os.Exit(1)
}
Expand Down
2 changes: 1 addition & 1 deletion checker/domain.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ func performCheck(query DomainQuery, timeout time.Duration, cache ScanCache) Dom
return result.setStatus(DomainNoSTARTTLSFailure)
}
// Any of the connected hostnames don't have a match?
if query.ExpectedHostnames != nil && !hasValidName(query.ExpectedHostnames, hostname) {
if query.ExpectedHostnames != nil && !policyMatches(hostname, query.ExpectedHostnames) {
return result.setStatus(DomainBadHostnameFailure)
}
result = result.setStatus(DomainStatus(hostnameResult.Status))
Expand Down
31 changes: 20 additions & 11 deletions checker/domain_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,21 @@ var mxLookup = map[string][]string{
}

// Fake hostname checks :)
var hostnameResults = map[string]HostnameResult{
"noconnection": HostnameResult{
var hostnameResults = map[string]ResultGroup{
"noconnection": ResultGroup{
Status: 3,
Checks: map[string]CheckResult{
"connectivity": {"connectivity", 3, nil},
},
},
"nostarttls": HostnameResult{
"nostarttls": ResultGroup{
Status: 2,
Checks: map[string]CheckResult{
"connectivity": {"connectivity", 0, nil},
"starttls": {"starttls", 2, nil},
},
},
"nostarttlsconnect": HostnameResult{
"nostarttlsconnect": ResultGroup{
Status: 3,
Checks: map[string]CheckResult{
"connectivity": {"connectivity", 0, nil},
Expand All @@ -54,19 +54,28 @@ func (*mockLookup) lookupHostname(domain string, _ time.Duration) ([]string, err

func (*mockChecker) checkHostname(domain string, hostname string, _ time.Duration) HostnameResult {
if result, ok := hostnameResults[hostname]; ok {
result.Timestamp = time.Now()
return result
return HostnameResult{
ResultGroup: &result,
Timestamp: time.Now(),
}
}
// For caching test: "changes" result changes after first scan
if hostname == "changes" {
hostnameResults["changes"] = hostnameResults["nostarttls"]
}
// by default return successful check
return HostnameResult{Status: 0, Checks: map[string]CheckResult{
"connectivity": {"connectivity", 0, nil},
"starttls": {"starttls", 0, nil},
"certificate": {"certificate", 0, nil},
"version": {"version", 0, nil}}, Timestamp: time.Now()}
return HostnameResult{
ResultGroup: &ResultGroup{
Status: 0,
Checks: map[string]CheckResult{
"connectivity": {"connectivity", 0, nil},
"starttls": {"starttls", 0, nil},
"certificate": {"certificate", 0, nil},
"version": {"version", 0, nil},
},
},
Timestamp: time.Now(),
}
}

// Test helpers.
Expand Down
105 changes: 35 additions & 70 deletions checker/hostname.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,13 @@ import (

// HostnameResult wraps the results of a security check against a particular hostname.
type HostnameResult struct {
Domain string `json:"domain"`
Hostname string `json:"hostname"`
Status CheckStatus `json:"status"`
Checks map[string]CheckResult `json:"checks"`
Timestamp time.Time `json:"-"`
*ResultGroup
Domain string `json:"domain"`
Hostname string `json:"hostname"`
Timestamp time.Time `json:"-"`
}

// Returns result of specifiedcheck.
// Returns result of specified check.
// If called before that check occurs, returns false.
func (h HostnameResult) checkSucceeded(checkName string) bool {
if result, ok := h.Checks[checkName]; ok {
Expand All @@ -36,35 +35,28 @@ func (h HostnameResult) couldSTARTTLS() bool {
return h.checkSucceeded("starttls")
}

// Modelled after isWildcardMatch in Appendix B of the MTA-STS draft.
// From draft v17:
// Senders who are comparing a "suffix" MX pattern with a wildcard
// identifier should thus strip the wildcard and ensure that the two
// sides match label-by-label, until all labels of the shorter side
// (if unequal length) are consumed.
func wildcardMatch(hostname string, pattern string) bool {
if strings.HasPrefix(pattern, ".") {
parts := strings.SplitAfterN(hostname, ".", 2)
if len(parts) > 1 && parts[1] == pattern[1:] {
// Modelled after policyMatches in Appendix B of the MTA-STS RFC 8641.
// Also used to validate hostnames on the STARTTLS Everywhere policy list.
func policyMatches(mx string, patterns []string) bool {
mx = strings.TrimSuffix(mx, ".") // If FQDN, might end with .
mx = withoutPort(mx) // If URL, might include port
mx = strings.ToLower(mx) // Lowercase for comparison
for _, pattern := range patterns {
pattern = strings.ToLower(pattern)

// Literal match
if pattern == mx {
return true
}
}
return false
}

// Modelled after certMatches in Appendix B of the MTA-STS draft.
func policyMatch(certName string, policyMx string) bool {
// Lowercase both names for comparison
certName = strings.ToLower(certName)
policyMx = strings.ToLower(policyMx)
if strings.HasPrefix(certName, "*") {
certName = certName[1:]
if !strings.HasPrefix(certName, ".") { // Invalid wildcard domain
return false
// Wildcard match
if strings.HasPrefix(pattern, "*.") {
mxParts := strings.SplitN(mx, ".", 2)
if len(mxParts) > 1 && mxParts[1] == pattern[2:] {
return true
}
}
}
return certName == policyMx || wildcardMatch(certName, policyMx) ||
wildcardMatch(policyMx, certName)
return false
}

func withoutPort(url string) string {
Expand All @@ -74,22 +66,6 @@ func withoutPort(url string) string {
return url
}

// Checks certificate names against a list of expected MX patterns.
// The expected MX patterns are in the format described by MTA-STS,
// and validation is done according to this RFC as well.
func hasValidName(names []string, hostname string) bool {
// If FQDN, might end with '.'; strip it!
hostname = strings.TrimSuffix(hostname, ".")
// If URL, might include port #; strip it!
hostname = withoutPort(hostname)
for _, name := range names {
if policyMatch(name, hostname) {
return true
}
}
return false
}

// Retrieves this machine's hostname, if specified.
func getThisHostname() string {
hostname := os.Getenv("HOSTNAME")
Expand Down Expand Up @@ -130,15 +106,6 @@ func checkStartTLS(client *smtp.Client) CheckResult {
return result.Success()
}

// Retrieves valid names from certificate. If the certificate has
// SAN, retrieves all SAN domains; otherwise returns a list containing only the CN.
func getNamesFromCert(cert *x509.Certificate) []string {
if cert.DNSNames != nil && len(cert.DNSNames) > 0 {
return cert.DNSNames
}
return []string{cert.Subject.CommonName}
}

// If no MX matching policy was provided, then we'll default to accepting matches
// based on the mail domain and the MX hostname.
//
Expand Down Expand Up @@ -179,10 +146,13 @@ func checkCert(client *smtp.Client, domain, hostname string) CheckResult {
return result.Error("TLS not initiated properly.")
}
cert := state.PeerCertificates[0]
if !hasValidName(getNamesFromCert(cert), hostname) {
result = result.Failure("Name in cert doesn't match hostname.")
// If hostname is an FQDN, it might end with '.'
hostname = strings.TrimSuffix(hostname, ".")
err := cert.VerifyHostname(withoutPort(hostname))
if err != nil {
result.Failure("Name in cert doesn't match hostname: %v", err)
}
err := verifyCertChain(state)
err = verifyCertChain(state)
if err != nil {
return result.Failure("Certificate root is not trusted: %v", err)
}
Expand Down Expand Up @@ -247,22 +217,17 @@ func checkTLSVersion(client *smtp.Client, hostname string, timeout time.Duration
return result.Success()
}

// Wrapping helper function to set the status of this hostname.
func (h *HostnameResult) addCheck(checkResult CheckResult) {
h.Checks[checkResult.Name] = checkResult
// SetStatus sets HostnameResult's status to the most severe of any individual check
h.Status = SetStatus(h.Status, checkResult.Status)
}

// CheckHostname performs a series of checks against a hostname for an email domain.
// `domain` is the mail domain that this server serves email for.
// `hostname` is the hostname for this server.
func CheckHostname(domain string, hostname string, timeout time.Duration) HostnameResult {
result := HostnameResult{
Status: Success,
Domain: domain,
Hostname: hostname,
Checks: make(map[string]CheckResult),
Domain: domain,
Hostname: hostname,
ResultGroup: &ResultGroup{
Status: Success,
Checks: make(map[string]CheckResult),
},
Timestamp: time.Now(),
}

Expand Down

0 comments on commit 70bdf1f

Please sign in to comment.