diff --git a/api.go b/api.go index a8a72c9e..edfc68e1 100644 --- a/api.go +++ b/api.go @@ -3,8 +3,10 @@ package main import ( "encoding/json" "fmt" + "html/template" "log" "net/http" + "os" "strings" "time" @@ -62,9 +64,10 @@ type EmailSender interface { // APIResponse wraps all the responses from this API. type APIResponse struct { - StatusCode int `json:"status_code"` - Message string `json:"message"` - Response interface{} `json:"response"` + StatusCode int `json:"status_code"` + Message string `json:"message"` + Response interface{} `json:"response"` + TemplatePath string `json:"-"` } type apiHandler func(r *http.Request) APIResponse @@ -72,20 +75,21 @@ type apiHandler func(r *http.Request) APIResponse func apiWrapper(api apiHandler) func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) { response := api(r) - if response.StatusCode != http.StatusOK { - http.Error(w, response.Message, response.StatusCode) - } if response.StatusCode == http.StatusInternalServerError { packet := raven.NewPacket(response.Message, raven.NewHttp(r)) raven.Capture(packet, nil) } - writeJSON(w, response) + if strings.Contains(r.Header.Get("accept"), "text/html") { + writeHTML(w, response) + } else { + writeJSON(w, response) + } } } // Checks the policy status of this domain. func (api API) policyCheck(domain string) *checker.Result { - result := checker.Result{Name: "policylist"} + result := checker.Result{Name: checker.PolicyList} if _, err := api.List.Get(domain); err == nil { return result.Success() } @@ -153,16 +157,20 @@ func (api API) Scan(r *http.Request) APIResponse { scan, err := api.Database.GetLatestScan(domain) if err == nil && scan.Version == models.ScanVersion && time.Now().Before(scan.Timestamp.Add(cacheScanTime)) { - return APIResponse{StatusCode: http.StatusOK, Response: scan} + return APIResponse{ + StatusCode: http.StatusOK, + Response: scan, + TemplatePath: "views/scan.html.tmpl", + } } // 1. Conduct scan via starttls-checker - rawScandata, err := api.CheckDomain(api, domain) + scanData, err := api.CheckDomain(api, domain) if err != nil { return APIResponse{StatusCode: http.StatusInternalServerError, Message: err.Error()} } scan = models.Scan{ Domain: domain, - Data: rawScandata, + Data: scanData, Timestamp: time.Now(), Version: models.ScanVersion, } @@ -171,7 +179,11 @@ func (api API) Scan(r *http.Request) APIResponse { if err != nil { return APIResponse{StatusCode: http.StatusInternalServerError, Message: err.Error()} } - return APIResponse{StatusCode: http.StatusOK, Response: scan} + return APIResponse{ + StatusCode: http.StatusOK, + Response: scan, + TemplatePath: "views/scan.html.tmpl", + } // GET: Just fetch the most recent scan } else if r.Method == http.MethodGet { scan, err := api.Database.GetLatestScan(domain) @@ -278,14 +290,23 @@ func (api API) Queue(r *http.Request) APIResponse { return APIResponse{StatusCode: http.StatusInternalServerError, Message: "Unable to send validation e-mail"} } - return APIResponse{StatusCode: http.StatusOK, Response: domainData} + // domainData.State = Unvalidated + // or queued? + return APIResponse{ + StatusCode: http.StatusOK, + Response: fmt.Sprintf("Thank you for submitting your domain. Please check postmaster@%s to validate that you control the domain.", domain), + } // GET: Retrieve domain status from queue + // JSON only } else if r.Method == http.MethodGet { status, err := api.Database.GetDomain(domain) if err != nil { return APIResponse{StatusCode: http.StatusNotFound, Message: err.Error()} } - return APIResponse{StatusCode: http.StatusOK, Response: status} + return APIResponse{ + StatusCode: http.StatusOK, + Response: status, + } } else { return APIResponse{StatusCode: http.StatusMethodNotAllowed, Message: "/api/queue only accepts POST and GET requests"} @@ -352,9 +373,10 @@ func getParam(param string, r *http.Request) (string, error) { // Writes `v` as a JSON object to http.ResponseWriter `w`. If an error // occurs, writes `http.StatusInternalServerError` to `w`. -func writeJSON(w http.ResponseWriter, v interface{}) { +func writeJSON(w http.ResponseWriter, apiResponse APIResponse) { w.Header().Set("Content-Type", "application/json; charset=utf-8") - b, err := json.MarshalIndent(v, "", " ") + w.WriteHeader(apiResponse.StatusCode) + b, err := json.MarshalIndent(apiResponse, "", " ") if err != nil { msg := fmt.Sprintf("Internal error: could not format JSON. (%s)\n", err) http.Error(w, msg, http.StatusInternalServerError) @@ -362,3 +384,32 @@ func writeJSON(w http.ResponseWriter, v interface{}) { } fmt.Fprintf(w, "%s\n", b) } + +func writeHTML(w http.ResponseWriter, apiResponse APIResponse) { + // Add some additional useful fields for use in templates. + if apiResponse.TemplatePath == "" { + apiResponse.TemplatePath = "views/default.html.tmpl" + } + data := struct { + APIResponse + BaseURL string + StatusText string + }{ + APIResponse: apiResponse, + BaseURL: os.Getenv("FRONTEND_WEBSITE_LINK"), + StatusText: http.StatusText(apiResponse.StatusCode), + } + tmpl, err := template.ParseFiles(apiResponse.TemplatePath) + if err != nil { + log.Println(err) + raven.CaptureError(err, nil) + http.Error(w, "Internal error: could not parse template.", http.StatusInternalServerError) + return + } + w.WriteHeader(apiResponse.StatusCode) + err = tmpl.Execute(w, data) + if err != nil { + log.Println(err) + raven.CaptureError(err, nil) + } +} diff --git a/api_test.go b/api_test.go index 923d4b4b..f3879736 100644 --- a/api_test.go +++ b/api_test.go @@ -1,6 +1,10 @@ package main import ( + "io/ioutil" + "net/http" + "net/url" + "strings" "testing" "github.com/EFForg/starttls-backend/checker" @@ -40,3 +44,21 @@ func TestPolicyCheckWithQueuedDomain(t *testing.T) { t.Errorf("Check should have warned.") } } + +func testHTMLPost(path string, data url.Values, t *testing.T) ([]byte, int) { + req, err := http.NewRequest("POST", server.URL+path, strings.NewReader(data.Encode())) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("accept", "text/html") + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + body, _ := ioutil.ReadAll(resp.Body) + if !strings.Contains(strings.ToLower(string(body)), " Failure > Warning > Success -func SetStatus(oldStatus CheckStatus, newStatus CheckStatus) CheckStatus { +func SetStatus(oldStatus Status, newStatus Status) Status { if newStatus > oldStatus { return newStatus } @@ -28,7 +43,7 @@ func SetStatus(oldStatus CheckStatus, newStatus CheckStatus) CheckStatus { // warning messages associated. type Result struct { Name string `json:"name"` - Status CheckStatus `json:"status"` + Status Status `json:"status"` Messages []string `json:"messages,omitempty"` Checks map[string]*Result `json:"checks,omitempty"` } @@ -91,3 +106,48 @@ func (r *Result) addCheck(checkResult *Result) { // SetStatus sets Result's status to the most severe of any individual check r.Status = SetStatus(r.Status, checkResult.Status) } + +// IDs for checks that can be run +const ( + Connectivity = "connectivity" + STARTTLS = "starttls" + Version = "version" + Certificate = "certificate" + MTASTS = "mta-sts" + MTASTSText = "mta-sts-text" + MTASTSPolicyFile = "mta-sts-policy-file" + PolicyList = "policylist" +) + +// Text descriptions of checks that can be run +var checkNames = map[string]string{ + Connectivity: "Server connectivity", + STARTTLS: "Support for inbound STARTTLS", + Version: "Secure version of TLS", + Certificate: "Valid certificate", + MTASTS: "Inbound MTA-STS support", + MTASTSText: "Correct MTA-STS DNS record", + MTASTSPolicyFile: "Correct MTA-STS policy file", + PolicyList: "Status on EFF's STARTTLS Everywhere policy list", +} + +// Description returns the full-text name of a check. +func (r Result) Description() string { + return checkNames[r.Name] +} + +// MarshalJSON writes Result to JSON. It adds status_text and description to +// the output. +func (r Result) MarshalJSON() ([]byte, error) { + // FakeResult lets us access the default json.Marshall result for Result. + type FakeResult Result + return json.Marshal(struct { + FakeResult + StatusText string `json:"status_text,omitempty"` + Description string `json:"description,omitempty"` + }{ + FakeResult: FakeResult(r), + StatusText: r.StatusText(), + Description: r.Description(), + }) +} diff --git a/checker/result_test.go b/checker/result_test.go new file mode 100644 index 00000000..04a7554a --- /dev/null +++ b/checker/result_test.go @@ -0,0 +1,41 @@ +package checker + +import ( + "bytes" + "encoding/json" + "testing" +) + +func TestMarshalResultJSON(t *testing.T) { + // Should set description and status_text for CheckResult w/ recognized keys + result := Result{ + Name: "starttls", + Status: Success, + } + marshalled, err := json.Marshal(result) + if err != nil { + t.Fatal(err) + } + if !bytes.Contains(marshalled, []byte("\"status_text\":\"Success\"")) { + t.Errorf("Marshalled result should contain status_text, got %s", string(marshalled)) + } + if !bytes.Contains(marshalled, []byte("\"description\":\"")) { + t.Errorf("Marshalled result should contain description, got %s", string(marshalled)) + } + + // Should survive unrecognized keys + result = Result{ + Name: "foo", + Status: 100, + } + marshalled, _ = json.Marshal(result) + if err != nil { + t.Fatal(err) + } + if bytes.Contains(marshalled, []byte("\"status_text\":\"")) { + t.Errorf("Result with unrecognized keys shouldn't output status_text, got %s", string(marshalled)) + } + if bytes.Contains(marshalled, []byte("\"description\":\"")) { + t.Errorf("Result with unrecognized keys shouldn't output status_text, got %s", string(marshalled)) + } +} diff --git a/models/scan.go b/models/scan.go index 883bc781..e39ff912 100644 --- a/models/scan.go +++ b/models/scan.go @@ -16,3 +16,13 @@ type Scan struct { Timestamp time.Time `json:"timestamp"` // Time at which this scan was conducted Version uint32 `json:"version"` // Version counter } + +// CanAddToPolicyList returns true if the domain owner should be prompted to +// add their domain to the STARTTLS Everywhere Policy List. +func (s Scan) CanAddToPolicyList() bool { + if policyResult, ok := s.Data.ExtraResults[checker.PolicyList]; ok { + return s.Data.Status == checker.DomainSuccess && + policyResult.Status == checker.Failure + } + return false +} diff --git a/queue_test.go b/queue_test.go index 5686bc67..d9f7a8d5 100644 --- a/queue_test.go +++ b/queue_test.go @@ -6,6 +6,7 @@ import ( "io/ioutil" "net/http" "net/url" + "strings" "testing" "github.com/EFForg/starttls-backend/models" @@ -23,6 +24,30 @@ func validQueueData(scan bool) url.Values { return data } +func TestQueueHTML(t *testing.T) { + defer teardown() + + body, status := testHTMLPost("/api/queue", validQueueData(true), t) + if status != http.StatusOK { + t.Errorf("HTML POST to api/queue failed with error %d", status) + } + if !strings.Contains(string(body), "Thank you for submitting your domain") { + t.Errorf("Response should describe domain status, got %s", string(body)) + } +} + +func TestQueueErrorHTML(t *testing.T) { + defer teardown() + + body, status := testHTMLPost("/api/queue", url.Values{}, t) + if status != http.StatusBadRequest { + t.Errorf("HTML POST status should be %d, got %d", http.StatusBadRequest, status) + } + if !strings.Contains(string(body), "Bad Request") { + t.Errorf("Response should contain failed status text, got %s", string(body)) + } +} + func TestGetDomainHidesEmail(t *testing.T) { defer teardown() diff --git a/scan_test.go b/scan_test.go index 2d1ef7b9..2004f000 100644 --- a/scan_test.go +++ b/scan_test.go @@ -4,13 +4,80 @@ import ( "encoding/json" "io/ioutil" "net/http" + "net/http/httptest" "net/url" "strings" "testing" + "time" + "github.com/EFForg/starttls-backend/checker" "github.com/EFForg/starttls-backend/models" ) +func TestScanHTMLRequest(t *testing.T) { + defer teardown() + + // Request a scan! + data := url.Values{} + data.Set("domain", "eff.org") + body, status := testHTMLPost("/api/scan", data, t) + if status != http.StatusOK { + t.Errorf("HTML POST to api/scan failed with error %d", status) + } + if !strings.Contains(string(body), "eff.org") { + t.Errorf("Response should contain scan domain, got %s", string(body)) + } +} + +func TestScanWriteHTML(t *testing.T) { + scan := models.Scan{ + Domain: "example.com", + Data: checker.DomainResult{ + Domain: "example.com", + Status: checker.DomainSuccess, + HostnameResults: map[string]checker.HostnameResult{ + "example.com": checker.HostnameResult{ + Domain: "example.com", + Hostname: "mx.example.com", + Result: &checker.Result{ + Checks: map[string]*checker.Result{ + checker.Connectivity: checker.MakeResult(checker.Connectivity), + checker.STARTTLS: checker.MakeResult(checker.STARTTLS), + checker.Certificate: checker.MakeResult(checker.Certificate), + checker.Version: checker.MakeResult(checker.Version), + }, + }, + }, + }, + PreferredHostnames: []string{"mx.example.com"}, + ExtraResults: map[string]*checker.Result{ + checker.PolicyList: checker.MakeResult(checker.PolicyList), + }, + }, + Timestamp: time.Now(), + Version: 1, + } + response := APIResponse{ + StatusCode: http.StatusOK, + Response: scan, + TemplatePath: "views/scan.html.tmpl", + } + + w := httptest.NewRecorder() + writeHTML(w, response) + resp := w.Result() + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + if !strings.Contains(strings.ToLower(string(body)), " + + STARTTLS Everywhere + + + {{ if ne .StatusCode 200 }} +

{{ .StatusText }}

+ {{ end }} + + {{ if ne .Message "" }} +

{{ .Message }}

+ {{ end }} + +

{{ .Response }}

+ + diff --git a/views/scan.html.tmpl b/views/scan.html.tmpl new file mode 100644 index 00000000..b52c9dc5 --- /dev/null +++ b/views/scan.html.tmpl @@ -0,0 +1,47 @@ + + +

Scan results for {{ .Response.Domain }}

+ You're viewing unstyled results. You can enable Javascript to view styled content. + +

Summary

+ {{ if eq .Response.Data.Status 0 }} +

Congratulations, your domain passed all checks.

+ {{ else if eq .Response.Data.Status 1 }} +

Your domain passed all checks with some warnings. See below for details.

+ {{ else }} +

There were some problems with your domain. See below for details.

+ {{ end }} + +

{{ .Response.Data.Message }}

+ +

STARTTLS Everywhere Policy List

+ {{ with index .Response.Data.ExtraResults "policylist" }} + {{ .Description }}: {{ .StatusText }} + + {{ end }} + {{ if .Response.CanAddToPolicyList }} + Add your email domain the STARTTLS Everywhere Policy List + {{ end }} + +

Mailboxes

+ {{ range $hostname, $hostnameResult := .Response.Data.HostnameResults }} +

{{ $hostname }}

+ + {{ end }} + +