Skip to content

Commit

Permalink
Merge f52f89c into 43ac723
Browse files Browse the repository at this point in the history
  • Loading branch information
vbrown608 committed Mar 28, 2019
2 parents 43ac723 + f52f89c commit 799032c
Show file tree
Hide file tree
Showing 11 changed files with 74 additions and 52 deletions.
5 changes: 3 additions & 2 deletions checker/checker.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ type Checker struct {
// domain. It is used to mock DNS lookups during testing.
lookupMXOverride func(string) ([]*net.MX, error)

// checkHostnameOverride is used to mock checks for a single hostname.
checkHostnameOverride func(string, string) HostnameResult
// CheckHostname defines the function that should be used to check each hostname.
// If nil, FullCheckHostname (all hostname checks) will be used.
CheckHostname func(string, string, time.Duration) HostnameResult

// checkMTASTSOverride is used to mock MTA-STS checks.
checkMTASTSOverride func(string, map[string]HostnameResult) *MTASTSResult
Expand Down
3 changes: 3 additions & 0 deletions checker/cmd/starttls-check/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ func main() {

domainReader := csv.NewReader(instream)
if *aggregate {
c = checker.Checker{
CheckHostname: checker.NoopCheckHostname,
}
resultHandler = &checker.DomainTotals{
Time: time.Now(),
Source: label,
Expand Down
6 changes: 3 additions & 3 deletions checker/cmd/starttls-check/cmd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ import (
func TestUpdateStats(t *testing.T) {
out = new(bytes.Buffer)
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, `1,foo,example1.com
2,bar,example2.com
3,baz,example3.com`)
fmt.Fprintln(w, `1,foo,localhost
2,bar,localhost
3,baz,localhost`)
}))
defer ts.Close()

Expand Down
2 changes: 1 addition & 1 deletion checker/domain.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ func (c *Checker) CheckDomain(domain string, expectedHostnames []string) DomainR
}
checkedHostnames := make([]string, 0)
for _, hostname := range hostnames {
hostnameResult := c.CheckHostnameWithCache(domain, hostname)
hostnameResult := c.checkHostname(domain, hostname)
result.HostnameResults[hostname] = hostnameResult
if hostnameResult.couldConnect() {
checkedHostnames = append(checkedHostnames, hostname)
Expand Down
12 changes: 6 additions & 6 deletions checker/domain_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func mockLookupMX(domain string) ([]*net.MX, error) {
return result, nil
}

func mockCheckHostname(domain string, hostname string) HostnameResult {
func mockCheckHostname(domain string, hostname string, _ time.Duration) HostnameResult {
if result, ok := hostnameResults[hostname]; ok {
return HostnameResult{
Result: &result,
Expand Down Expand Up @@ -110,11 +110,11 @@ func performTests(t *testing.T, tests []domainTestCase) {

func performTestsWithCacheTimeout(t *testing.T, tests []domainTestCase, cacheExpiry time.Duration) {
c := Checker{
Timeout: time.Second,
Cache: MakeSimpleCache(cacheExpiry),
lookupMXOverride: mockLookupMX,
checkHostnameOverride: mockCheckHostname,
checkMTASTSOverride: mockCheckMTASTS,
Timeout: time.Second,
Cache: MakeSimpleCache(cacheExpiry),
lookupMXOverride: mockLookupMX,
CheckHostname: mockCheckHostname,
checkMTASTSOverride: mockCheckMTASTS,
}
for _, test := range tests {
if test.expectedHostnames == nil {
Expand Down
40 changes: 26 additions & 14 deletions checker/hostname.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,29 +210,41 @@ func checkTLSVersion(client *smtp.Client, hostname string, timeout time.Duration
return result.Success()
}

// CheckHostnameWithCache returns the result of CheckHostname, using or
// updating the Checker's cache.
func (c *Checker) CheckHostnameWithCache(domain string, hostname string) HostnameResult {
// checkHostname returns the result of c.CheckHostname or FullCheckHostname,
// using or updating the Checker's cache.
func (c *Checker) checkHostname(domain string, hostname string) HostnameResult {
check := c.CheckHostname
if check == nil {
// If CheckHostname hasn't been set, default to the full set of checks.
check = FullCheckHostname
}

if c.Cache == nil {
return c.CheckHostname(domain, hostname)
return check(domain, hostname, c.timeout())
}
hostnameResult, err := c.Cache.GetHostnameScan(hostname)
if err != nil {
hostnameResult = c.CheckHostname(domain, hostname)
hostnameResult = check(domain, hostname, c.timeout())
c.Cache.PutHostnameScan(hostname, hostnameResult)
}
return hostnameResult
}

// CheckHostname performs a series of checks against a hostname for an email domain.
// `domain` is the mail domain that this server serves email for.
// `hostname` is the hostname for this server.
func (c *Checker) CheckHostname(domain string, hostname string) HostnameResult {
if c.checkHostnameOverride != nil {
// Allow the Checker to mock this function.
return c.checkHostnameOverride(domain, hostname)
// NoopCheckHostname returns a fake error result containing `domain` and `hostname`.
func NoopCheckHostname(domain string, hostname string, _ time.Duration) HostnameResult {
r := HostnameResult{
Domain: domain,
Hostname: hostname,
Result: MakeResult("hostnames"),
}
r.addCheck(MakeResult(Connectivity).Error("Skipping hostname checks"))
return r
}

// FullCheckHostname performs a series of checks against a hostname for an email domain.
// `domain` is the mail domain that this server serves email for.
// `hostname` is the hostname for this server.
func FullCheckHostname(domain string, hostname string, timeout time.Duration) HostnameResult {
result := HostnameResult{
Domain: domain,
Hostname: hostname,
Expand All @@ -242,7 +254,7 @@ func (c *Checker) CheckHostname(domain string, hostname string) HostnameResult {

// Connect to the SMTP server and use that connection to perform as many checks as possible.
connectivityResult := MakeResult(Connectivity)
client, err := smtpDialWithTimeout(hostname, c.timeout())
client, err := smtpDialWithTimeout(hostname, timeout)
if err != nil {
result.addCheck(connectivityResult.Error("Could not establish connection: %v", err))
return result
Expand All @@ -258,7 +270,7 @@ func (c *Checker) CheckHostname(domain string, hostname string) HostnameResult {
// result.addCheck(checkTLSCipher(hostname))

// Creates a new connection to check for SSLv2/3 support because we can't call starttls twice.
result.addCheck(checkTLSVersion(client, hostname, c.timeout()))
result.addCheck(checkTLSVersion(client, hostname, timeout))

return result
}
16 changes: 7 additions & 9 deletions checker/hostname_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ func TestMain(m *testing.M) {

const testTimeout = 250 * time.Millisecond

var testChecker = Checker{Timeout: testTimeout}

// Code follows pattern from crypto/tls/generate_cert.go
// to generate a cert from a PEM-encoded RSA private key.
func createCert(keyData string, commonName string) string {
Expand Down Expand Up @@ -102,7 +100,7 @@ func TestPolicyMatch(t *testing.T) {
}

func TestNoConnection(t *testing.T) {
result := testChecker.CheckHostname("", "example.com")
result := FullCheckHostname("", "example.com", testTimeout)

expected := Result{
Status: 3,
Expand All @@ -117,7 +115,7 @@ func TestNoTLS(t *testing.T) {
ln := smtpListenAndServe(t, &tls.Config{})
defer ln.Close()

result := testChecker.CheckHostname("", ln.Addr().String())
result := FullCheckHostname("", ln.Addr().String(), testTimeout)

expected := Result{
Status: 2,
Expand All @@ -137,7 +135,7 @@ func TestSelfSigned(t *testing.T) {
ln := smtpListenAndServe(t, &tls.Config{Certificates: []tls.Certificate{cert}})
defer ln.Close()

result := testChecker.CheckHostname("", ln.Addr().String())
result := FullCheckHostname("", ln.Addr().String(), testTimeout)

expected := Result{
Status: 2,
Expand All @@ -163,7 +161,7 @@ func TestNoTLS12(t *testing.T) {
})
defer ln.Close()

result := testChecker.CheckHostname("", ln.Addr().String())
result := FullCheckHostname("", ln.Addr().String(), testTimeout)

expected := Result{
Status: 2,
Expand Down Expand Up @@ -196,7 +194,7 @@ func TestSuccessWithFakeCA(t *testing.T) {
// conserving the port number.
addrParts := strings.Split(ln.Addr().String(), ":")
port := addrParts[len(addrParts)-1]
result := testChecker.CheckHostname("", "localhost:"+port)
result := FullCheckHostname("", "localhost:"+port, testTimeout)
expected := Result{
Status: 0,
Checks: map[string]*Result{
Expand Down Expand Up @@ -271,7 +269,7 @@ func TestFailureWithBadHostname(t *testing.T) {
// conserving the port number.
addrParts := strings.Split(ln.Addr().String(), ":")
port := addrParts[len(addrParts)-1]
result := testChecker.CheckHostname("", "localhost:"+port)
result := FullCheckHostname("", "localhost:"+port, testTimeout)
expected := Result{
Status: 2,
Checks: map[string]*Result{
Expand Down Expand Up @@ -311,7 +309,7 @@ func TestAdvertisedCiphers(t *testing.T) {

ln := smtpListenAndServe(t, tlsConfig)
defer ln.Close()
testChecker.CheckHostname("", ln.Addr().String())
FullCheckHostname("", ln.Addr().String(), testTimeout)

// Partial list of ciphers we want to support
expectedCipherSuites := []struct {
Expand Down
2 changes: 1 addition & 1 deletion checker/mta_sts.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func (m MTASTSResult) MarshalJSON() ([]byte, error) {
func filterByPrefix(records []string, prefix string) []string {
filtered := []string{}
for _, elem := range records {
if elem[0:len(prefix)] == prefix {
if strings.HasPrefix(elem, prefix) {
filtered = append(filtered, elem)
}
}
Expand Down
20 changes: 13 additions & 7 deletions checker/totals.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"fmt"
"io"
"log"
"os"
"strconv"
"strings"
"time"
)
Expand All @@ -15,7 +17,7 @@ type DomainTotals struct {
Time time.Time
Source string
Attempted int
Connected int // Connected to at least one mx
WithMXs int
MTASTSTesting []string
MTASTSEnforce []string
}
Expand All @@ -30,11 +32,11 @@ func (t *DomainTotals) HandleDomain(r DomainResult) {
log.Println(t.MTASTSEnforce)
}

// If DomainStatus is > 4, we couldn't connect to a mailbox.
if r.Status > 4 {
if len(r.HostnameResults) == 0 {
// No MX records - assume this isn't an email domain.
return
}
t.Connected++
t.WithMXs++
if r.MTASTSResult != nil {
switch r.MTASTSResult.Mode {
case "enforce":
Expand All @@ -46,8 +48,8 @@ func (t *DomainTotals) HandleDomain(r DomainResult) {
}

func (t DomainTotals) String() string {
s := strings.Join([]string{"time", "source", "attempted", "connected", "mta_sts_testing", "mta_sts_enforce"}, "\t") + "\n"
s += fmt.Sprintf("%v\t%s\t%d\t%d\t%d\t%d\n", t.Time, t.Source, t.Attempted, t.Connected, len(t.MTASTSTesting), len(t.MTASTSEnforce))
s := strings.Join([]string{"time", "source", "attempted", "with_mxs", "mta_sts_testing", "mta_sts_enforce"}, "\t") + "\n"
s += fmt.Sprintf("%v\t%s\t%d\t%d\t%d\t%d\n", t.Time, t.Source, t.Attempted, t.WithMXs, len(t.MTASTSTesting), len(t.MTASTSEnforce))
return s
}

Expand All @@ -57,11 +59,15 @@ type ResultHandler interface {
HandleDomain(DomainResult)
}

const poolSize = 16
const defaultPoolSize = 16

// CheckCSV runs the checker on a csv of domains, processing the results according
// to resultHandler.
func (c *Checker) CheckCSV(domains *csv.Reader, resultHandler ResultHandler, domainColumn int) {
poolSize, err := strconv.Atoi(os.Getenv("CONNECTION_POOL_SIZE"))
if err != nil || poolSize <= 0 {
poolSize = defaultPoolSize
}
work := make(chan string)
results := make(chan DomainResult)

Expand Down
16 changes: 8 additions & 8 deletions checker/totals_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,21 @@ func TestCheckCSV(t *testing.T) {
reader := csv.NewReader(strings.NewReader(in))

c := Checker{
Cache: MakeSimpleCache(10 * time.Minute),
lookupMXOverride: mockLookupMX,
checkHostnameOverride: mockCheckHostname,
checkMTASTSOverride: mockCheckMTASTS,
Cache: MakeSimpleCache(10 * time.Minute),
lookupMXOverride: mockLookupMX,
CheckHostname: mockCheckHostname,
checkMTASTSOverride: mockCheckMTASTS,
}
totals := DomainTotals{}
c.CheckCSV(reader, &totals, 0)

if totals.Attempted != 6 {
t.Errorf("Expected 6 attempted connections, got %d", totals.Attempted)
}
if totals.Connected != 4 {
t.Errorf("Expected 4 successfully connecting domains, got %d", totals.Connected)
if totals.WithMXs != 5 {
t.Errorf("Expected 5 domains with MXs, got %d", totals.WithMXs)
}
if len(totals.MTASTSTesting) != 4 {
t.Errorf("Expected 4 domains in MTA-STS testing mode, got %d", len(totals.MTASTSTesting))
if len(totals.MTASTSTesting) != 5 {
t.Errorf("Expected 5 domains in MTA-STS testing mode, got %d", len(totals.MTASTSTesting))
}
}
4 changes: 3 additions & 1 deletion db/scripts/init_tables.sql
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,13 @@ CREATE TABLE IF NOT EXISTS domain_totals
time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
source TEXT NOT NULL,
attempted INTEGER DEFAULT 0,
connected INTEGER DEFAULT 0,
with_mxs INTEGER DEFAULT 0,
mta_sts_testing INTEGER DEFAULT 0,
mta_sts_enforce INTEGER DEFAULT 0
);

ALTER TABLE domains ADD COLUMN IF NOT EXISTS queue_weeks INTEGER DEFAULT 4;

ALTER TABLE domains ADD COLUMN IF NOT EXISTS testing_start TIMESTAMP;

ALTER TABLE domain_totals ADD COLUMN IF NOT EXISTS with_mxs INTEGER DEFAULT 0;

0 comments on commit 799032c

Please sign in to comment.