Skip to content

Commit

Permalink
Merge c2a8d29 into 43ac723
Browse files Browse the repository at this point in the history
  • Loading branch information
sydneyli committed Mar 28, 2019
2 parents 43ac723 + c2a8d29 commit 6a00831
Show file tree
Hide file tree
Showing 14 changed files with 254 additions and 100 deletions.
6 changes: 3 additions & 3 deletions checker/domain.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ func (c *Checker) lookupHostnames(domain string) ([]string, error) {
// checks on the highest priority mailservers succeed.
//
// `domain` is the mail domain to perform the lookup on.
// `mxHostnames` is the list of expected hostnames.
// If `mxHostnames` is nil, we don't validate the DNS lookup.
// `expectedHostnames` is the list of expected hostnames.
// If `expectedHostnames` is nil, we don't validate the DNS lookup.
func (c *Checker) CheckDomain(domain string, expectedHostnames []string) DomainResult {
result := DomainResult{
Domain: domain,
Expand All @@ -127,7 +127,7 @@ func (c *Checker) CheckDomain(domain string, expectedHostnames []string) DomainR
}
}
result.PreferredHostnames = checkedHostnames
result.MTASTSResult = c.checkMTASTS(domain, result.HostnameResults)
result.MTASTSResult = c.CheckMTASTS(domain, result.HostnameResults)

// Derive Domain code from Hostname results.
if len(checkedHostnames) == 0 {
Expand Down
4 changes: 3 additions & 1 deletion checker/mta_sts.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,9 @@ func validateMTASTSMXs(policyFileMXs []string, dnsMXs map[string]HostnameResult,
}
}

func (c Checker) checkMTASTS(domain string, hostnameResults map[string]HostnameResult) *MTASTSResult {
// CheckMTASTS performs all associated checks for a particular domain's
// MTA-STS support.
func (c Checker) CheckMTASTS(domain string, hostnameResults map[string]HostnameResult) *MTASTSResult {
if c.checkMTASTSOverride != nil {
// Allow the Checker to mock this function.
return c.checkMTASTSOverride(domain, hostnameResults)
Expand Down
6 changes: 5 additions & 1 deletion db/scripts/init_tables.sql
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ CREATE TABLE IF NOT EXISTS domains
last_updated TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
status VARCHAR(255) NOT NULL,
queue_weeks INTEGER DEFAULT 4,
testing_start TIMESTAMP
testing_start TIMESTAMP,
mta_sts BOOLEAN DEFAULT FALSE
);

CREATE TABLE IF NOT EXISTS blacklisted_emails
Expand Down Expand Up @@ -89,3 +90,6 @@ CREATE TABLE IF NOT EXISTS domain_totals
ALTER TABLE domains ADD COLUMN IF NOT EXISTS queue_weeks INTEGER DEFAULT 4;

ALTER TABLE domains ADD COLUMN IF NOT EXISTS testing_start TIMESTAMP;

ALTER TABLE domains ADD COLUMN IF NOT EXISTS mta_sts BOOLEAN DEFAULT FALSE;

77 changes: 46 additions & 31 deletions db/sqldb.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,14 +208,22 @@ func (db SQLDatabase) GetAllScans(domain string) ([]models.Scan, error) {
// Subsequent puts with the same domain updates the row with the information in
// the object provided.
func (db *SQLDatabase) PutDomain(domain models.Domain) error {
_, err := db.conn.Exec("INSERT INTO domains(domain, email, data, status, queue_weeks) "+
"VALUES($1, $2, $3, $4, $6) "+
_, err := db.conn.Exec("INSERT INTO domains(domain, email, data, status, queue_weeks, mta_sts) "+
"VALUES($1, $2, $3, $4, $6, $7) "+
"ON CONFLICT (domain) DO UPDATE SET status=$5",
domain.Name, domain.Email, strings.Join(domain.MXs[:], ","),
models.StateUnvalidated, domain.State, domain.QueueWeeks)
models.StateUnvalidated, domain.State, domain.QueueWeeks, domain.MTASTSMode == "on")
return err
}

// UpdateDomainPolicy allows us to update the internal data about a particular domain.
func (db *SQLDatabase) UpdateDomainPolicy(domain models.Domain) error {
_, err := db.conn.Exec("UPDATE domains SET data=$2, status=$3 WHERE domain=$1 AND mta_sts=TRUE",
domain.Name, strings.Join(domain.MXs[:], ","), domain.State)
return err

}

// GetDomain retrieves the status and information associated with a particular
// mailserver domain.
func (db SQLDatabase) GetDomain(domain string) (models.Domain, error) {
Expand All @@ -231,25 +239,15 @@ func (db SQLDatabase) GetDomain(domain string) (models.Domain, error) {
return data, err
}

// GetDomains retrieves all the domains which match a particular state.
// GetDomains retrieves all the domains which match a particular state,
// that are not in MTA_STS mode
func (db SQLDatabase) GetDomains(state models.DomainState) ([]models.Domain, error) {
rows, err := db.conn.Query(
"SELECT domain, email, data, status, last_updated, queue_weeks FROM domains WHERE status=$1", state)
if err != nil {
return nil, err
}
defer rows.Close()
domains := []models.Domain{}
for rows.Next() {
var domain models.Domain
var rawMXs string
if err := rows.Scan(&domain.Name, &domain.Email, &rawMXs, &domain.State, &domain.LastUpdated, &domain.QueueWeeks); err != nil {
return nil, err
}
domain.MXs = strings.Split(rawMXs, ",")
domains = append(domains, domain)
}
return domains, nil
return db.getDomainsWhere("status=$1 AND mta_sts=FALSE", state)
}

// GetMTASTSDomains retrieves domains which wish their policy to be queued with their MTASTS.
func (db SQLDatabase) GetMTASTSDomains() ([]models.Domain, error) {
return db.getDomainsWhere("mta_sts=TRUE")
}

// EMAIL BLACKLIST DB FUNCTIONS
Expand Down Expand Up @@ -294,6 +292,26 @@ func (db SQLDatabase) ClearTables() error {
})
}

func (db SQLDatabase) getDomainsWhere(condition string, args ...interface{}) ([]models.Domain, error) {
query := fmt.Sprintf("SELECT domain, email, data, status, last_updated, queue_weeks FROM domains WHERE %s", condition)
rows, err := db.conn.Query(query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
domains := []models.Domain{}
for rows.Next() {
var domain models.Domain
var rawMXs string
if err := rows.Scan(&domain.Name, &domain.Email, &rawMXs, &domain.State, &domain.LastUpdated, &domain.QueueWeeks); err != nil {
return nil, err
}
domain.MXs = strings.Split(rawMXs, ",")
domains = append(domains, domain)
}
return domains, nil
}

// DomainsToValidate [interface Validator] retrieves domains from the
// DB whose policies should be validated.
func (db SQLDatabase) DomainsToValidate() ([]string, error) {
Expand All @@ -302,20 +320,17 @@ func (db SQLDatabase) DomainsToValidate() ([]string, error) {
if err != nil {
return domains, err
}
dataMTASTS, err := db.GetMTASTSDomains()
if err != nil {
return domains, err
}
for _, domainInfo := range data {
domains = append(domains, domainInfo.Name)
}
return domains, nil
}

// HostnamesForDomain [interface Validator] retrieves the hostname policy for
// a particular domain.
func (db SQLDatabase) HostnamesForDomain(domain string) ([]string, error) {
data, err := db.GetDomain(domain)
if err != nil {
return []string{}, err
for _, domainInfo := range dataMTASTS {
domains = append(domains, domainInfo.Name)
}
return data.MXs, nil
return domains, nil
}

// GetHostnameScan retrives most recent scan from database.
Expand Down
63 changes: 43 additions & 20 deletions db/sqldb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package db_test
import (
"log"
"os"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -264,26 +265,6 @@ func TestDomainsToValidate(t *testing.T) {
}
}

func TestHostnamesForDomain(t *testing.T) {
database.ClearTables()
database.PutDomain(models.Domain{Name: "x", MXs: []string{"x.com", "y.org"}})
database.PutDomain(models.Domain{Name: "y"})
result, err := database.HostnamesForDomain("x")
if err != nil {
t.Fatalf("HostnamesForDomain failed: %v\n", err)
}
if len(result) != 2 || result[0] != "x.com" || result[1] != "y.org" {
t.Errorf("Expected two hostnames, x.com and y.org\n")
}
result, err = database.HostnamesForDomain("y")
if err != nil {
t.Fatalf("HostnamesForDomain failed: %v\n", err)
}
if len(result) > 0 {
t.Errorf("Expected no hostnames to be returned, got %s\n", result[0])
}
}

func TestPutAndIsBlacklistedEmail(t *testing.T) {
defer database.ClearTables()

Expand Down Expand Up @@ -426,3 +407,45 @@ func expectStats(ts models.TimeSeries, t *testing.T) {
}
}
}

func TestGetMTASTSDomains(t *testing.T) {
database.ClearTables()
database.PutDomain(models.Domain{Name: "unicorns"})
database.PutDomain(models.Domain{Name: "mta-sts-x", MTASTSMode: "on"})
database.PutDomain(models.Domain{Name: "mta-sts-y", MTASTSMode: "on"})
database.PutDomain(models.Domain{Name: "regular"})
domains, err := database.GetMTASTSDomains()
if err != nil {
t.Fatalf("GetMTASTSDomains() failed: %v", err)
}
if len(domains) != 2 {
t.Errorf("Expected GetMTASTSDomains() to return 2 elements")
}
for _, domain := range domains {
if !strings.HasPrefix(domain.Name, "mta-sts") {
t.Errorf("GetMTASTSDomains returned %s when it wasn't supposed to", domain.Name)
}
}
}

func TestUpdateDomainPolicy(t *testing.T) {
database.ClearTables()
database.PutDomain(models.Domain{Name: "no-mtasts"})
database.PutDomain(models.Domain{Name: "mtasts", MTASTSMode: "on", Email: "real-email"})
database.UpdateDomainPolicy(models.Domain{Name: "no-mtasts", State: models.StateEnforce})
database.UpdateDomainPolicy(models.Domain{Name: "mtasts", State: models.StateEnforce, MXs: []string{"hostname"}, Email: "fake-email"})
domain, _ := database.GetDomain("no-mtasts")
if domain.State == models.StateEnforce {
t.Errorf("Expected State to not update since unicorns isn't MTASTS")
}
domain, _ = database.GetDomain("mtasts")
if domain.State != models.StateEnforce {
t.Errorf("Expected State to update after UpdateDomainPolicy")
}
if len(domain.MXs) != 1 || domain.MXs[0] != "hostname" {
t.Errorf("Expected MXs to update after UpdateDomainPolicy")
}
if domain.Email != "real-email" {
t.Errorf("Did not expect Email to update after UpdateDomainPolicy")
}
}
9 changes: 8 additions & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,14 @@ func main() {
}
if os.Getenv("VALIDATE_QUEUED") == "1" {
log.Println("[Starting queued validator]")
go validator.ValidateRegularly("Testing domains", db, 24*time.Hour)
v := validator.Validator{
Name: "Testing and enforced domains",
Store: db,
Interval: 24 * time.Hour,
CheckPerformer: validator.GetDBCheck(db.UpdateDomainPolicy),
}
go v.Run()
// go validator.ValidateRegularly("Testing domains", db, 24*time.Hour)
}
ServePublicEndpoints(&api, &cfg)
}
15 changes: 13 additions & 2 deletions models/domain.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"time"

"github.com/EFForg/starttls-backend/checker"
"github.com/EFForg/starttls-backend/util"
)

// Domain stores the preload state of a single domain.
Expand Down Expand Up @@ -77,8 +78,7 @@ func (d *Domain) IsQueueable(db scanStore, list policyList) (bool, string, Scan)
// PopulateFromScan updates a Domain's fields based on a scan of that domain.
func (d *Domain) PopulateFromScan(scan Scan) {
// We should only trust MTA-STS info from a successful MTA-STS check.
if scan.Data.MTASTSResult != nil && scan.SupportsMTASTS() {
d.MTASTSMode = scan.Data.MTASTSResult.Mode
if d.MTASTSMode == "on" && scan.Data.MTASTSResult != nil && scan.SupportsMTASTS() {
// If the domain's MXs are missing, we can take them from the scan's
// PreferredHostnames, which must be a subset of those listed in the
// MTA-STS policy file.
Expand Down Expand Up @@ -129,3 +129,14 @@ func (d Domain) AsyncPolicyListCheck(store domainStore, list policyList) <-chan
go func() { result <- *d.PolicyListCheck(store, list) }()
return result
}

// SamePolicy checks whether the underlying policy represented by Domain
// and the one picked up by the MTA-STS check represent the same policy.
func (d *Domain) SamePolicy(result *checker.MTASTSResult) bool {
if (result.Mode == "enforce" && d.State != StateEnforce) ||
(result.Mode == "testing" && d.State != StateTesting) ||
result.Mode == "none" {
return false
}
return util.ListsEqual(d.MXs, result.MXs)
}
9 changes: 3 additions & 6 deletions models/domain_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,20 +121,17 @@ func TestIsQueueable(t *testing.T) {

func TestPopulateFromScan(t *testing.T) {
d := Domain{
Name: "example.com",
Email: "me@example.com",
Name: "example.com",
Email: "me@example.com",
MTASTSMode: "on",
}
s := Scan{
Data: checker.DomainResult{
MTASTSResult: checker.MakeMTASTSResult(),
},
}
s.Data.MTASTSResult.Mode = "enforce"
s.Data.MTASTSResult.MXs = []string{"mx1.example.com", "mx2.example.com"}
d.PopulateFromScan(s)
if d.MTASTSMode != "enforce" {
t.Errorf("Expected domain MTA-STS mode to match scan, got %s", d.MTASTSMode)
}
for i, mx := range s.Data.MTASTSResult.MXs {
if mx != d.MXs[i] {
t.Errorf("Expected MXs to match scan, got %s", d.MXs)
Expand Down
19 changes: 15 additions & 4 deletions policy/policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
"net/http"
"sync"
"time"

"github.com/EFForg/starttls-backend/models"
)

// policyURL is the default URL from which to fetch the policy JSON.
Expand Down Expand Up @@ -80,14 +82,23 @@ func (l *UpdatedList) DomainsToValidate() ([]string, error) {
return domains, nil
}

// HostnamesForDomain [interface Validator] retrieves the hostname policy for
// GetDomain [interface Validator] retrieves the domain object for
// a particular domain.
func (l *UpdatedList) HostnamesForDomain(domain string) ([]string, error) {
func (l *UpdatedList) GetDomain(domain string) (models.Domain, error) {
policy, err := l.Get(domain)
if err != nil {
return []string{}, err
return models.Domain{}, err
}
domainObj := models.Domain{
Name: domain,
MXs: policy.MXs,
}
if policy.Mode == "enforce" {
domainObj.State = models.StateEnforce
} else if policy.Mode == "testing" {
domainObj.State = models.StateTesting
}
return policy.MXs, nil
return domainObj, nil
}

// Get safely reads from the underlying policy list and returns a TLSPolicy for a domain
Expand Down
6 changes: 3 additions & 3 deletions policy/policy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,12 @@ func TestHostnamesForDomain(t *testing.T) {
var updatedList = List{Policies: map[string]TLSPolicy{
"eff.org": TLSPolicy{MXs: hostnames}}}
list := makeUpdatedList(func() (List, error) { return updatedList, nil }, time.Second)
returned, err := list.HostnamesForDomain("eff.org")
returned, err := list.GetDomain("eff.org")
if err != nil {
t.Fatalf("Encountered %v", err)
}
if !reflect.DeepEqual(returned, hostnames) {
t.Errorf("Expected %s, got %s", hostnames, returned)
if !reflect.DeepEqual(returned.MXs, hostnames) {
t.Errorf("Expected %s, got %s", hostnames, returned.MXs)
}
}

Expand Down
27 changes: 27 additions & 0 deletions util/util.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package util

import (
"reflect"
)

// ListsEqual checks that two lists have the same elements,
// regardless of order.
func ListsEqual(x []string, y []string) bool {
// Transform each list into a histogram
xMap := make(map[string]uint)
yMap := make(map[string]uint)
for _, element := range x {
if _, ok := xMap[element]; !ok {
xMap[element] = 0
}
xMap[element]++
}
for _, element := range y {
if _, ok := yMap[element]; !ok {
yMap[element] = 0
}
yMap[element]++
}
// Compare the histogram maps
return reflect.DeepEqual(xMap, yMap)
}

0 comments on commit 6a00831

Please sign in to comment.