Skip to content

Commit

Permalink
Merge branch 'master' into noscript-with-models
Browse files Browse the repository at this point in the history
  • Loading branch information
vbrown608 committed Jan 25, 2019
2 parents b4ac5ce + 934a4e1 commit ab6a7bf
Show file tree
Hide file tree
Showing 18 changed files with 261 additions and 248 deletions.
22 changes: 14 additions & 8 deletions api.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ func apiWrapper(api apiHandler) func(w http.ResponseWriter, r *http.Request) {
}

// Checks the policy status of this domain.
func (api API) policyCheck(domain string) checker.CheckResult {
result := checker.CheckResult{CheckType: checker.CheckType{Name: "policylist"}}
func (api API) policyCheck(domain string) *checker.Result {
result := checker.Result{CheckType: checker.CheckType{Name: "policylist"}}
if _, err := api.List.Get(domain); err == nil {
return result.Success()
}
Expand All @@ -116,17 +116,23 @@ func (api API) policyCheck(domain string) checker.CheckResult {
// Performs policyCheck asynchronously.
// Should be safe since Database is safe for concurrent use, and so
// is List.
func asyncPolicyCheck(api API, domain string) <-chan checker.CheckResult {
result := make(chan checker.CheckResult)
go func() { result <- api.policyCheck(domain) }()
func asyncPolicyCheck(api API, domain string) <-chan checker.Result {
result := make(chan checker.Result)
go func() { result <- *api.policyCheck(domain) }()
return result
}

func defaultCheck(api API, domain string) (checker.DomainResult, error) {
policyChan := asyncPolicyCheck(api, domain)
result := checker.CheckDomain(domain, nil, 3*time.Second,
checker.ScanCache{ScanStore: api.Database, ExpireTime: 5 * time.Minute})
result.ExtraResults = make(map[string]checker.CheckResult)
c := checker.Checker{
Cache: &checker.ScanCache{
ScanStore: api.Database,
ExpireTime: 5 * time.Minute,
},
Timeout: 3 * time.Second,
}
result := c.CheckDomain(domain, nil)
result.ExtraResults = make(map[string]checker.Result)
result.ExtraResults["policylist"] = <-policyChan
return result, nil
}
Expand Down
6 changes: 3 additions & 3 deletions checker/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ func (l *SimpleStore) PutHostnameScan(hostname string, result HostnameResult) er
return nil
}

// CreateSimpleCache creates a cache with a SimpleStore backing it.
func CreateSimpleCache(expiryTime time.Duration) ScanCache {
// MakeSimpleCache creates a cache with a SimpleStore backing it.
func MakeSimpleCache(expiryTime time.Duration) *ScanCache {
store := SimpleStore{m: make(map[string]HostnameResult)}
return ScanCache{ScanStore: &store, ExpireTime: expiryTime}
return &ScanCache{ScanStore: &store, ExpireTime: expiryTime}
}
12 changes: 6 additions & 6 deletions checker/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ import (
)

func TestSimpleCacheMap(t *testing.T) {
cache := CreateSimpleCache(time.Hour)
cache := MakeSimpleCache(time.Hour)
err := cache.PutHostnameScan("anything", HostnameResult{
ResultGroup: &ResultGroup{Status: 3},
Timestamp: time.Now(),
Result: &Result{Status: 3},
Timestamp: time.Now(),
})
if err != nil {
t.Errorf("Expected scan put to succeed: %v", err)
Expand All @@ -24,10 +24,10 @@ func TestSimpleCacheMap(t *testing.T) {
}

func TestSimpleCacheExpires(t *testing.T) {
cache := CreateSimpleCache(0)
cache := MakeSimpleCache(0)
cache.PutHostnameScan("anything", HostnameResult{
ResultGroup: &ResultGroup{Status: 3},
Timestamp: time.Now(),
Result: &Result{Status: 3},
Timestamp: time.Now(),
})
_, err := cache.GetHostnameScan("anything")
if err == nil {
Expand Down
39 changes: 39 additions & 0 deletions checker/checker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package checker

import (
"net"
"time"
)

// A Checker is used to run checks against SMTP domains and hostnames.
type Checker struct {
// Timeout specifies the maximum timeout for network requests made during
// checks.
// If nil, a default timeout of 10 seconds is used.
Timeout time.Duration

// Cache specifies the hostname scan cache store and expire time.
// Defaults to a 10-minute in-memory cache.
Cache *ScanCache

// lookupMX specifies an alternate function to retrieve hostnames for a given
// domain. It is used to mock DNS lookups during testing.
lookupMX func(string) ([]*net.MX, error)

// checkHostname is used to mock checks for a single hostname.
checkHostname func(string, string) HostnameResult
}

func (c *Checker) timeout() time.Duration {
if c.Timeout != 0 {
return c.Timeout
}
return 10 * time.Second
}

func (c *Checker) cache() *ScanCache {
if c.Cache == nil {
c.Cache = MakeSimpleCache(10 * time.Minute)
}
return c.Cache
}
5 changes: 2 additions & 3 deletions checker/cmd/starttls-check/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"io/ioutil"
"os"
"strings"
"time"

"github.com/EFForg/starttls-backend/checker"
)
Expand Down Expand Up @@ -55,8 +54,8 @@ func main() {
flag.PrintDefaults()
os.Exit(1)
}
cache := checker.CreateSimpleCache(10 * time.Minute)
result := checker.CheckDomain(*domainStr, nil, 5*time.Second, cache)
c := checker.Checker{}
result := c.CheckDomain(*domainStr, nil)
b, err := json.Marshal(result)
if err != nil {
fmt.Printf("%q", err)
Expand Down
66 changes: 18 additions & 48 deletions checker/domain.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,27 +31,6 @@ const (
DomainBadHostnameFailure DomainStatus = 6
)

// DomainQuery wraps the parameters we need to perform a domain check.
type DomainQuery struct {
// Domain being checked.
Domain string
// Expected hostnames in MX records for Domain
ExpectedHostnames []string
// Unexported implementations for fns that make network requests
hostnameLookup
hostnameChecker
}

// Looks up what hostnames are correlated with a particular domain.
type hostnameLookup interface {
lookupHostname(string, time.Duration) ([]string, error)
}

// Performs a series of checks on a particular domain, hostname combo.
type hostnameChecker interface {
checkHostname(string, string, time.Duration) HostnameResult
}

// DomainResult wraps all the results for a particular mail domain.
type DomainResult struct {
// Domain being checked against.
Expand All @@ -68,7 +47,7 @@ type DomainResult struct {
// Expected MX hostnames supplied by the caller of CheckDomain.
MxHostnames []string `json:"mx_hostnames,omitempty"`
// Extra global results
ExtraResults map[string]CheckResult `json:"extra_results,omitempty"`
ExtraResults map[string]Result `json:"extra_results,omitempty"`
}

// Class satisfies raven's Interface interface.
Expand All @@ -77,12 +56,6 @@ func (d DomainResult) Class() string {
return "extra"
}

type tlsChecker struct{}

func (*tlsChecker) checkHostname(domain string, hostname string, timeout time.Duration) HostnameResult {
return CheckHostname(domain, hostname, timeout)
}

func (d DomainResult) setStatus(status DomainStatus) DomainResult {
d.Status = DomainStatus(SetStatus(CheckStatus(d.Status), CheckStatus(status)))
return d
Expand All @@ -95,14 +68,19 @@ func lookupMXWithTimeout(domain string, timeout time.Duration) ([]*net.MX, error
return r.LookupMX(ctx, domain)
}

type dnsLookup struct{}

func (*dnsLookup) lookupHostname(domain string, timeout time.Duration) ([]string, error) {
// lookupHostnames retrieves the MX hostnames associated with a domain.
func (c *Checker) lookupHostnames(domain string) ([]string, error) {
domainASCII, err := idna.ToASCII(domain)
if err != nil {
return nil, fmt.Errorf("domain name %s couldn't be converted to ASCII", domain)
}
mxs, err := lookupMXWithTimeout(domainASCII, timeout)
// Allow the Checker to mock DNS lookup.
var mxs []*net.MX
if c.lookupMX != nil {
mxs, err = c.lookupMX(domain)
} else {
mxs, err = lookupMXWithTimeout(domainASCII, c.timeout())
}
if err != nil || len(mxs) == 0 {
return nil, fmt.Errorf("No MX records found")
}
Expand All @@ -124,33 +102,25 @@ func (*dnsLookup) lookupHostname(domain string, timeout time.Duration) ([]string
// `domain` is the mail domain to perform the lookup on.
// `mxHostnames` is the list of expected hostnames.
// If `mxHostnames` is nil, we don't validate the DNS lookup.
func CheckDomain(domain string, mxHostnames []string, timeout time.Duration, cache ScanCache) DomainResult {
return performCheck(DomainQuery{
Domain: domain,
ExpectedHostnames: mxHostnames,
hostnameLookup: &dnsLookup{},
hostnameChecker: &tlsChecker{},
}, timeout, cache)
}

func performCheck(query DomainQuery, timeout time.Duration, cache ScanCache) DomainResult {
func (c *Checker) CheckDomain(domain string, expectedHostnames []string) DomainResult {
result := DomainResult{
Domain: query.Domain,
MxHostnames: query.ExpectedHostnames,
Domain: domain,
MxHostnames: expectedHostnames,
HostnameResults: make(map[string]HostnameResult),
}
// 1. Look up hostnames
// 2. Perform and aggregate checks from those hostnames.
// 3. Set a summary message.
hostnames, err := query.lookupHostname(query.Domain, timeout)
hostnames, err := c.lookupHostnames(domain)
if err != nil {
return result.reportError(err)
return result.setStatus(DomainCouldNotConnect)
}
checkedHostnames := make([]string, 0)
for _, hostname := range hostnames {
cache := c.cache()
hostnameResult, err := cache.GetHostnameScan(hostname)
if err != nil {
hostnameResult = query.checkHostname(query.Domain, hostname, timeout)
hostnameResult = c.CheckHostname(domain, hostname)
cache.PutHostnameScan(hostname, hostnameResult)
}
result.HostnameResults[hostname] = hostnameResult
Expand All @@ -172,7 +142,7 @@ func performCheck(query DomainQuery, timeout time.Duration, cache ScanCache) Dom
return result.setStatus(DomainNoSTARTTLSFailure)
}
// Any of the connected hostnames don't have a match?
if query.ExpectedHostnames != nil && !policyMatches(hostname, query.ExpectedHostnames) {
if expectedHostnames != nil && !policyMatches(hostname, expectedHostnames) {
return result.setStatus(DomainBadHostnameFailure)
}
result = result.setStatus(DomainStatus(hostnameResult.Status))
Expand Down
77 changes: 38 additions & 39 deletions checker/domain_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package checker

import (
"fmt"
"net"
"testing"
"time"
)
Expand All @@ -18,46 +19,45 @@ var mxLookup = map[string][]string{
}

// Fake hostname checks :)
var hostnameResults = map[string]ResultGroup{
"noconnection": ResultGroup{
Status: 3,
Checks: map[string]CheckResult{
"connectivity": {Connectivity, 3, nil},
var hostnameResults = map[string]Result{
"noconnection": Result{
Status: Error,
Checks: map[string]*Result{
"connectivity": {Connectivity, Error, nil, nil},
},
},
"nostarttls": ResultGroup{
Status: 2,
Checks: map[string]CheckResult{
"connectivity": {Connectivity, 0, nil},
"starttls": {STARTTLS, 2, nil},
"nostarttls": Result{
Status: Failure,
Checks: map[string]*Result{
"connectivity": {Connectivity, 0, nil, nil},
"starttls": {STARTTLS, Failure, nil, nil},
},
},
"nostarttlsconnect": ResultGroup{
Status: 3,
Checks: map[string]CheckResult{
"connectivity": {Connectivity, 0, nil},
"starttls": {STARTTLS, 3, nil},
"nostarttlsconnect": Result{
Status: Error,
Checks: map[string]*Result{
"connectivity": {Connectivity, 0, nil, nil},
"starttls": {STARTTLS, Error, nil, nil},
},
},
}

// Mock implementation for lookup and checker

type mockLookup struct{}
type mockChecker struct{}

func (*mockLookup) lookupHostname(domain string, _ time.Duration) ([]string, error) {
func mockLookupMX(domain string) ([]*net.MX, error) {
if domain == "error" {
return nil, fmt.Errorf("No MX records found")
}
return mxLookup[domain], nil
result := []*net.MX{}
for _, host := range mxLookup[domain] {
result = append(result, &net.MX{Host: host})
}
return result, nil
}

func (*mockChecker) checkHostname(domain string, hostname string, _ time.Duration) HostnameResult {
func mockCheckHostname(domain string, hostname string) HostnameResult {
if result, ok := hostnameResults[hostname]; ok {
return HostnameResult{
ResultGroup: &result,
Timestamp: time.Now(),
Result: &result,
Timestamp: time.Now(),
}
}
// For caching test: "changes" result changes after first scan
Expand All @@ -66,13 +66,13 @@ func (*mockChecker) checkHostname(domain string, hostname string, _ time.Duratio
}
// by default return successful check
return HostnameResult{
ResultGroup: &ResultGroup{
Result: &Result{
Status: 0,
Checks: map[string]CheckResult{
"connectivity": {Connectivity, 0, nil},
"starttls": {STARTTLS, 0, nil},
"certificate": {Certificate, 0, nil},
"version": {Version, 0, nil},
Checks: map[string]*Result{
"connectivity": {Connectivity, 0, nil, nil},
"starttls": {STARTTLS, 0, nil, nil},
"certificate": {Certificate, 0, nil, nil},
"version": {Version, 0, nil, nil},
},
},
Timestamp: time.Now(),
Expand Down Expand Up @@ -103,20 +103,19 @@ func performTests(t *testing.T, tests []domainTestCase) {
}

func performTestsWithCacheTimeout(t *testing.T, tests []domainTestCase, cacheExpiry time.Duration) {
cache := CreateSimpleCache(cacheExpiry)
c := Checker{
Timeout: time.Second,
Cache: MakeSimpleCache(cacheExpiry),
lookupMX: mockLookupMX,
checkHostname: mockCheckHostname,
}
for _, test := range tests {
if test.expectedHostnames == nil {
test.expectedHostnames = mxLookup[test.domain]
}
got := performCheck(DomainQuery{
Domain: test.domain,
ExpectedHostnames: test.expectedHostnames,
hostnameLookup: &mockLookup{},
hostnameChecker: &mockChecker{},
}, time.Second, cache).Status
got := c.CheckDomain(test.domain, test.expectedHostnames).Status
test.check(t, got)
}

}

// Test cases.
Expand Down

0 comments on commit ab6a7bf

Please sign in to comment.