Skip to content

Commit

Permalink
Merge 9839c1d into 3ceebcf
Browse files Browse the repository at this point in the history
  • Loading branch information
vbrown608 committed Jan 17, 2019
2 parents 3ceebcf + 9839c1d commit 33253dc
Show file tree
Hide file tree
Showing 9 changed files with 106 additions and 86 deletions.
10 changes: 8 additions & 2 deletions api.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,14 @@ func asyncPolicyCheck(api API, domain string) <-chan checker.CheckResult {

func defaultCheck(api API, domain string) (checker.DomainResult, error) {
policyChan := asyncPolicyCheck(api, domain)
result := checker.CheckDomain(domain, nil, 3*time.Second,
checker.ScanCache{ScanStore: api.Database, ExpireTime: 5 * time.Minute})
c := checker.Checker{
Cache: checker.ScanCache{
ScanStore: api.Database,
ExpireTime: 5 * time.Minute,
},
Timeout: 3 * time.Second,
}
result := c.CheckDomain(domain, nil)
result.ExtraResults = make(map[string]checker.CheckResult)
result.ExtraResults["policylist"] = <-policyChan
return result, nil
Expand Down
39 changes: 39 additions & 0 deletions checker/checker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package checker

import (
"net"
"time"
)

// A Checker is used to run checks against SMTP domains and hostnames.
type Checker struct {
// Timeout specifies the maximum timeout for network requests made during
// checks.
// If nil, a default timeout of 10 seconds is used.
Timeout time.Duration

// Cache specifies the hostname scan cache store and expire time.
// Defaults to a 10-minute in-memory cache.
Cache ScanCache

// lookupMX specifies an alternate function to retrieve hostnames for a given
// domain. It is used to mock DNS lookups during testing.
lookupMX func(string) ([]*net.MX, error)

// checkHostname is used to mock checks for a single hostname.
checkHostname func(string, string) HostnameResult
}

func (c Checker) timeout() time.Duration {
if &c.Timeout != nil {
return c.Timeout
}
return 10 * time.Second
}

func (c Checker) cache() ScanCache {
if &c.Cache == nil {
c.Cache = CreateSimpleCache(10 * time.Minute)
}
return c.Cache
}
5 changes: 2 additions & 3 deletions checker/cmd/starttls-check/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"io/ioutil"
"os"
"strings"
"time"

"github.com/EFForg/starttls-backend/checker"
)
Expand Down Expand Up @@ -55,8 +54,8 @@ func main() {
flag.PrintDefaults()
os.Exit(1)
}
cache := checker.CreateSimpleCache(10 * time.Minute)
result := checker.CheckDomain(*domainStr, nil, 5*time.Second, cache)
c := checker.Checker{}
result := c.CheckDomain(*domainStr, nil)
b, err := json.Marshal(result)
if err != nil {
fmt.Printf("%q", err)
Expand Down
64 changes: 17 additions & 47 deletions checker/domain.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,27 +31,6 @@ const (
DomainBadHostnameFailure DomainStatus = 6
)

// DomainQuery wraps the parameters we need to perform a domain check.
type DomainQuery struct {
// Domain being checked.
Domain string
// Expected hostnames in MX records for Domain
ExpectedHostnames []string
// Unexported implementations for fns that make network requests
hostnameLookup
hostnameChecker
}

// Looks up what hostnames are correlated with a particular domain.
type hostnameLookup interface {
lookupHostname(string, time.Duration) ([]string, error)
}

// Performs a series of checks on a particular domain, hostname combo.
type hostnameChecker interface {
checkHostname(string, string, time.Duration) HostnameResult
}

// DomainResult wraps all the results for a particular mail domain.
type DomainResult struct {
// Domain being checked against.
Expand All @@ -77,12 +56,6 @@ func (d DomainResult) Class() string {
return "extra"
}

type tlsChecker struct{}

func (*tlsChecker) checkHostname(domain string, hostname string, timeout time.Duration) HostnameResult {
return CheckHostname(domain, hostname, timeout)
}

func (d DomainResult) setStatus(status DomainStatus) DomainResult {
d.Status = DomainStatus(SetStatus(CheckStatus(d.Status), CheckStatus(status)))
return d
Expand All @@ -95,14 +68,19 @@ func lookupMXWithTimeout(domain string, timeout time.Duration) ([]*net.MX, error
return r.LookupMX(ctx, domain)
}

type dnsLookup struct{}

func (*dnsLookup) lookupHostname(domain string, timeout time.Duration) ([]string, error) {
// lookupHostnames retrieves the MX hostnames associated with a domain.
func (c Checker) lookupHostnames(domain string) ([]string, error) {
domainASCII, err := idna.ToASCII(domain)
if err != nil {
return nil, fmt.Errorf("domain name %s couldn't be converted to ASCII", domain)
}
mxs, err := lookupMXWithTimeout(domainASCII, timeout)
// Allow the Checker to mock DNS lookup.
var mxs []*net.MX
if c.lookupMX != nil {
mxs, err = c.lookupMX(domain)
} else {
mxs, err = lookupMXWithTimeout(domainASCII, c.timeout())
}
if err != nil || len(mxs) == 0 {
return nil, fmt.Errorf("No MX records found")
}
Expand All @@ -124,33 +102,25 @@ func (*dnsLookup) lookupHostname(domain string, timeout time.Duration) ([]string
// `domain` is the mail domain to perform the lookup on.
// `mxHostnames` is the list of expected hostnames.
// If `mxHostnames` is nil, we don't validate the DNS lookup.
func CheckDomain(domain string, mxHostnames []string, timeout time.Duration, cache ScanCache) DomainResult {
return performCheck(DomainQuery{
Domain: domain,
ExpectedHostnames: mxHostnames,
hostnameLookup: &dnsLookup{},
hostnameChecker: &tlsChecker{},
}, timeout, cache)
}

func performCheck(query DomainQuery, timeout time.Duration, cache ScanCache) DomainResult {
func (c Checker) CheckDomain(domain string, expectedHostnames []string) DomainResult {
result := DomainResult{
Domain: query.Domain,
MxHostnames: query.ExpectedHostnames,
Domain: domain,
MxHostnames: expectedHostnames,
HostnameResults: make(map[string]HostnameResult),
}
// 1. Look up hostnames
// 2. Perform and aggregate checks from those hostnames.
// 3. Set a summary message.
hostnames, err := query.lookupHostname(query.Domain, timeout)
hostnames, err := c.lookupHostnames(domain)
if err != nil {
return result.reportError(err)
return result.setStatus(DomainCouldNotConnect)
}
checkedHostnames := make([]string, 0)
for _, hostname := range hostnames {
cache := c.cache()
hostnameResult, err := cache.GetHostnameScan(hostname)
if err != nil {
hostnameResult = query.checkHostname(query.Domain, hostname, timeout)
hostnameResult = c.CheckHostname(domain, hostname)
cache.PutHostnameScan(hostname, hostnameResult)
}
result.HostnameResults[hostname] = hostnameResult
Expand All @@ -172,7 +142,7 @@ func performCheck(query DomainQuery, timeout time.Duration, cache ScanCache) Dom
return result.setStatus(DomainNoSTARTTLSFailure)
}
// Any of the connected hostnames don't have a match?
if query.ExpectedHostnames != nil && !policyMatches(hostname, query.ExpectedHostnames) {
if expectedHostnames != nil && !policyMatches(hostname, expectedHostnames) {
return result.setStatus(DomainBadHostnameFailure)
}
result = result.setStatus(DomainStatus(hostnameResult.Status))
Expand Down
31 changes: 15 additions & 16 deletions checker/domain_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package checker

import (
"fmt"
"net"
"testing"
"time"
)
Expand Down Expand Up @@ -41,19 +42,18 @@ var hostnameResults = map[string]ResultGroup{
},
}

// Mock implementation for lookup and checker

type mockLookup struct{}
type mockChecker struct{}

func (*mockLookup) lookupHostname(domain string, _ time.Duration) ([]string, error) {
func mockLookupMX(domain string) ([]*net.MX, error) {
if domain == "error" {
return nil, fmt.Errorf("No MX records found")
}
return mxLookup[domain], nil
result := []*net.MX{}
for _, host := range mxLookup[domain] {
result = append(result, &net.MX{Host: host})
}
return result, nil
}

func (*mockChecker) checkHostname(domain string, hostname string, _ time.Duration) HostnameResult {
func mockCheckHostname(domain string, hostname string) HostnameResult {
if result, ok := hostnameResults[hostname]; ok {
return HostnameResult{
ResultGroup: &result,
Expand Down Expand Up @@ -103,20 +103,19 @@ func performTests(t *testing.T, tests []domainTestCase) {
}

func performTestsWithCacheTimeout(t *testing.T, tests []domainTestCase, cacheExpiry time.Duration) {
cache := CreateSimpleCache(cacheExpiry)
c := Checker{
Timeout: time.Second,
Cache: CreateSimpleCache(cacheExpiry),
lookupMX: mockLookupMX,
checkHostname: mockCheckHostname,
}
for _, test := range tests {
if test.expectedHostnames == nil {
test.expectedHostnames = mxLookup[test.domain]
}
got := performCheck(DomainQuery{
Domain: test.domain,
ExpectedHostnames: test.expectedHostnames,
hostnameLookup: &mockLookup{},
hostnameChecker: &mockChecker{},
}, time.Second, cache).Status
got := c.CheckDomain(test.domain, test.expectedHostnames).Status
test.check(t, got)
}

}

// Test cases.
Expand Down
11 changes: 8 additions & 3 deletions checker/hostname.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,12 @@ func checkTLSVersion(client *smtp.Client, hostname string, timeout time.Duration
// CheckHostname performs a series of checks against a hostname for an email domain.
// `domain` is the mail domain that this server serves email for.
// `hostname` is the hostname for this server.
func CheckHostname(domain string, hostname string, timeout time.Duration) HostnameResult {
func (c Checker) CheckHostname(domain string, hostname string) HostnameResult {
if c.checkHostname != nil {
// Allow the Checker to mock this function.
return c.checkHostname(domain, hostname)
}

result := HostnameResult{
Domain: domain,
Hostname: hostname,
Expand All @@ -234,7 +239,7 @@ func CheckHostname(domain string, hostname string, timeout time.Duration) Hostna

// Connect to the SMTP server and use that connection to perform as many checks as possible.
connectivityResult := CheckResult{Name: "connectivity"}
client, err := smtpDialWithTimeout(hostname, timeout)
client, err := smtpDialWithTimeout(hostname, c.timeout())
if err != nil {
result.addCheck(connectivityResult.Error("Could not establish connection: %v", err))
return result
Expand All @@ -250,7 +255,7 @@ func CheckHostname(domain string, hostname string, timeout time.Duration) Hostna
// result.addCheck(checkTLSCipher(hostname))

// Creates a new connection to check for SSLv2/3 support because we can't call starttls twice.
result.addCheck(checkTLSVersion(client, hostname, timeout))
result.addCheck(checkTLSVersion(client, hostname, c.timeout()))

return result
}
16 changes: 9 additions & 7 deletions checker/hostname_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ func TestMain(m *testing.M) {

const testTimeout = 250 * time.Millisecond

var testChecker = Checker{Timeout: testTimeout}

// Code follows pattern from crypto/tls/generate_cert.go
// to generate a cert from a PEM-encoded RSA private key.
func createCert(keyData string, commonName string) string {
Expand Down Expand Up @@ -100,7 +102,7 @@ func TestPolicyMatch(t *testing.T) {
}

func TestNoConnection(t *testing.T) {
result := CheckHostname("", "example.com", testTimeout)
result := testChecker.CheckHostname("", "example.com")

expected := ResultGroup{
Status: 3,
Expand All @@ -115,7 +117,7 @@ func TestNoTLS(t *testing.T) {
ln := smtpListenAndServe(t, &tls.Config{})
defer ln.Close()

result := CheckHostname("", ln.Addr().String(), testTimeout)
result := testChecker.CheckHostname("", ln.Addr().String())

expected := ResultGroup{
Status: 2,
Expand All @@ -135,7 +137,7 @@ func TestSelfSigned(t *testing.T) {
ln := smtpListenAndServe(t, &tls.Config{Certificates: []tls.Certificate{cert}})
defer ln.Close()

result := CheckHostname("", ln.Addr().String(), testTimeout)
result := testChecker.CheckHostname("", ln.Addr().String())

expected := ResultGroup{
Status: 2,
Expand All @@ -161,7 +163,7 @@ func TestNoTLS12(t *testing.T) {
})
defer ln.Close()

result := CheckHostname("", ln.Addr().String(), testTimeout)
result := testChecker.CheckHostname("", ln.Addr().String())

expected := ResultGroup{
Status: 2,
Expand Down Expand Up @@ -194,7 +196,7 @@ func TestSuccessWithFakeCA(t *testing.T) {
// conserving the port number.
addrParts := strings.Split(ln.Addr().String(), ":")
port := addrParts[len(addrParts)-1]
result := CheckHostname("", "localhost:"+port, testTimeout)
result := testChecker.CheckHostname("", "localhost:"+port)
expected := ResultGroup{
Status: 0,
Checks: map[string]CheckResult{
Expand Down Expand Up @@ -269,7 +271,7 @@ func TestFailureWithBadHostname(t *testing.T) {
// conserving the port number.
addrParts := strings.Split(ln.Addr().String(), ":")
port := addrParts[len(addrParts)-1]
result := CheckHostname("", "localhost:"+port, testTimeout)
result := testChecker.CheckHostname("", "localhost:"+port)
expected := ResultGroup{
Status: 2,
Checks: map[string]CheckResult{
Expand Down Expand Up @@ -309,7 +311,7 @@ func TestAdvertisedCiphers(t *testing.T) {

ln := smtpListenAndServe(t, tlsConfig)
defer ln.Close()
CheckHostname("", ln.Addr().String(), testTimeout)
testChecker.CheckHostname("", ln.Addr().String())

// Partial list of ciphers we want to support
expectedCipherSuites := []struct {
Expand Down
9 changes: 4 additions & 5 deletions validator/validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func reportToSentry(name string, domain string, result checker.DomainResult) {
result)
}

type checkPerformer func(string, []string, time.Duration, checker.ScanCache) checker.DomainResult
type checkPerformer func(string, []string) checker.DomainResult
type reportFailure func(string, string, checker.DomainResult)

// Helper function that's agnostic to how checks are performed how to
Expand All @@ -43,15 +43,13 @@ func validateRegularly(v DomainPolicyStore, interval time.Duration,
log.Printf("[%s validator] Could not retrieve domains: %v", v.GetName(), err)
continue
}
cache := checker.CreateSimpleCache(time.Minute * 10)

for _, domain := range domains {
hostnames, err := v.HostnamesForDomain(domain)
if err != nil {
log.Printf("[%s validator] Could not retrieve policy for domain %s: %v", v.GetName(), domain, err)
continue
}
result := check(domain, hostnames, 10*time.Second, cache)
result := check(domain, hostnames)
if result.Status != 0 && report != nil {
log.Printf("[%s validator] %s failed; sending report", v.GetName(), domain)
report(v.GetName(), domain, result)
Expand All @@ -64,5 +62,6 @@ func validateRegularly(v DomainPolicyStore, interval time.Duration,
// Hostname map. Interval specifies the interval to wait between each run.
// Failures are reported to Sentry.
func ValidateRegularly(v DomainPolicyStore, interval time.Duration) {
validateRegularly(v, interval, checker.CheckDomain, reportToSentry)
c := checker.Checker{}
validateRegularly(v, interval, c.CheckDomain, reportToSentry)
}

0 comments on commit 33253dc

Please sign in to comment.