Skip to content

Commit

Permalink
Merge 8e916b4 into e63a88e
Browse files Browse the repository at this point in the history
  • Loading branch information
sydneyli committed Jan 25, 2019
2 parents e63a88e + 8e916b4 commit 7c8faac
Show file tree
Hide file tree
Showing 11 changed files with 139 additions and 151 deletions.
12 changes: 6 additions & 6 deletions api.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ func apiWrapper(api apiHandler) func(w http.ResponseWriter, r *http.Request) {
}

// Checks the policy status of this domain.
func (api API) policyCheck(domain string) checker.CheckResult {
result := checker.CheckResult{Name: "policylist"}
func (api API) policyCheck(domain string) *checker.Result {
result := checker.Result{Name: "policylist"}
if _, err := api.List.Get(domain); err == nil {
return result.Success()
}
Expand All @@ -107,9 +107,9 @@ func (api API) policyCheck(domain string) checker.CheckResult {
// Performs policyCheck asynchronously.
// Should be safe since Database is safe for concurrent use, and so
// is List.
func asyncPolicyCheck(api API, domain string) <-chan checker.CheckResult {
result := make(chan checker.CheckResult)
go func() { result <- api.policyCheck(domain) }()
func asyncPolicyCheck(api API, domain string) <-chan checker.Result {
result := make(chan checker.Result)
go func() { result <- *api.policyCheck(domain) }()
return result
}

Expand All @@ -123,7 +123,7 @@ func defaultCheck(api API, domain string) (checker.DomainResult, error) {
Timeout: 3 * time.Second,
}
result := c.CheckDomain(domain, nil)
result.ExtraResults = make(map[string]checker.CheckResult)
result.ExtraResults = make(map[string]checker.Result)
result.ExtraResults["policylist"] = <-policyChan
return result, nil
}
Expand Down
8 changes: 4 additions & 4 deletions checker/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ import (
func TestSimpleCacheMap(t *testing.T) {
cache := MakeSimpleCache(time.Hour)
err := cache.PutHostnameScan("anything", HostnameResult{
ResultGroup: &ResultGroup{Status: 3},
Timestamp: time.Now(),
Result: &Result{Status: 3},
Timestamp: time.Now(),
})
if err != nil {
t.Errorf("Expected scan put to succeed: %v", err)
Expand All @@ -26,8 +26,8 @@ func TestSimpleCacheMap(t *testing.T) {
func TestSimpleCacheExpires(t *testing.T) {
cache := MakeSimpleCache(0)
cache.PutHostnameScan("anything", HostnameResult{
ResultGroup: &ResultGroup{Status: 3},
Timestamp: time.Now(),
Result: &Result{Status: 3},
Timestamp: time.Now(),
})
_, err := cache.GetHostnameScan("anything")
if err == nil {
Expand Down
2 changes: 1 addition & 1 deletion checker/domain.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ type DomainResult struct {
// Expected MX hostnames supplied by the caller of CheckDomain.
MxHostnames []string `json:"mx_hostnames,omitempty"`
// Extra global results
ExtraResults map[string]CheckResult `json:"extra_results,omitempty"`
ExtraResults map[string]Result `json:"extra_results,omitempty"`
}

// Class satisfies raven's Interface interface.
Expand Down
46 changes: 23 additions & 23 deletions checker/domain_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,25 @@ var mxLookup = map[string][]string{
}

// Fake hostname checks :)
var hostnameResults = map[string]ResultGroup{
"noconnection": ResultGroup{
Status: 3,
Checks: map[string]CheckResult{
"connectivity": {"connectivity", 3, nil},
var hostnameResults = map[string]Result{
"noconnection": Result{
Status: Error,
Checks: map[string]*Result{
"connectivity": {"connectivity", Error, nil, nil},
},
},
"nostarttls": ResultGroup{
Status: 2,
Checks: map[string]CheckResult{
"connectivity": {"connectivity", 0, nil},
"starttls": {"starttls", 2, nil},
"nostarttls": Result{
Status: Failure,
Checks: map[string]*Result{
"connectivity": {"connectivity", 0, nil, nil},
"starttls": {"starttls", Failure, nil, nil},
},
},
"nostarttlsconnect": ResultGroup{
Status: 3,
Checks: map[string]CheckResult{
"connectivity": {"connectivity", 0, nil},
"starttls": {"starttls", 3, nil},
"nostarttlsconnect": Result{
Status: Error,
Checks: map[string]*Result{
"connectivity": {"connectivity", 0, nil, nil},
"starttls": {"starttls", Error, nil, nil},
},
},
}
Expand All @@ -56,8 +56,8 @@ func mockLookupMX(domain string) ([]*net.MX, error) {
func mockCheckHostname(domain string, hostname string) HostnameResult {
if result, ok := hostnameResults[hostname]; ok {
return HostnameResult{
ResultGroup: &result,
Timestamp: time.Now(),
Result: &result,
Timestamp: time.Now(),
}
}
// For caching test: "changes" result changes after first scan
Expand All @@ -66,13 +66,13 @@ func mockCheckHostname(domain string, hostname string) HostnameResult {
}
// by default return successful check
return HostnameResult{
ResultGroup: &ResultGroup{
Result: &Result{
Status: 0,
Checks: map[string]CheckResult{
"connectivity": {"connectivity", 0, nil},
"starttls": {"starttls", 0, nil},
"certificate": {"certificate", 0, nil},
"version": {"version", 0, nil},
Checks: map[string]*Result{
"connectivity": {"connectivity", 0, nil, nil},
"starttls": {"starttls", 0, nil, nil},
"certificate": {"certificate", 0, nil, nil},
"version": {"version", 0, nil, nil},
},
},
Timestamp: time.Now(),
Expand Down
29 changes: 13 additions & 16 deletions checker/hostname.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (

// HostnameResult wraps the results of a security check against a particular hostname.
type HostnameResult struct {
*ResultGroup
*Result
Domain string `json:"domain"`
Hostname string `json:"hostname"`
Timestamp time.Time `json:"-"`
Expand Down Expand Up @@ -94,8 +94,8 @@ func smtpDialWithTimeout(hostname string, timeout time.Duration) (*smtp.Client,
}

// Simply tries to StartTLS with the server.
func checkStartTLS(client *smtp.Client) CheckResult {
result := CheckResult{Name: "starttls"}
func checkStartTLS(client *smtp.Client) *Result {
result := MakeResult("starttls")
ok, _ := client.Extension("StartTLS")
if !ok {
return result.Failure("Server does not advertise support for STARTTLS.")
Expand Down Expand Up @@ -140,8 +140,8 @@ var certRoots *x509.CertPool

// Checks that the certificate presented is valid for a particular hostname, unexpired,
// and chains to a trusted root.
func checkCert(client *smtp.Client, domain, hostname string) CheckResult {
result := CheckResult{Name: "certificate"}
func checkCert(client *smtp.Client, domain, hostname string) *Result {
result := MakeResult("certificate")
state, ok := client.TLSConnectionState()
if !ok {
return result.Error("TLS not initiated properly.")
Expand All @@ -168,8 +168,8 @@ func tlsConfigForCipher(ciphers []uint16) tls.Config {
}

// Checks to see that insecure ciphers are disabled.
func checkTLSCipher(hostname string, timeout time.Duration) CheckResult {
result := CheckResult{Name: "cipher"}
func checkTLSCipher(hostname string, timeout time.Duration) *Result {
result := MakeResult("cipher")
badCiphers := []uint16{
tls.TLS_RSA_WITH_RC4_128_SHA,
tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA,
Expand All @@ -187,8 +187,8 @@ func checkTLSCipher(hostname string, timeout time.Duration) CheckResult {
return result.Success()
}

func checkTLSVersion(client *smtp.Client, hostname string, timeout time.Duration) CheckResult {
result := CheckResult{Name: "version"}
func checkTLSVersion(client *smtp.Client, hostname string, timeout time.Duration) *Result {
result := MakeResult("version")

// Check the TLS version of the existing connection.
tlsConnectionState, ok := client.TLSConnectionState()
Expand Down Expand Up @@ -228,17 +228,14 @@ func (c *Checker) CheckHostname(domain string, hostname string) HostnameResult {
}

result := HostnameResult{
Domain: domain,
Hostname: hostname,
ResultGroup: &ResultGroup{
Status: Success,
Checks: make(map[string]CheckResult),
},
Domain: domain,
Hostname: hostname,
Result: MakeResult("hostnames"),
Timestamp: time.Now(),
}

// Connect to the SMTP server and use that connection to perform as many checks as possible.
connectivityResult := CheckResult{Name: "connectivity"}
connectivityResult := MakeResult("connectivity")
client, err := smtpDialWithTimeout(hostname, c.timeout())
if err != nil {
result.addCheck(connectivityResult.Error("Could not establish connection: %v", err))
Expand Down
64 changes: 32 additions & 32 deletions checker/hostname_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,10 @@ func TestPolicyMatch(t *testing.T) {
func TestNoConnection(t *testing.T) {
result := testChecker.CheckHostname("", "example.com")

expected := ResultGroup{
expected := Result{
Status: 3,
Checks: map[string]CheckResult{
"connectivity": {"connectivity", 3, nil},
Checks: map[string]*Result{
"connectivity": {"connectivity", 3, nil, nil},
},
}
compareStatuses(t, expected, result)
Expand All @@ -119,11 +119,11 @@ func TestNoTLS(t *testing.T) {

result := testChecker.CheckHostname("", ln.Addr().String())

expected := ResultGroup{
expected := Result{
Status: 2,
Checks: map[string]CheckResult{
"connectivity": {"connectivity", 0, nil},
"starttls": {"starttls", 2, nil},
Checks: map[string]*Result{
"connectivity": {"connectivity", 0, nil, nil},
"starttls": {"starttls", 2, nil, nil},
},
}
compareStatuses(t, expected, result)
Expand All @@ -139,13 +139,13 @@ func TestSelfSigned(t *testing.T) {

result := testChecker.CheckHostname("", ln.Addr().String())

expected := ResultGroup{
expected := Result{
Status: 2,
Checks: map[string]CheckResult{
"connectivity": {"connectivity", 0, nil},
"starttls": {"starttls", 0, nil},
"certificate": {"certificate", 2, nil},
"version": {"version", 0, nil},
Checks: map[string]*Result{
"connectivity": {"connectivity", 0, nil, nil},
"starttls": {"starttls", 0, nil, nil},
"certificate": {"certificate", 2, nil, nil},
"version": {"version", 0, nil, nil},
},
}
compareStatuses(t, expected, result)
Expand All @@ -165,13 +165,13 @@ func TestNoTLS12(t *testing.T) {

result := testChecker.CheckHostname("", ln.Addr().String())

expected := ResultGroup{
expected := Result{
Status: 2,
Checks: map[string]CheckResult{
"connectivity": {"connectivity", 0, nil},
"starttls": {"starttls", 0, nil},
"certificate": {"certificate", 2, nil},
"version": {"version", 1, nil},
Checks: map[string]*Result{
"connectivity": {"connectivity", 0, nil, nil},
"starttls": {"starttls", 0, nil, nil},
"certificate": {"certificate", 2, nil, nil},
"version": {"version", 1, nil, nil},
},
}
compareStatuses(t, expected, result)
Expand All @@ -197,13 +197,13 @@ func TestSuccessWithFakeCA(t *testing.T) {
addrParts := strings.Split(ln.Addr().String(), ":")
port := addrParts[len(addrParts)-1]
result := testChecker.CheckHostname("", "localhost:"+port)
expected := ResultGroup{
expected := Result{
Status: 0,
Checks: map[string]CheckResult{
"connectivity": {"connectivity", 0, nil},
"starttls": {"starttls", 0, nil},
"certificate": {"certificate", 0, nil},
"version": {"version", 0, nil},
Checks: map[string]*Result{
"connectivity": {"connectivity", 0, nil, nil},
"starttls": {"starttls", 0, nil, nil},
"certificate": {"certificate", 0, nil, nil},
"version": {"version", 0, nil, nil},
},
}
compareStatuses(t, expected, result)
Expand Down Expand Up @@ -272,13 +272,13 @@ func TestFailureWithBadHostname(t *testing.T) {
addrParts := strings.Split(ln.Addr().String(), ":")
port := addrParts[len(addrParts)-1]
result := testChecker.CheckHostname("", "localhost:"+port)
expected := ResultGroup{
expected := Result{
Status: 2,
Checks: map[string]CheckResult{
"connectivity": {"connectivity", 0, nil},
"starttls": {"starttls", 0, nil},
"certificate": {"certificate", 2, nil},
"version": {"version", 0, nil},
Checks: map[string]*Result{
"connectivity": {"connectivity", 0, nil, nil},
"starttls": {"starttls", 0, nil, nil},
"certificate": {"certificate", 2, nil, nil},
"version": {"version", 0, nil, nil},
},
}
compareStatuses(t, expected, result)
Expand Down Expand Up @@ -338,7 +338,7 @@ func containsCipherSuite(result []uint16, want uint16) bool {
}

// compareStatuses compares the status for the HostnameResult and each Check with a desired value
func compareStatuses(t *testing.T, expected ResultGroup, result HostnameResult) {
func compareStatuses(t *testing.T, expected Result, result HostnameResult) {
if result.Status != expected.Status {
t.Errorf("hostname status = %d, want %d", result.Status, expected.Status)
}
Expand Down
21 changes: 9 additions & 12 deletions checker/mta_sts.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,16 @@ func getKeyValuePairs(record string, lineDelimiter string,
return parsed
}

func checkMTASTSRecord(domain string) CheckResult {
result := CheckResult{Name: "mta-sts-txt"}
func checkMTASTSRecord(domain string) *Result {
result := MakeResult("mta-sts-txt")
records, err := net.LookupTXT(fmt.Sprintf("_mta-sts.%s", domain))
if err != nil {
return result.Failure("Couldn't find MTA-STS TXT record: %v", err)
}
return validateMTASTSRecord(records, result)
}

func validateMTASTSRecord(records []string, result CheckResult) CheckResult {
func validateMTASTSRecord(records []string, result *Result) *Result {
records = filterByPrefix(records, "v=STSv1")
if len(records) != 1 {
return result.Failure("exactly 1 MTA-STS TXT record required, found %d", len(records))
Expand All @@ -62,8 +62,8 @@ func validateMTASTSRecord(records []string, result CheckResult) CheckResult {
return result.Success()
}

func checkMTASTSPolicyFile(domain string, hostnameResults map[string]HostnameResult) CheckResult {
result := CheckResult{Name: "policy_file"}
func checkMTASTSPolicyFile(domain string, hostnameResults map[string]HostnameResult) *Result {
result := MakeResult("policy-file")
client := &http.Client{
// Don't follow redirects.
CheckRedirect: func(req *http.Request, via []*http.Request) error {
Expand Down Expand Up @@ -96,7 +96,7 @@ func checkMTASTSPolicyFile(domain string, hostnameResults map[string]HostnameRes
return validateMTASTSMXs(strings.Split(policy["mx"], " "), hostnameResults, result)
}

func validateMTASTSPolicyFile(body string, result CheckResult) (CheckResult, map[string]string) {
func validateMTASTSPolicyFile(body string, result *Result) (*Result, map[string]string) {
policy := getKeyValuePairs(body, "\n", ":")

if policy["version"] != "STSv1" {
Expand All @@ -121,7 +121,7 @@ func validateMTASTSPolicyFile(body string, result CheckResult) (CheckResult, map
}

func validateMTASTSMXs(policyFileMXs []string, dnsMXs map[string]HostnameResult,
result CheckResult) CheckResult {
result *Result) *Result {
for dnsMX, dnsMXResult := range dnsMXs {
if !dnsMXResult.couldConnect() {
// Ignore hostnames we couldn't connect to, they may be spam traps.
Expand All @@ -138,11 +138,8 @@ func validateMTASTSMXs(policyFileMXs []string, dnsMXs map[string]HostnameResult,
return result
}

func checkMTASTS(domain string, hostnameResults map[string]HostnameResult) ResultGroup {
result := ResultGroup{
Status: Success,
Checks: make(map[string]CheckResult),
}
func checkMTASTS(domain string, hostnameResults map[string]HostnameResult) *Result {
result := MakeResult("mta-sts")
result.addCheck(checkMTASTSRecord(domain))
result.addCheck(checkMTASTSPolicyFile(domain, hostnameResults))
return result
Expand Down

0 comments on commit 7c8faac

Please sign in to comment.