Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions api.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ func apiWrapper(api apiHandler) func(w http.ResponseWriter, r *http.Request) {
}

// Checks the policy status of this domain.
func (api API) policyCheck(domain string) checker.CheckResult {
result := checker.CheckResult{Name: "policylist"}
func (api API) policyCheck(domain string) *checker.Result {
result := checker.Result{Name: "policylist"}
if _, err := api.List.Get(domain); err == nil {
return result.Success()
}
Expand All @@ -107,9 +107,9 @@ func (api API) policyCheck(domain string) checker.CheckResult {
// Performs policyCheck asynchronously.
// Should be safe since Database is safe for concurrent use, and so
// is List.
func asyncPolicyCheck(api API, domain string) <-chan checker.CheckResult {
result := make(chan checker.CheckResult)
go func() { result <- api.policyCheck(domain) }()
func asyncPolicyCheck(api API, domain string) <-chan checker.Result {
result := make(chan checker.Result)
go func() { result <- *api.policyCheck(domain) }()
return result
}

Expand All @@ -123,7 +123,7 @@ func defaultCheck(api API, domain string) (checker.DomainResult, error) {
Timeout: 3 * time.Second,
}
result := c.CheckDomain(domain, nil)
result.ExtraResults = make(map[string]checker.CheckResult)
result.ExtraResults = make(map[string]checker.Result)
result.ExtraResults["policylist"] = <-policyChan
return result, nil
}
Expand Down
8 changes: 4 additions & 4 deletions checker/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ import (
func TestSimpleCacheMap(t *testing.T) {
cache := MakeSimpleCache(time.Hour)
err := cache.PutHostnameScan("anything", HostnameResult{
ResultGroup: &ResultGroup{Status: 3},
Timestamp: time.Now(),
Result: &Result{Status: 3},
Timestamp: time.Now(),
})
if err != nil {
t.Errorf("Expected scan put to succeed: %v", err)
Expand All @@ -26,8 +26,8 @@ func TestSimpleCacheMap(t *testing.T) {
func TestSimpleCacheExpires(t *testing.T) {
cache := MakeSimpleCache(0)
cache.PutHostnameScan("anything", HostnameResult{
ResultGroup: &ResultGroup{Status: 3},
Timestamp: time.Now(),
Result: &Result{Status: 3},
Timestamp: time.Now(),
})
_, err := cache.GetHostnameScan("anything")
if err == nil {
Expand Down
2 changes: 1 addition & 1 deletion checker/domain.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ type DomainResult struct {
// Expected MX hostnames supplied by the caller of CheckDomain.
MxHostnames []string `json:"mx_hostnames,omitempty"`
// Extra global results
ExtraResults map[string]CheckResult `json:"extra_results,omitempty"`
ExtraResults map[string]Result `json:"extra_results,omitempty"`
}

// Class satisfies raven's Interface interface.
Expand Down
46 changes: 23 additions & 23 deletions checker/domain_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,25 @@ var mxLookup = map[string][]string{
}

// Fake hostname checks :)
var hostnameResults = map[string]ResultGroup{
"noconnection": ResultGroup{
Status: 3,
Checks: map[string]CheckResult{
"connectivity": {"connectivity", 3, nil},
var hostnameResults = map[string]Result{
"noconnection": Result{
Status: Error,
Checks: map[string]*Result{
"connectivity": {"connectivity", Error, nil, nil},
},
},
"nostarttls": ResultGroup{
Status: 2,
Checks: map[string]CheckResult{
"connectivity": {"connectivity", 0, nil},
"starttls": {"starttls", 2, nil},
"nostarttls": Result{
Status: Failure,
Checks: map[string]*Result{
"connectivity": {"connectivity", 0, nil, nil},
"starttls": {"starttls", Failure, nil, nil},
},
},
"nostarttlsconnect": ResultGroup{
Status: 3,
Checks: map[string]CheckResult{
"connectivity": {"connectivity", 0, nil},
"starttls": {"starttls", 3, nil},
"nostarttlsconnect": Result{
Status: Error,
Checks: map[string]*Result{
"connectivity": {"connectivity", 0, nil, nil},
"starttls": {"starttls", Error, nil, nil},
},
},
}
Expand All @@ -56,8 +56,8 @@ func mockLookupMX(domain string) ([]*net.MX, error) {
func mockCheckHostname(domain string, hostname string) HostnameResult {
if result, ok := hostnameResults[hostname]; ok {
return HostnameResult{
ResultGroup: &result,
Timestamp: time.Now(),
Result: &result,
Timestamp: time.Now(),
}
}
// For caching test: "changes" result changes after first scan
Expand All @@ -66,13 +66,13 @@ func mockCheckHostname(domain string, hostname string) HostnameResult {
}
// by default return successful check
return HostnameResult{
ResultGroup: &ResultGroup{
Result: &Result{
Status: 0,
Checks: map[string]CheckResult{
"connectivity": {"connectivity", 0, nil},
"starttls": {"starttls", 0, nil},
"certificate": {"certificate", 0, nil},
"version": {"version", 0, nil},
Checks: map[string]*Result{
"connectivity": {"connectivity", 0, nil, nil},
"starttls": {"starttls", 0, nil, nil},
"certificate": {"certificate", 0, nil, nil},
"version": {"version", 0, nil, nil},
},
},
Timestamp: time.Now(),
Expand Down
29 changes: 13 additions & 16 deletions checker/hostname.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (

// HostnameResult wraps the results of a security check against a particular hostname.
type HostnameResult struct {
*ResultGroup
*Result
Domain string `json:"domain"`
Hostname string `json:"hostname"`
Timestamp time.Time `json:"-"`
Expand Down Expand Up @@ -94,8 +94,8 @@ func smtpDialWithTimeout(hostname string, timeout time.Duration) (*smtp.Client,
}

// Simply tries to StartTLS with the server.
func checkStartTLS(client *smtp.Client) CheckResult {
result := CheckResult{Name: "starttls"}
func checkStartTLS(client *smtp.Client) *Result {
result := MakeResult("starttls")
ok, _ := client.Extension("StartTLS")
if !ok {
return result.Failure("Server does not advertise support for STARTTLS.")
Expand Down Expand Up @@ -140,8 +140,8 @@ var certRoots *x509.CertPool

// Checks that the certificate presented is valid for a particular hostname, unexpired,
// and chains to a trusted root.
func checkCert(client *smtp.Client, domain, hostname string) CheckResult {
result := CheckResult{Name: "certificate"}
func checkCert(client *smtp.Client, domain, hostname string) *Result {
result := MakeResult("certificate")
state, ok := client.TLSConnectionState()
if !ok {
return result.Error("TLS not initiated properly.")
Expand All @@ -168,8 +168,8 @@ func tlsConfigForCipher(ciphers []uint16) tls.Config {
}

// Checks to see that insecure ciphers are disabled.
func checkTLSCipher(hostname string, timeout time.Duration) CheckResult {
result := CheckResult{Name: "cipher"}
func checkTLSCipher(hostname string, timeout time.Duration) *Result {
result := MakeResult("cipher")
badCiphers := []uint16{
tls.TLS_RSA_WITH_RC4_128_SHA,
tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA,
Expand All @@ -187,8 +187,8 @@ func checkTLSCipher(hostname string, timeout time.Duration) CheckResult {
return result.Success()
}

func checkTLSVersion(client *smtp.Client, hostname string, timeout time.Duration) CheckResult {
result := CheckResult{Name: "version"}
func checkTLSVersion(client *smtp.Client, hostname string, timeout time.Duration) *Result {
result := MakeResult("version")

// Check the TLS version of the existing connection.
tlsConnectionState, ok := client.TLSConnectionState()
Expand Down Expand Up @@ -228,17 +228,14 @@ func (c *Checker) CheckHostname(domain string, hostname string) HostnameResult {
}

result := HostnameResult{
Domain: domain,
Hostname: hostname,
ResultGroup: &ResultGroup{
Status: Success,
Checks: make(map[string]CheckResult),
},
Domain: domain,
Hostname: hostname,
Result: MakeResult("hostnames"),
Timestamp: time.Now(),
}

// Connect to the SMTP server and use that connection to perform as many checks as possible.
connectivityResult := CheckResult{Name: "connectivity"}
connectivityResult := MakeResult("connectivity")
client, err := smtpDialWithTimeout(hostname, c.timeout())
if err != nil {
result.addCheck(connectivityResult.Error("Could not establish connection: %v", err))
Expand Down
64 changes: 32 additions & 32 deletions checker/hostname_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,10 @@ func TestPolicyMatch(t *testing.T) {
func TestNoConnection(t *testing.T) {
result := testChecker.CheckHostname("", "example.com")

expected := ResultGroup{
expected := Result{
Status: 3,
Checks: map[string]CheckResult{
"connectivity": {"connectivity", 3, nil},
Checks: map[string]*Result{
"connectivity": {"connectivity", 3, nil, nil},
},
}
compareStatuses(t, expected, result)
Expand All @@ -119,11 +119,11 @@ func TestNoTLS(t *testing.T) {

result := testChecker.CheckHostname("", ln.Addr().String())

expected := ResultGroup{
expected := Result{
Status: 2,
Checks: map[string]CheckResult{
"connectivity": {"connectivity", 0, nil},
"starttls": {"starttls", 2, nil},
Checks: map[string]*Result{
"connectivity": {"connectivity", 0, nil, nil},
"starttls": {"starttls", 2, nil, nil},
},
}
compareStatuses(t, expected, result)
Expand All @@ -139,13 +139,13 @@ func TestSelfSigned(t *testing.T) {

result := testChecker.CheckHostname("", ln.Addr().String())

expected := ResultGroup{
expected := Result{
Status: 2,
Checks: map[string]CheckResult{
"connectivity": {"connectivity", 0, nil},
"starttls": {"starttls", 0, nil},
"certificate": {"certificate", 2, nil},
"version": {"version", 0, nil},
Checks: map[string]*Result{
"connectivity": {"connectivity", 0, nil, nil},
"starttls": {"starttls", 0, nil, nil},
"certificate": {"certificate", 2, nil, nil},
"version": {"version", 0, nil, nil},
},
}
compareStatuses(t, expected, result)
Expand All @@ -165,13 +165,13 @@ func TestNoTLS12(t *testing.T) {

result := testChecker.CheckHostname("", ln.Addr().String())

expected := ResultGroup{
expected := Result{
Status: 2,
Checks: map[string]CheckResult{
"connectivity": {"connectivity", 0, nil},
"starttls": {"starttls", 0, nil},
"certificate": {"certificate", 2, nil},
"version": {"version", 1, nil},
Checks: map[string]*Result{
"connectivity": {"connectivity", 0, nil, nil},
"starttls": {"starttls", 0, nil, nil},
"certificate": {"certificate", 2, nil, nil},
"version": {"version", 1, nil, nil},
},
}
compareStatuses(t, expected, result)
Expand All @@ -197,13 +197,13 @@ func TestSuccessWithFakeCA(t *testing.T) {
addrParts := strings.Split(ln.Addr().String(), ":")
port := addrParts[len(addrParts)-1]
result := testChecker.CheckHostname("", "localhost:"+port)
expected := ResultGroup{
expected := Result{
Status: 0,
Checks: map[string]CheckResult{
"connectivity": {"connectivity", 0, nil},
"starttls": {"starttls", 0, nil},
"certificate": {"certificate", 0, nil},
"version": {"version", 0, nil},
Checks: map[string]*Result{
"connectivity": {"connectivity", 0, nil, nil},
"starttls": {"starttls", 0, nil, nil},
"certificate": {"certificate", 0, nil, nil},
"version": {"version", 0, nil, nil},
},
}
compareStatuses(t, expected, result)
Expand Down Expand Up @@ -272,13 +272,13 @@ func TestFailureWithBadHostname(t *testing.T) {
addrParts := strings.Split(ln.Addr().String(), ":")
port := addrParts[len(addrParts)-1]
result := testChecker.CheckHostname("", "localhost:"+port)
expected := ResultGroup{
expected := Result{
Status: 2,
Checks: map[string]CheckResult{
"connectivity": {"connectivity", 0, nil},
"starttls": {"starttls", 0, nil},
"certificate": {"certificate", 2, nil},
"version": {"version", 0, nil},
Checks: map[string]*Result{
"connectivity": {"connectivity", 0, nil, nil},
"starttls": {"starttls", 0, nil, nil},
"certificate": {"certificate", 2, nil, nil},
"version": {"version", 0, nil, nil},
},
}
compareStatuses(t, expected, result)
Expand Down Expand Up @@ -338,7 +338,7 @@ func containsCipherSuite(result []uint16, want uint16) bool {
}

// compareStatuses compares the status for the HostnameResult and each Check with a desired value
func compareStatuses(t *testing.T, expected ResultGroup, result HostnameResult) {
func compareStatuses(t *testing.T, expected Result, result HostnameResult) {
if result.Status != expected.Status {
t.Errorf("hostname status = %d, want %d", result.Status, expected.Status)
}
Expand Down
21 changes: 9 additions & 12 deletions checker/mta_sts.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,16 @@ func getKeyValuePairs(record string, lineDelimiter string,
return parsed
}

func checkMTASTSRecord(domain string) CheckResult {
result := CheckResult{Name: "mta-sts-txt"}
func checkMTASTSRecord(domain string) *Result {
result := MakeResult("mta-sts-txt")
records, err := net.LookupTXT(fmt.Sprintf("_mta-sts.%s", domain))
if err != nil {
return result.Failure("Couldn't find MTA-STS TXT record: %v", err)
}
return validateMTASTSRecord(records, result)
}

func validateMTASTSRecord(records []string, result CheckResult) CheckResult {
func validateMTASTSRecord(records []string, result *Result) *Result {
records = filterByPrefix(records, "v=STSv1")
if len(records) != 1 {
return result.Failure("exactly 1 MTA-STS TXT record required, found %d", len(records))
Expand All @@ -62,8 +62,8 @@ func validateMTASTSRecord(records []string, result CheckResult) CheckResult {
return result.Success()
}

func checkMTASTSPolicyFile(domain string, hostnameResults map[string]HostnameResult) CheckResult {
result := CheckResult{Name: "policy_file"}
func checkMTASTSPolicyFile(domain string, hostnameResults map[string]HostnameResult) *Result {
result := MakeResult("policy-file")
client := &http.Client{
// Don't follow redirects.
CheckRedirect: func(req *http.Request, via []*http.Request) error {
Expand Down Expand Up @@ -96,7 +96,7 @@ func checkMTASTSPolicyFile(domain string, hostnameResults map[string]HostnameRes
return validateMTASTSMXs(strings.Split(policy["mx"], " "), hostnameResults, result)
}

func validateMTASTSPolicyFile(body string, result CheckResult) (CheckResult, map[string]string) {
func validateMTASTSPolicyFile(body string, result *Result) (*Result, map[string]string) {
policy := getKeyValuePairs(body, "\n", ":")

if policy["version"] != "STSv1" {
Expand All @@ -121,7 +121,7 @@ func validateMTASTSPolicyFile(body string, result CheckResult) (CheckResult, map
}

func validateMTASTSMXs(policyFileMXs []string, dnsMXs map[string]HostnameResult,
result CheckResult) CheckResult {
result *Result) *Result {
for dnsMX, dnsMXResult := range dnsMXs {
if !dnsMXResult.couldConnect() {
// Ignore hostnames we couldn't connect to, they may be spam traps.
Expand All @@ -138,11 +138,8 @@ func validateMTASTSMXs(policyFileMXs []string, dnsMXs map[string]HostnameResult,
return result
}

func checkMTASTS(domain string, hostnameResults map[string]HostnameResult) ResultGroup {
result := ResultGroup{
Status: Success,
Checks: make(map[string]CheckResult),
}
func checkMTASTS(domain string, hostnameResults map[string]HostnameResult) *Result {
result := MakeResult("mta-sts")
result.addCheck(checkMTASTSRecord(domain))
result.addCheck(checkMTASTSPolicyFile(domain, hostnameResults))
return result
Expand Down
Loading