Skip to content

Commit

Permalink
Merge 6c8b255 into 87ef87d
Browse files Browse the repository at this point in the history
  • Loading branch information
sydneyli committed May 6, 2019
2 parents 87ef87d + 6c8b255 commit 88e901a
Show file tree
Hide file tree
Showing 15 changed files with 213 additions and 87 deletions.
2 changes: 1 addition & 1 deletion checker/domain.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ func (c *Checker) CheckDomain(domain string, expectedHostnames []string) DomainR
}
}
result.PreferredHostnames = checkedHostnames
result.MTASTSResult = c.checkMTASTS(domain, result.HostnameResults)
result.MTASTSResult = c.CheckMTASTS(domain, result.HostnameResults)

// Derive Domain code from Hostname results.
if len(checkedHostnames) == 0 {
Expand Down
4 changes: 3 additions & 1 deletion checker/mta_sts.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,9 @@ func validateMTASTSMXs(policyFileMXs []string, dnsMXs map[string]HostnameResult,
}
}

func (c Checker) checkMTASTS(domain string, hostnameResults map[string]HostnameResult) *MTASTSResult {
// CheckMTASTS performs all associated checks for a particular domain's
// MTA-STS support.
func (c Checker) CheckMTASTS(domain string, hostnameResults map[string]HostnameResult) *MTASTSResult {
if c.checkMTASTSOverride != nil {
// Allow the Checker to mock this function.
return c.checkMTASTSOverride(domain, hostnameResults)
Expand Down
2 changes: 1 addition & 1 deletion db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ type Database interface {
// Upserts domain state.
PutDomain(models.Domain) error
// Retrieves state of a domain
GetDomain(string, models.DomainState) (models.Domain, error)
GetDomainInState(string, models.DomainState) (models.Domain, error)
// Retrieves all domains in a particular state.
GetDomains(models.DomainState) ([]models.Domain, error)
SetStatus(string, models.DomainState) error
Expand Down
34 changes: 22 additions & 12 deletions db/sqldb.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,16 +216,24 @@ func (db *SQLDatabase) PutDomain(domain models.Domain) error {
return err
}

// GetDomain retrieves the status and information associated with a particular
// UpdateDomainPolicy allows us to update the internal data about a particular domain.
func (db *SQLDatabase) UpdateDomainPolicy(domain models.Domain) error {
_, err := db.conn.Exec("UPDATE domains SET data=$2, status=$3 WHERE domain=$1 AND mta_sts=TRUE",
domain.Name, strings.Join(domain.MXs[:], ","), domain.State)
return err

}

// GetDomainInState retrieves the status and information associated with a particular
// mailserver domain.
func (db SQLDatabase) GetDomain(domain string, state models.DomainState) (models.Domain, error) {
func (db SQLDatabase) GetDomainInState(domain string, state models.DomainState) (models.Domain, error) {
return db.queryDomain("SELECT %s FROM domains WHERE domain=$1 AND status=$2", domain, state)
}

// GetDomains retrieves all the domains which match a particular state,
// that are not in MTA_STS mode
func (db SQLDatabase) GetDomains(state models.DomainState) ([]models.Domain, error) {
return db.queryDomainsWhere("status=$1", state)
return db.queryDomainsWhere("status=$1 AND mta_sts=FALSE", state)
}

// GetMTASTSDomains retrieves domains which wish their policy to be queued with their MTASTS.
Expand Down Expand Up @@ -332,23 +340,25 @@ func (db SQLDatabase) DomainsToValidate() ([]string, error) {
if err != nil {
return domains, err
}
dataMTASTS, err := db.GetMTASTSDomains()
if err != nil {
return domains, err
}
for _, domainInfo := range data {
domains = append(domains, domainInfo.Name)
}
for _, domainInfo := range dataMTASTS {
domains = append(domains, domainInfo.Name)
}
return domains, nil
}

// HostnamesForDomain [interface Validator] retrieves the hostname policy for
// a particular domain.
func (db SQLDatabase) HostnamesForDomain(domain string) ([]string, error) {
data, err := db.GetDomain(domain, models.StateEnforce)
func (db SQLDatabase) GetDomain(domain string) (models.Domain, error) {
data, err := db.GetDomainInState(domain, models.StateEnforce)
if err != nil {
data, err = db.GetDomain(domain, models.StateTesting)
data, err = db.GetDomainInState(domain, models.StateTesting)
}
if err != nil {
return []string{}, err
}
return data.MXs, nil
return data, err
}

// GetHostnameScan retrives most recent scan from database.
Expand Down
56 changes: 28 additions & 28 deletions db/sqldb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ func TestPutGetDomain(t *testing.T) {
if err != nil {
t.Errorf("PutDomain failed: %v\n", err)
}
retrievedData, err := database.GetDomain(data.Name, models.StateUnconfirmed)
retrievedData, err := database.GetDomainInState(data.Name, models.StateUnconfirmed)
if err != nil {
t.Errorf("GetDomain(%s) failed: %v\n", data.Name, err)
}
Expand All @@ -171,7 +171,7 @@ func TestUpsertDomain(t *testing.T) {
if err != nil {
t.Errorf("PutDomain(%s) failed: %v\n", data.Name, err)
}
retrievedData, err := database.GetDomain(data.Name, models.StateUnconfirmed)
retrievedData, err := database.GetDomainInState(data.Name, models.StateUnconfirmed)
if retrievedData.MXs[0] != "hello_darkness_my_old_friend" || retrievedData.Email != "actual_admin@testing.com" {
t.Errorf("Email and MXs should have been rewritten: %v\n", retrievedData)
}
Expand Down Expand Up @@ -220,11 +220,11 @@ func TestLastUpdatedFieldUpdates(t *testing.T) {
State: models.StateUnconfirmed,
}
database.PutDomain(data)
retrievedData, _ := database.GetDomain(data.Name, models.StateUnconfirmed)
retrievedData, _ := database.GetDomainInState(data.Name, models.StateUnconfirmed)
lastUpdated := retrievedData.LastUpdated
data.State = models.StateTesting
database.PutDomain(models.Domain{Name: data.Name, Email: "new fone who dis"})
retrievedData, _ = database.GetDomain(data.Name, models.StateUnconfirmed)
retrievedData, _ = database.GetDomainInState(data.Name, models.StateUnconfirmed)
if lastUpdated.Equal(retrievedData.LastUpdated) {
t.Errorf("Expected last_updated to be updated on change: %v", lastUpdated)
}
Expand All @@ -238,10 +238,10 @@ func TestLastUpdatedFieldDoesntUpdate(t *testing.T) {
State: models.StateUnconfirmed,
}
database.PutDomain(data)
retrievedData, _ := database.GetDomain(data.Name, models.StateUnconfirmed)
retrievedData, _ := database.GetDomainInState(data.Name, models.StateUnconfirmed)
lastUpdated := retrievedData.LastUpdated
database.PutDomain(data)
retrievedData, _ = database.GetDomain(data.Name, models.StateUnconfirmed)
retrievedData, _ = database.GetDomainInState(data.Name, models.StateUnconfirmed)
if !lastUpdated.Equal(retrievedData.LastUpdated) {
t.Errorf("Expected last_updated to stay the same if no changes were made")
}
Expand Down Expand Up @@ -270,28 +270,6 @@ func TestDomainsToValidate(t *testing.T) {
}
}

func TestHostnamesForDomain(t *testing.T) {
database.ClearTables()
database.PutDomain(models.Domain{Name: "x", MXs: []string{"x.com", "y.org"}})
database.PutDomain(models.Domain{Name: "y"})
database.SetStatus("x", models.StateTesting)
database.SetStatus("y", models.StateTesting)
result, err := database.HostnamesForDomain("x")
if err != nil {
t.Fatalf("HostnamesForDomain failed: %v\n", err)
}
if len(result) != 2 || result[0] != "x.com" || result[1] != "y.org" {
t.Errorf("Expected two hostnames, x.com and y.org\n")
}
result, err = database.HostnamesForDomain("y")
if err != nil {
t.Fatalf("HostnamesForDomain failed: %v\n", err)
}
if len(result) > 0 {
t.Errorf("Expected no hostnames to be returned, got %s\n", result[0])
}
}

func TestPutAndIsBlacklistedEmail(t *testing.T) {
defer database.ClearTables()

Expand Down Expand Up @@ -454,3 +432,25 @@ func TestGetMTASTSDomains(t *testing.T) {
}
}
}

func TestUpdateDomainPolicy(t *testing.T) {
database.ClearTables()
database.PutDomain(models.Domain{Name: "no-mtasts"})
database.PutDomain(models.Domain{Name: "mtasts", MTASTS: true, Email: "real-email"})
database.UpdateDomainPolicy(models.Domain{Name: "no-mtasts", State: models.StateEnforce})
database.UpdateDomainPolicy(models.Domain{Name: "mtasts", State: models.StateEnforce, MXs: []string{"hostname"}, Email: "fake-email"})
domain, _ := database.GetDomain("no-mtasts")
if domain.State == models.StateEnforce {
t.Errorf("Expected State to not update since unicorns isn't MTASTS")
}
domain, _ = database.GetDomain("mtasts")
if domain.State != models.StateEnforce {
t.Errorf("Expected State to update after UpdateDomainPolicy")
}
if len(domain.MXs) != 1 || domain.MXs[0] != "hostname" {
t.Errorf("Expected MXs to update after UpdateDomainPolicy")
}
if domain.Email != "real-email" {
t.Errorf("Did not expect Email to update after UpdateDomainPolicy")
}
}
9 changes: 8 additions & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,14 @@ func main() {
}
if os.Getenv("VALIDATE_QUEUED") == "1" {
log.Println("[Starting queued validator]")
go validator.ValidateRegularly("Testing domains", db, 24*time.Hour)
v := validator.Validator{
Name: "Testing and enforced domains",
Store: db,
Interval: 24 * time.Hour,
CheckPerformer: validator.GetDBCheck(db.UpdateDomainPolicy),
}
go v.Run()
// go validator.ValidateRegularly("Testing domains", db, 24*time.Hour)
}
ServePublicEndpoints(&api, &cfg)
}
24 changes: 18 additions & 6 deletions models/domain.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"time"

"github.com/EFForg/starttls-backend/checker"
"github.com/EFForg/starttls-backend/util"
)

/* Domain represents an email domain's TLS policy.
Expand All @@ -29,7 +30,7 @@ type Domain struct {
// domainStore is a simple interface for fetching and adding domain objects.
type domainStore interface {
PutDomain(Domain) error
GetDomain(string, DomainState) (Domain, error)
GetDomainInState(string, DomainState) (Domain, error)
GetDomains(DomainState) ([]Domain, error)
SetStatus(string, DomainState) error
RemoveDomain(string, DomainState) (Domain, error)
Expand Down Expand Up @@ -69,7 +70,7 @@ func (d *Domain) IsQueueable(domains domainStore, scans scanStore, list policyLi
if list.HasDomain(d.Name) {
return false, "Domain is already on the policy list!", scan
}
if _, err := domains.GetDomain(d.Name, StateEnforce); err == nil {
if _, err := domains.GetDomainInState(d.Name, StateEnforce); err == nil {
return false, "Domain is already on the policy list!", scan
}
// Domains without submitted MTA-STS support must match provided mx patterns.
Expand Down Expand Up @@ -142,22 +143,33 @@ func (d Domain) AsyncPolicyListCheck(store domainStore, list policyList) <-chan
return result
}

// SamePolicy checks whether the underlying policy represented by Domain
// and the one picked up by the MTA-STS check represent the same policy.
func (d *Domain) SamePolicy(result *checker.MTASTSResult) bool {
if (result.Mode == "enforce" && d.State != StateEnforce) ||
(result.Mode == "testing" && d.State != StateTesting) ||
result.Mode == "none" {
return false
}
return util.ListsEqual(d.MXs, result.MXs)
}

// GetDomain retrieves Domain with the most "important" state.
// At any given time, there can only be one domain that's either StateEnforce
// or StateTesting. If that domain exists in the store, return that one.
// Otherwise, look for a Domain policy in the unconfirmed state.
func GetDomain(store domainStore, name string) (Domain, error) {
domain, err := store.GetDomain(name, StateEnforce)
domain, err := store.GetDomainInState(name, StateEnforce)
if err == nil {
return domain, nil
}
domain, err = store.GetDomain(name, StateTesting)
domain, err = store.GetDomainInState(name, StateTesting)
if err == nil {
return domain, nil
}
domain, err = store.GetDomain(name, StateUnconfirmed)
domain, err = store.GetDomainInState(name, StateUnconfirmed)
if err == nil {
return domain, nil
}
return store.GetDomain(name, StateFailed)
return store.GetDomainInState(name, StateFailed)
}
2 changes: 1 addition & 1 deletion models/domain_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func (m *mockDomainStore) SetStatus(d string, status DomainState) error {
return m.err
}

func (m *mockDomainStore) GetDomain(d string, state DomainState) (Domain, error) {
func (m *mockDomainStore) GetDomainInState(d string, state DomainState) (Domain, error) {
domain := m.domain
if state != domain.State {
return m.domain, errors.New("")
Expand Down
2 changes: 1 addition & 1 deletion models/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func (t *Token) Redeem(store domainStore, tokens tokenStore) (ret string, userEr
if err != nil {
return domain, err, nil
}
domainData, err := store.GetDomain(domain, StateUnconfirmed)
domainData, err := store.GetDomainInState(domain, StateUnconfirmed)
if err != nil {
return domain, nil, err
}
Expand Down
19 changes: 15 additions & 4 deletions policy/policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
"net/http"
"sync"
"time"

"github.com/EFForg/starttls-backend/models"
)

// policyURL is the default URL from which to fetch the policy JSON.
Expand Down Expand Up @@ -80,14 +82,23 @@ func (l *UpdatedList) DomainsToValidate() ([]string, error) {
return domains, nil
}

// HostnamesForDomain [interface Validator] retrieves the hostname policy for
// GetDomain [interface Validator] retrieves the domain object for
// a particular domain.
func (l *UpdatedList) HostnamesForDomain(domain string) ([]string, error) {
func (l *UpdatedList) GetDomain(domain string) (models.Domain, error) {
policy, err := l.Get(domain)
if err != nil {
return []string{}, err
return models.Domain{}, err
}
domainObj := models.Domain{
Name: domain,
MXs: policy.MXs,
}
if policy.Mode == "enforce" {
domainObj.State = models.StateEnforce
} else if policy.Mode == "testing" {
domainObj.State = models.StateTesting
}
return policy.MXs, nil
return domainObj, nil
}

// Get safely reads from the underlying policy list and returns a TLSPolicy for a domain
Expand Down
6 changes: 3 additions & 3 deletions policy/policy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,12 @@ func TestHostnamesForDomain(t *testing.T) {
var updatedList = List{Policies: map[string]TLSPolicy{
"eff.org": TLSPolicy{MXs: hostnames}}}
list := makeUpdatedList(func() (List, error) { return updatedList, nil }, time.Second)
returned, err := list.HostnamesForDomain("eff.org")
returned, err := list.GetDomain("eff.org")
if err != nil {
t.Fatalf("Encountered %v", err)
}
if !reflect.DeepEqual(returned, hostnames) {
t.Errorf("Expected %s, got %s", hostnames, returned)
if !reflect.DeepEqual(returned.MXs, hostnames) {
t.Errorf("Expected %s, got %s", hostnames, returned.MXs)
}
}

Expand Down
27 changes: 27 additions & 0 deletions util/util.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package util

import (
"reflect"
)

// ListsEqual checks that two lists have the same elements,
// regardless of order.
func ListsEqual(x []string, y []string) bool {
// Transform each list into a histogram
xMap := make(map[string]uint)
yMap := make(map[string]uint)
for _, element := range x {
if _, ok := xMap[element]; !ok {
xMap[element] = 0
}
xMap[element]++
}
for _, element := range y {
if _, ok := yMap[element]; !ok {
yMap[element] = 0
}
yMap[element]++
}
// Compare the histogram maps
return reflect.DeepEqual(xMap, yMap)
}

0 comments on commit 88e901a

Please sign in to comment.