Skip to content

Commit

Permalink
Merge 31969f7 into e6b4657
Browse files Browse the repository at this point in the history
  • Loading branch information
sydneyli committed Apr 2, 2019
2 parents e6b4657 + 31969f7 commit ba8fb40
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 40 deletions.
6 changes: 3 additions & 3 deletions api.go
Expand Up @@ -180,9 +180,9 @@ func getDomainParams(r *http.Request) (models.Domain, error) {
}
mtasts := r.FormValue("mta-sts")
domain := models.Domain{
Name: name,
MTASTSMode: mtasts,
State: models.StateUnvalidated,
Name: name,
MTASTS: mtasts == "on",
State: models.StateUnvalidated,
}
email, err := getParam("email", r)
if err == nil {
Expand Down
4 changes: 2 additions & 2 deletions checker/domain.go
Expand Up @@ -102,8 +102,8 @@ func (c *Checker) lookupHostnames(domain string) ([]string, error) {
// checks on the highest priority mailservers succeed.
//
// `domain` is the mail domain to perform the lookup on.
// `mxHostnames` is the list of expected hostnames.
// If `mxHostnames` is nil, we don't validate the DNS lookup.
// `expectedHostnames` is the list of expected hostnames.
// If `expectedHostnames` is nil, we don't validate the DNS lookup.
func (c *Checker) CheckDomain(domain string, expectedHostnames []string) DomainResult {
result := DomainResult{
Domain: domain,
Expand Down
5 changes: 4 additions & 1 deletion db/scripts/init_tables.sql
Expand Up @@ -35,7 +35,8 @@ CREATE TABLE IF NOT EXISTS domains
last_updated TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
status VARCHAR(255) NOT NULL,
queue_weeks INTEGER DEFAULT 4,
testing_start TIMESTAMP
testing_start TIMESTAMP,
mta_sts BOOLEAN DEFAULT FALSE
);

CREATE TABLE IF NOT EXISTS blacklisted_emails
Expand Down Expand Up @@ -92,3 +93,5 @@ ALTER TABLE domains ADD COLUMN IF NOT EXISTS testing_start TIMESTAMP;

ALTER TABLE domain_totals DROP COLUMN IF EXISTS connected;
ALTER TABLE domain_totals ADD COLUMN IF NOT EXISTS with_mxs INTEGER DEFAULT 0;

ALTER TABLE domains ADD COLUMN IF NOT EXISTS mta_sts BOOLEAN DEFAULT FALSE;
52 changes: 31 additions & 21 deletions db/sqldb.go
Expand Up @@ -208,11 +208,11 @@ func (db SQLDatabase) GetAllScans(domain string) ([]models.Scan, error) {
// Subsequent puts with the same domain updates the row with the information in
// the object provided.
func (db *SQLDatabase) PutDomain(domain models.Domain) error {
_, err := db.conn.Exec("INSERT INTO domains(domain, email, data, status, queue_weeks) "+
"VALUES($1, $2, $3, $4, $6) "+
_, err := db.conn.Exec("INSERT INTO domains(domain, email, data, status, queue_weeks, mta_sts) "+
"VALUES($1, $2, $3, $4, $6, $7) "+
"ON CONFLICT (domain) DO UPDATE SET status=$5",
domain.Name, domain.Email, strings.Join(domain.MXs[:], ","),
models.StateUnvalidated, domain.State, domain.QueueWeeks)
models.StateUnvalidated, domain.State, domain.QueueWeeks, domain.MTASTS)
return err
}

Expand All @@ -231,25 +231,15 @@ func (db SQLDatabase) GetDomain(domain string) (models.Domain, error) {
return data, err
}

// GetDomains retrieves all the domains which match a particular state.
// GetDomains retrieves all the domains which match a particular state,
// that are not in MTA_STS mode
func (db SQLDatabase) GetDomains(state models.DomainState) ([]models.Domain, error) {
rows, err := db.conn.Query(
"SELECT domain, email, data, status, last_updated, queue_weeks FROM domains WHERE status=$1", state)
if err != nil {
return nil, err
}
defer rows.Close()
domains := []models.Domain{}
for rows.Next() {
var domain models.Domain
var rawMXs string
if err := rows.Scan(&domain.Name, &domain.Email, &rawMXs, &domain.State, &domain.LastUpdated, &domain.QueueWeeks); err != nil {
return nil, err
}
domain.MXs = strings.Split(rawMXs, ",")
domains = append(domains, domain)
}
return domains, nil
return db.getDomainsWhere("status=$1", state)
}

// GetMTASTSDomains retrieves domains which wish their policy to be queued with their MTASTS.
func (db SQLDatabase) GetMTASTSDomains() ([]models.Domain, error) {
return db.getDomainsWhere("mta_sts=TRUE")
}

// EMAIL BLACKLIST DB FUNCTIONS
Expand Down Expand Up @@ -294,6 +284,26 @@ func (db SQLDatabase) ClearTables() error {
})
}

func (db SQLDatabase) getDomainsWhere(condition string, args ...interface{}) ([]models.Domain, error) {
query := fmt.Sprintf("SELECT domain, email, data, status, last_updated, queue_weeks FROM domains WHERE %s", condition)
rows, err := db.conn.Query(query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
domains := []models.Domain{}
for rows.Next() {
var domain models.Domain
var rawMXs string
if err := rows.Scan(&domain.Name, &domain.Email, &rawMXs, &domain.State, &domain.LastUpdated, &domain.QueueWeeks); err != nil {
return nil, err
}
domain.MXs = strings.Split(rawMXs, ",")
domains = append(domains, domain)
}
return domains, nil
}

// DomainsToValidate [interface Validator] retrieves domains from the
// DB whose policies should be validated.
func (db SQLDatabase) DomainsToValidate() ([]string, error) {
Expand Down
21 changes: 21 additions & 0 deletions db/sqldb_test.go
Expand Up @@ -3,6 +3,7 @@ package db_test
import (
"log"
"os"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -426,3 +427,23 @@ func expectStats(ts models.TimeSeries, t *testing.T) {
}
}
}

func TestGetMTASTSDomains(t *testing.T) {
database.ClearTables()
database.PutDomain(models.Domain{Name: "unicorns"})
database.PutDomain(models.Domain{Name: "mta-sts-x", MTASTS: true})
database.PutDomain(models.Domain{Name: "mta-sts-y", MTASTS: true})
database.PutDomain(models.Domain{Name: "regular"})
domains, err := database.GetMTASTSDomains()
if err != nil {
t.Fatalf("GetMTASTSDomains() failed: %v", err)
}
if len(domains) != 2 {
t.Errorf("Expected GetMTASTSDomains() to return 2 elements")
}
for _, domain := range domains {
if !strings.HasPrefix(domain.Name, "mta-sts") {
t.Errorf("GetMTASTSDomains returned %s when it wasn't supposed to", domain.Name)
}
}
}
7 changes: 3 additions & 4 deletions models/domain.go
Expand Up @@ -13,7 +13,7 @@ type Domain struct {
Name string `json:"domain"` // Domain that is preloaded
Email string `json:"-"` // Contact e-mail for Domain
MXs []string `json:"mxs"` // MXs that are valid for this domain
MTASTSMode string `json:"mta_sts"`
MTASTS bool `json:"mta_sts"`
State DomainState `json:"state"`
LastUpdated time.Time `json:"last_updated"`
TestingStart time.Time `json:"-"`
Expand Down Expand Up @@ -62,7 +62,7 @@ func (d *Domain) IsQueueable(db scanStore, list policyList) (bool, string, Scan)
return false, "Domain is already on the policy list!", scan
}
// Domains without submitted MTA-STS support must match provided mx patterns.
if d.MTASTSMode == "" {
if !d.MTASTS {
for _, hostname := range scan.Data.PreferredHostnames {
if !checker.PolicyMatches(hostname, d.MXs) {
return false, fmt.Sprintf("Hostnames %v do not match policy %v", scan.Data.PreferredHostnames, d.MXs), scan
Expand All @@ -77,8 +77,7 @@ func (d *Domain) IsQueueable(db scanStore, list policyList) (bool, string, Scan)
// PopulateFromScan updates a Domain's fields based on a scan of that domain.
func (d *Domain) PopulateFromScan(scan Scan) {
// We should only trust MTA-STS info from a successful MTA-STS check.
if scan.Data.MTASTSResult != nil && scan.SupportsMTASTS() {
d.MTASTSMode = scan.Data.MTASTSResult.Mode
if d.MTASTS && scan.SupportsMTASTS() {
// If the domain's MXs are missing, we can take them from the scan's
// PreferredHostnames, which must be a subset of those listed in the
// MTA-STS policy file.
Expand Down
15 changes: 6 additions & 9 deletions models/domain_test.go
Expand Up @@ -96,9 +96,9 @@ func TestIsQueueable(t *testing.T) {
}
// With MTA-STS
d = Domain{
Name: "example.com",
Email: "me@example.com",
MTASTSMode: "on",
Name: "example.com",
Email: "me@example.com",
MTASTS: true,
}
ok, msg, _ := d.IsQueueable(mockScanStore{goodScan, nil}, mockList{false})
if !ok {
Expand All @@ -121,20 +121,17 @@ func TestIsQueueable(t *testing.T) {

func TestPopulateFromScan(t *testing.T) {
d := Domain{
Name: "example.com",
Email: "me@example.com",
Name: "example.com",
Email: "me@example.com",
MTASTS: true,
}
s := Scan{
Data: checker.DomainResult{
MTASTSResult: checker.MakeMTASTSResult(),
},
}
s.Data.MTASTSResult.Mode = "enforce"
s.Data.MTASTSResult.MXs = []string{"mx1.example.com", "mx2.example.com"}
d.PopulateFromScan(s)
if d.MTASTSMode != "enforce" {
t.Errorf("Expected domain MTA-STS mode to match scan, got %s", d.MTASTSMode)
}
for i, mx := range s.Data.MTASTSResult.MXs {
if mx != d.MXs[i] {
t.Errorf("Expected MXs to match scan, got %s", d.MXs)
Expand Down

0 comments on commit ba8fb40

Please sign in to comment.