Skip to content

Commit

Permalink
Merge 302bc81 into b9434aa
Browse files Browse the repository at this point in the history
  • Loading branch information
vbrown608 committed Jun 13, 2019
2 parents b9434aa + 302bc81 commit e3d24c2
Show file tree
Hide file tree
Showing 8 changed files with 181 additions and 162 deletions.
12 changes: 11 additions & 1 deletion checker/totals.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@ type AggregatedScan struct {
MTASTSEnforceList []string
}

const (
// TopDomainsSource labels aggregated scans of the top million domains.
TopDomainsSource = "TOP_DOMAINS"
// LocalSource labels aggregated scan data for users of the web frontend.
LocalSource = "LOCAL"
)

// TotalMTASTS returns the number of domains supporting test or enforce mode.
func (a AggregatedScan) TotalMTASTS() int {
return a.MTASTSTesting + a.MTASTSEnforce
Expand All @@ -30,7 +37,10 @@ func (a AggregatedScan) TotalMTASTS() int {
// PercentMTASTS returns the fraction of domains with MXs that support
// MTA-STS, represented as a float between 0 and 1.
func (a AggregatedScan) PercentMTASTS() float64 {
return float64(a.TotalMTASTS()) / float64(a.WithMXs)
if a.WithMXs == 0 {
return 0
}
return 100 * float64(a.TotalMTASTS()) / float64(a.WithMXs)
}

// HandleDomain adds the result of a single domain scan to aggregated stats.
Expand Down
7 changes: 4 additions & 3 deletions db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package db
import (
"flag"
"os"
"time"

"github.com/EFForg/starttls-backend/checker"
"github.com/EFForg/starttls-backend/models"
Expand Down Expand Up @@ -34,10 +35,10 @@ type Database interface {
PutHostnameScan(string, checker.HostnameResult) error
// Writes an aggregated scan to the database
PutAggregatedScan(checker.AggregatedScan) error
// Caches stats for the 14 days preceding time.Time
PutLocalStats(time.Time) (checker.AggregatedScan, error)
// Gets counts per day of hosts supporting MTA-STS for a given source.
GetMTASTSStats(string) (stats.Series, error)
// Gets counts per day of hosts scanned by this app supporting MTA-STS adoption.
GetMTASTSLocalStats() (stats.Series, error)
GetStats(string) (stats.Series, error)
// Upserts domain state.
PutDomain(models.Domain) error
// Retrieves state of a domain
Expand Down
84 changes: 37 additions & 47 deletions db/sqldb.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,70 +117,60 @@ func (db *SQLDatabase) PutScan(scan models.Scan) error {
return err
}

// GetMTASTSStats returns statistics about a MTA-STS adoption from a single
// GetStats returns statistics about a MTA-STS adoption from a single
// source domains to check.
func (db *SQLDatabase) GetMTASTSStats(source string) (stats.Series, error) {
func (db *SQLDatabase) GetStats(source string) (stats.Series, error) {
series := stats.Series{}
rows, err := db.conn.Query(
"SELECT time, with_mxs, mta_sts_testing, mta_sts_enforce FROM aggregated_scans WHERE source=$1", source)
`SELECT time, with_mxs, mta_sts_testing, mta_sts_enforce
FROM aggregated_scans
WHERE source=$1
ORDER BY time`, source)
if err != nil {
return stats.Series{}, err
return series, err
}
defer rows.Close()
series := stats.Series{}
for rows.Next() {
var a checker.AggregatedScan
if err := rows.Scan(&a.Time, &a.WithMXs, &a.MTASTSTesting, &a.MTASTSEnforce); err != nil {
return stats.Series{}, err
return series, err
}
series[a.Time.UTC()] = float64(a.TotalMTASTS())
series = append(series, a)
}
return series, nil
}

// GetMTASTSLocalStats returns statistics about MTA-STS adoption in
// user-initiated scans over a rolling 14-day window. Returns a map with:
// key: the final day of a two-week window. Windows last until EOD.
// value: the percent of scans supporting MTA-STS in that window
// @TODO write a simpler query that gets caches totals in the the
// `aggregated_scans` table at the end of each 14-day period
func (db *SQLDatabase) GetMTASTSLocalStats() (stats.Series, error) {
// "day" represents truncated date (ie beginning of day), but windows should
// include the full day, so we add a day when querying timestamps.
// Getting the most recent 31 days for now, we can set the start date to the
// beginning of our MTA-STS data once we have some.
// PutLocalStats writes aggregated stats for the 14 days preceding `date` to
// the aggregated_stats table.
func (db *SQLDatabase) PutLocalStats(date time.Time) (checker.AggregatedScan, error) {
query := `
SELECT day, 100.0 * SUM(
CASE WHEN mta_sts_mode = 'testing' THEN 1 ELSE 0 END +
CASE WHEN mta_sts_mode = 'enforce' THEN 1 ELSE 0 END
) / COUNT(day) as percent
SELECT
COUNT(domain) AS total,
COALESCE ( SUM (
CASE WHEN mta_sts_mode = 'testing' THEN 1 ELSE 0 END
), 0 ) AS testing,
COALESCE ( SUM (
CASE WHEN mta_sts_mode = 'enforce' THEN 1 ELSE 0 END
), 0 ) AS enforce
FROM (
SELECT date_trunc('day', d)::date AS day
FROM generate_series(CURRENT_DATE-31, CURRENT_DATE, '1 day'::INTERVAL) d )
AS days
INNER JOIN LATERAL (
SELECT DISTINCT ON (domain) domain, timestamp, mta_sts_mode
FROM scans
WHERE timestamp BETWEEN day - '13 days'::INTERVAL AND day + '1 day'::INTERVAL
ORDER BY domain, timestamp DESC
) AS most_recent_scans ON TRUE
GROUP BY day;`

rows, err := db.conn.Query(query)
if err != nil {
return nil, err
SELECT DISTINCT ON (domain) domain, timestamp, mta_sts_mode
FROM scans
WHERE timestamp BETWEEN $1 AND $2
ORDER BY domain, timestamp DESC
) AS latest_domains;
`
start := date.Add(-14 * 24 * time.Hour)
end := date
a := checker.AggregatedScan{
Source: checker.LocalSource,
Time: date,
}
defer rows.Close()

ts := make(map[time.Time]float64)
for rows.Next() {
var t time.Time
var count float64
if err := rows.Scan(&t, &count); err != nil {
return nil, err
}
ts[t.UTC()] = count
err := db.conn.QueryRow(query, start.UTC(), end.UTC()).Scan(&a.WithMXs, &a.MTASTSTesting, &a.MTASTSEnforce)
if err != nil {
return a, err
}
return ts, nil
err = db.PutAggregatedScan(a)
return a, err
}

const mostRecentQuery = `
Expand Down
111 changes: 52 additions & 59 deletions db/sqldb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"github.com/EFForg/starttls-backend/checker"
"github.com/EFForg/starttls-backend/db"
"github.com/EFForg/starttls-backend/models"
"github.com/EFForg/starttls-backend/stats"
"github.com/joho/godotenv"
)

Expand Down Expand Up @@ -294,7 +293,7 @@ func TestHostnamesForDomain(t *testing.T) {
}

func TestPutAndIsBlacklistedEmail(t *testing.T) {
defer database.ClearTables()
database.ClearTables()

// Add an e-mail address to the blacklist.
err := database.PutBlacklistedEmail("fail@example.com", "bounce", "2017-07-21T18:47:13.498Z")
Expand Down Expand Up @@ -354,22 +353,22 @@ func dateMustParse(date string, t *testing.T) time.Time {
return parsed
}

func TestGetMTASTSStats(t *testing.T) {
func TestGetStats(t *testing.T) {
database.ClearTables()
may1 := dateMustParse("2019-May-01", t)
may2 := dateMustParse("2019-May-02", t)
data := []checker.AggregatedScan{
checker.AggregatedScan{
Time: may1,
Source: "domains-depot",
Source: checker.TopDomainsSource,
Attempted: 5,
WithMXs: 4,
MTASTSTesting: 2,
MTASTSEnforce: 1,
},
checker.AggregatedScan{
Time: may2,
Source: "domains-depot",
Source: checker.TopDomainsSource,
Attempted: 10,
WithMXs: 8,
MTASTSTesting: 1,
Expand All @@ -382,16 +381,45 @@ func TestGetMTASTSStats(t *testing.T) {
t.Fatal(err)
}
}
result, err := database.GetMTASTSStats("domains-depot")
result, err := database.GetStats(checker.TopDomainsSource)
if err != nil {
t.Fatal(err)
}
if result[may1] != 3 || result[may2] != 4 {
if result[0].TotalMTASTS() != 3 || result[1].TotalMTASTS() != 4 {
t.Errorf("Incorrect MTA-STS stats, got %v", result)
}
}

func TestGetMTASTSLocalStats(t *testing.T) {
func TestPutLocalStats(t *testing.T) {
database.ClearTables()
a, err := database.PutLocalStats(time.Now())
if err != nil {
t.Fatal(err)
}
if a.PercentMTASTS() != 0 {
t.Errorf("Expected PercentMTASTS with no recent scans to be 0, got %v",
a.PercentMTASTS())
}
day := time.Hour * 24
today := time.Now()
lastWeek := today.Add(-6 * day)
s := models.Scan{
Domain: "example1.com",
Data: checker.NewSampleDomainResult("example1.com"),
Timestamp: lastWeek,
}
database.PutScan(s)
a, err = database.PutLocalStats(time.Now())
if err != nil {
t.Fatal(err)
}
if a.PercentMTASTS() != 100 {
t.Errorf("Expected PercentMTASTS with one recent scan to be 100, got %v",
a.PercentMTASTS())
}
}

func TestGetLocalStats(t *testing.T) {
database.ClearTables()
day := time.Hour * 24
today := time.Now()
Expand All @@ -402,40 +430,20 @@ func TestGetMTASTSLocalStats(t *testing.T) {
s := models.Scan{
Domain: "example1.com",
Data: checker.NewSampleDomainResult("example1.com"),
Timestamp: lastWeek,
Timestamp: lastWeek.Add(1 * day),
}
database.PutScan(s)
s.Timestamp = lastWeek.Add(3 * day)
s.Data.MTASTSResult.Mode = ""
database.PutScan(s)
// Support is shown in the rolling average until the no-support scan is
// included.
expectStats(stats.Series{
lastWeek: 100,
lastWeek.Add(day): 100,
lastWeek.Add(2 * day): 100,
lastWeek.Add(3 * day): 0,
lastWeek.Add(4 * day): 0,
lastWeek.Add(5 * day): 0,
lastWeek.Add(6 * day): 0,
}, t)

// Add another recent scan, from a second domain.
s = models.Scan{
Domain: "example2.com",
Data: checker.NewSampleDomainResult("example2.com"),
Timestamp: lastWeek.Add(1 * day),
Timestamp: lastWeek.Add(2 * day),
}
database.PutScan(s)
expectStats(stats.Series{
lastWeek: 100,
lastWeek.Add(day): 100,
lastWeek.Add(2 * day): 100,
lastWeek.Add(3 * day): 50,
lastWeek.Add(4 * day): 50,
lastWeek.Add(5 * day): 50,
lastWeek.Add(6 * day): 50,
}, t)

// Add a third scan to check that floats are outputted correctly.
s = models.Scan{
Expand All @@ -444,40 +452,25 @@ func TestGetMTASTSLocalStats(t *testing.T) {
Timestamp: lastWeek.Add(6 * day),
}
database.PutScan(s)
expectStats(stats.Series{
lastWeek: 100,
lastWeek.Add(day): 100,
lastWeek.Add(2 * day): 100,
lastWeek.Add(3 * day): 50,
lastWeek.Add(4 * day): 50,
lastWeek.Add(5 * day): 50,
lastWeek.Add(6 * day): 66.66666666666667,
}, t)
}

func expectStats(ts stats.Series, t *testing.T) {
// GetMTASTSStats returns dates only (no hours, minutes, seconds). We need
// to truncate the expected times for comparison to dates and convert to UTC
// to match the database's timezone.
expected := make(map[time.Time]float64)
for kOld, v := range ts {
k := kOld.UTC().Truncate(24 * time.Hour)
expected[k] = v
// Write stats to the database for all the windows we want to check.
for i := 0; i < 7; i++ {
database.PutLocalStats(lastWeek.Add(day * time.Duration(i)))
}
got, err := database.GetMTASTSLocalStats()

stats, err := database.GetStats(checker.LocalSource)
if err != nil {
t.Fatal(err)
}
if len(expected) != len(got) {
t.Errorf("Expected MTA-STS stats to be\n %v\ngot\n %v\n", expected, got)
return
}
for expKey, expVal := range expected {
// DB query returns dates only (no hours, minutes, seconds).
key := expKey.Truncate(24 * time.Hour)
if got[key] != expVal {
t.Errorf("Expected MTA-STS stats to be\n %v\ngot\n %v\n", expected, got)
return

// Validate result
expPcts := []float64{0, 100, 100, 50, 50, 50, 100 * 2 / float64(3)}
if len(expPcts) != 7 {
t.Errorf("Expected 7 stats, got\n %v\n", stats)
}
for i, got := range stats {
if got.PercentMTASTS() != expPcts[i] {
t.Errorf("\nExpected %v%%\nGot %v\n (%v%%)", expPcts[i], got, got.PercentMTASTS())
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,6 @@ func main() {
log.Println("[Starting queued validator]")
go validator.ValidateRegularly("Testing domains", db, 24*time.Hour)
}
go stats.ImportRegularly(db, time.Hour)
go stats.UpdateRegularly(db, time.Hour)
ServePublicEndpoints(&api, &cfg)
}

0 comments on commit e3d24c2

Please sign in to comment.