Skip to content

Commit

Permalink
Merge ec01f31 into 1b2fb37
Browse files Browse the repository at this point in the history
  • Loading branch information
sydneyli committed Jun 22, 2019
2 parents 1b2fb37 + ec01f31 commit 5236b0a
Show file tree
Hide file tree
Showing 20 changed files with 745 additions and 751 deletions.
76 changes: 29 additions & 47 deletions api.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ type checkPerformer func(API, string) (checker.DomainResult, error)
// Any POST request accepts either URL query parameters or data value parameters,
// and prefers the latter if both are present.
type API struct {
Database db.Database
Database *db.SQLDatabase
CheckDomain checkPerformer
List PolicyList
DontScan map[string]bool
Expand All @@ -61,7 +61,7 @@ type PolicyList interface {
type EmailSender interface {
// SendValidation sends a validation e-mail for a particular domain,
// with a particular validation token.
SendValidation(*models.Domain, string) error
SendValidation(*models.PolicySubmission, string) error
}

// APIResponse wraps all the responses from this API.
Expand Down Expand Up @@ -90,7 +90,7 @@ func (api *API) wrapper(handler apiHandler) func(w http.ResponseWriter, r *http.
}

func defaultCheck(api API, domain string) (checker.DomainResult, error) {
policyChan := models.Domain{Name: domain}.AsyncPolicyListCheck(api.Database, api.List)
policyChan := models.PolicySubmission{Name: domain}.AsyncPolicyListCheck(api.Database.PendingPolicies, api.Database.Policies, api.List)
c := checker.Checker{
Cache: &checker.ScanCache{
ScanStore: api.Database,
Expand Down Expand Up @@ -171,47 +171,42 @@ func (api API) Scan(r *http.Request) APIResponse {
// MaxHostnames is the maximum number of hostnames that can be specified for a single domain's TLS policy.
const MaxHostnames = 8

// Extracts relevant parameters from http.Request for a POST to /api/queue
// TODO: also validate hostnames as FQDNs.
func getDomainParams(r *http.Request) (models.Domain, error) {
// Extracts relevant parameters from http.Request for a POST to /api/queue into PolicySubmission
// If MTASTS is set, doesn't try to extract hostnames. Otherwise, expects between 1 and MaxHostnames
// valid hostnames to be given in |r|.
func getDomainParams(r *http.Request) (models.PolicySubmission, error) {
name, err := getASCIIDomain(r)
if err != nil {
return models.Domain{}, err
return models.PolicySubmission{}, err
}
email, err := getParam("email", r)
if err != nil {
email = validationAddress(name)
}
mtasts := r.FormValue("mta-sts")
domain := models.Domain{
domain := models.PolicySubmission{
Name: name,
Email: email,
MTASTS: mtasts == "on",
State: models.StateUnconfirmed,
}
email, err := getParam("email", r)
if err == nil {
domain.Email = email
} else {
domain.Email = validationAddress(&domain)
}
queueWeeks, err := getInt("weeks", r, 4, 52, 4)
if err != nil {
return domain, err
}
domain.QueueWeeks = queueWeeks

if mtasts != "on" {
if !domain.MTASTS {
p := policy.TLSPolicy{Mode: "testing", MXs: make([]string, 0)}
for _, hostname := range r.PostForm["hostnames"] {
if len(hostname) == 0 {
continue
}
if !validDomainName(strings.TrimPrefix(hostname, ".")) {
return domain, fmt.Errorf("Hostname %s is invalid", hostname)
}
domain.MXs = append(domain.MXs, hostname)
p.MXs = append(p.MXs, hostname)
}
if len(domain.MXs) == 0 {
return domain, fmt.Errorf("No MX hostnames supplied for domain %s", domain.Name)
if len(p.MXs) == 0 {
return domain, fmt.Errorf("No MX hostnames supplied for domain %s", name)
}
if len(domain.MXs) > MaxHostnames {
if len(p.MXs) > MaxHostnames {
return domain, fmt.Errorf("No more than 8 MX hostnames are permitted")
}
domain.Policy = &p
}
return domain, nil
}
Expand All @@ -221,7 +216,7 @@ func getDomainParams(r *http.Request) (models.Domain, error) {
// domain: Mail domain to queue a TLS policy for.
// mta_sts: "on" if domain supports MTA-STS, else "".
// hostnames: List of MX hostnames to put into this domain's TLS policy. Up to 8.
// Sets models.Domain object as response.
// Sets models.PolicySubmission object as response.
// weeks (optional, default 4): How many weeks is this domain queued for.
// email (optional): Contact email associated with domain.
// GET /api/queue?domain=<domain>
Expand All @@ -233,12 +228,14 @@ func (api API) Queue(r *http.Request) APIResponse {
if err != nil {
return badRequest(err.Error())
}
ok, msg, scan := domain.IsQueueable(api.Database, api.Database, api.List)
if !domain.CanUpdate(api.Database.Policies) {
return badRequest("existing submission can't be updated")
}
ok, msg := domain.HasValidScan(api.Database)
if !ok {
return badRequest(msg)
}
domain.PopulateFromScan(scan)
token, err := domain.InitializeWithToken(api.Database, api.Database)
token, err := domain.InitializeWithToken(api.Database.PendingPolicies, api.Database)
if err != nil {
return serverError(err.Error())
}
Expand All @@ -251,23 +248,8 @@ func (api API) Queue(r *http.Request) APIResponse {
Response: fmt.Sprintf("Thank you for submitting your domain. Please check postmaster@%s to validate that you control the domain.", domain.Name),
}
}
// GET: Retrieve domain status from queue
if r.Method == http.MethodGet {
domainName, err := getASCIIDomain(r)
if err != nil {
return badRequest(err.Error())
}
domainObj, err := models.GetDomain(api.Database, domainName)
if err != nil {
return APIResponse{StatusCode: http.StatusNotFound, Message: err.Error()}
}
return APIResponse{
StatusCode: http.StatusOK,
Response: domainObj,
}
}
return APIResponse{StatusCode: http.StatusMethodNotAllowed,
Message: "/api/queue only accepts POST and GET requests"}
Message: "/api/queue only accepts POST requests"}
}

// Validate handles requests to /api/validate
Expand All @@ -284,7 +266,7 @@ func (api API) Validate(r *http.Request) APIResponse {
Message: "/api/validate only accepts POST requests"}
}
tokenData := models.Token{Token: token}
domain, userErr, dbErr := tokenData.Redeem(api.Database, api.Database)
domain, userErr, dbErr := tokenData.Redeem(api.Database.PendingPolicies, api.Database.Policies, api.Database)
if userErr != nil {
return badRequest(userErr.Error())
}
Expand Down
1 change: 1 addition & 0 deletions db/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Database structure
8 changes: 0 additions & 8 deletions db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,6 @@ type Database interface {
PutLocalStats(time.Time) (checker.AggregatedScan, error)
// Gets counts per day of hosts supporting MTA-STS for a given source.
GetStats(string) (stats.Series, error)
// Upserts domain state.
PutDomain(models.Domain) error
// Retrieves state of a domain
GetDomain(string, models.DomainState) (models.Domain, error)
// Retrieves all domains in a particular state.
GetDomains(models.DomainState) ([]models.Domain, error)
SetStatus(string, models.DomainState) error
RemoveDomain(string, models.DomainState) (models.Domain, error)
ClearTables() error
}

Expand Down
122 changes: 122 additions & 0 deletions db/policy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
package db

import (
"database/sql"
"fmt"
"strings"

"github.com/EFForg/starttls-backend/models"
"github.com/EFForg/starttls-backend/policy"
)

// PolicyDB is a database of PolicySubmissions.
type PolicyDB struct {
tableName string
conn *sql.DB
strict bool
}

func (p *PolicyDB) formQuery(query string) string {
return fmt.Sprintf(query, p.tableName, "domain, email, mta_sts, mxs, mode")
}

type scanner interface {
Scan(dest ...interface{}) error
}

func (p *PolicyDB) scanPolicy(result scanner) (models.PolicySubmission, error) {
data := models.PolicySubmission{Policy: new(policy.TLSPolicy)}
var rawMXs string
err := result.Scan(
&data.Name, &data.Email,
&data.MTASTS, &rawMXs, &data.Policy.Mode)
data.Policy.MXs = strings.Split(rawMXs, ",")
return data, err
}

// GetPolicies returns a list of policy submissions that match
// the mtasts status given.
func (p *PolicyDB) GetPolicies(mtasts bool) ([]models.PolicySubmission, error) {
rows, err := p.conn.Query(p.formQuery(
"SELECT %[2]s FROM %[1]s WHERE mta_sts=$1"), mtasts)
if err != nil {
return nil, err
}
defer rows.Close()
policies := []models.PolicySubmission{}
for rows.Next() {
policy, err := p.scanPolicy(rows)
if err != nil {
return nil, err
}
policies = append(policies, policy)
}
return policies, nil
}

// GetPolicy returns the policy submission for the given domain.
func (p *PolicyDB) GetPolicy(domainName string) (models.PolicySubmission, bool, error) {
row := p.conn.QueryRow(p.formQuery(
"SELECT %[2]s FROM %[1]s WHERE domain=$1"), domainName)
result, err := p.scanPolicy(row)
if err == sql.ErrNoRows {
return result, false, nil
}
return result, true, err
}

// RemovePolicy removes the policy submission with the given domain from
// the database.
func (p *PolicyDB) RemovePolicy(domainName string) (models.PolicySubmission, error) {
row := p.conn.QueryRow(p.formQuery(
"DELETE FROM %[1]s WHERE domain=$1 RETURNING %[2]s"), domainName)
return p.scanPolicy(row)
}

// PutOrUpdatePolicy upserts the given policy into the data store, if
// CanUpdate passes.
func (p *PolicyDB) PutOrUpdatePolicy(ps *models.PolicySubmission) error {
if p.strict && !ps.CanUpdate(p) {
return fmt.Errorf("can't update policy in restricted table")
}
if p.strict && ps.Policy == nil {
return fmt.Errorf("can't degrade policy in restricted table")
}
if ps.Policy == nil {
ps.Policy = &policy.TLSPolicy{MXs: []string{}, Mode: ""}
}
_, err := p.conn.Exec(p.formQuery(
"INSERT INTO %[1]s(%[2]s) VALUES($1, $2, $3, $4, $5) "+
"ON CONFLICT (domain) DO UPDATE SET "+
"email=$2, mta_sts=$3, mxs=$4, mode=$5"),
ps.Name, ps.Email, ps.MTASTS,
strings.Join(ps.Policy.MXs[:], ","), ps.Policy.Mode)
return err
}

// DomainsToValidate [interface Validator] retrieves domains from the
// DB whose policies should be validated-- all Pending policies.
func (p *PolicyDB) DomainsToValidate() ([]string, error) {
domains := []string{}
data, err := p.GetPolicies(true)
if err != nil {
return domains, err
}
for _, domainInfo := range data {
domains = append(domains, domainInfo.Name)
}
return domains, nil
}

// HostnamesForDomain [interface Validator] retrieves the hostname policy for
// a particular domain in Pending.
func (db SQLDatabase) HostnamesForDomain(domain string) ([]string, error) {
data, ok, err := db.PendingPolicies.GetPolicy(domain)
if !ok {
err = fmt.Errorf("domain %s not in database", domain)
}
if err != nil {
return []string{}, err
}
return data.Policy.MXs, nil
}
19 changes: 19 additions & 0 deletions db/scripts/init_tables.sql
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,25 @@ CREATE TABLE IF NOT EXISTS blacklisted_emails
timestamp TIMESTAMP
);

CREATE TABLE IF NOT EXISTS pending_policies
(
domain TEXT NOT NULL PRIMARY KEY,
email TEXT NOT NULL,
mta_sts BOOLEAN DEFAULT FALSE,
mxs TEXT NOT NULL,
mode VARCHAR(255) NOT NULL
);


CREATE TABLE IF NOT EXISTS policies
(
domain TEXT NOT NULL PRIMARY KEY,
email TEXT NOT NULL,
mta_sts BOOLEAN DEFAULT FALSE,
mxs TEXT NOT NULL,
mode VARCHAR(255) NOT NULL
);

-- Schema change: add "last_updated" timestamp column if it doesn't exist.

ALTER TABLE domains ADD COLUMN IF NOT EXISTS last_updated TIMESTAMP DEFAULT CURRENT_TIMESTAMP;
Expand Down

0 comments on commit 5236b0a

Please sign in to comment.