Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 67 additions & 16 deletions api.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ package main
import (
"encoding/json"
"fmt"
"html/template"
"log"
"net/http"
"os"
"strings"
"time"

Expand Down Expand Up @@ -62,30 +64,32 @@ type EmailSender interface {

// APIResponse wraps all the responses from this API.
type APIResponse struct {
StatusCode int `json:"status_code"`
Message string `json:"message"`
Response interface{} `json:"response"`
StatusCode int `json:"status_code"`
Message string `json:"message"`
Response interface{} `json:"response"`
TemplatePath string `json:"-"`
}

type apiHandler func(r *http.Request) APIResponse

func apiWrapper(api apiHandler) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
response := api(r)
if response.StatusCode != http.StatusOK {
http.Error(w, response.Message, response.StatusCode)
}
if response.StatusCode == http.StatusInternalServerError {
packet := raven.NewPacket(response.Message, raven.NewHttp(r))
raven.Capture(packet, nil)
}
writeJSON(w, response)
if strings.Contains(r.Header.Get("accept"), "text/html") {
writeHTML(w, response)
} else {
writeJSON(w, response)
}
}
}

// Checks the policy status of this domain.
func (api API) policyCheck(domain string) *checker.Result {
result := checker.Result{Name: "policylist"}
result := checker.Result{Name: checker.PolicyList}
if _, err := api.List.Get(domain); err == nil {
return result.Success()
}
Expand Down Expand Up @@ -153,16 +157,20 @@ func (api API) Scan(r *http.Request) APIResponse {
scan, err := api.Database.GetLatestScan(domain)
if err == nil && scan.Version == models.ScanVersion &&
time.Now().Before(scan.Timestamp.Add(cacheScanTime)) {
return APIResponse{StatusCode: http.StatusOK, Response: scan}
return APIResponse{
StatusCode: http.StatusOK,
Response: scan,
TemplatePath: "views/scan.html.tmpl",
}
}
// 1. Conduct scan via starttls-checker
rawScandata, err := api.CheckDomain(api, domain)
scanData, err := api.CheckDomain(api, domain)
if err != nil {
return APIResponse{StatusCode: http.StatusInternalServerError, Message: err.Error()}
}
scan = models.Scan{
Domain: domain,
Data: rawScandata,
Data: scanData,
Timestamp: time.Now(),
Version: models.ScanVersion,
}
Expand All @@ -171,7 +179,11 @@ func (api API) Scan(r *http.Request) APIResponse {
if err != nil {
return APIResponse{StatusCode: http.StatusInternalServerError, Message: err.Error()}
}
return APIResponse{StatusCode: http.StatusOK, Response: scan}
return APIResponse{
StatusCode: http.StatusOK,
Response: scan,
TemplatePath: "views/scan.html.tmpl",
}
// GET: Just fetch the most recent scan
} else if r.Method == http.MethodGet {
scan, err := api.Database.GetLatestScan(domain)
Expand Down Expand Up @@ -278,14 +290,23 @@ func (api API) Queue(r *http.Request) APIResponse {
return APIResponse{StatusCode: http.StatusInternalServerError,
Message: "Unable to send validation e-mail"}
}
return APIResponse{StatusCode: http.StatusOK, Response: domainData}
// domainData.State = Unvalidated
// or queued?
Copy link
Collaborator Author

@vbrown608 vbrown608 Jan 31, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looked into this question - resolved in follow-up PR (and I removed the cryptic comment there).

return APIResponse{
StatusCode: http.StatusOK,
Response: fmt.Sprintf("Thank you for submitting your domain. Please check postmaster@%s to validate that you control the domain.", domain),
}
// GET: Retrieve domain status from queue
// JSON only
} else if r.Method == http.MethodGet {
status, err := api.Database.GetDomain(domain)
if err != nil {
return APIResponse{StatusCode: http.StatusNotFound, Message: err.Error()}
}
return APIResponse{StatusCode: http.StatusOK, Response: status}
return APIResponse{
StatusCode: http.StatusOK,
Response: status,
}
} else {
return APIResponse{StatusCode: http.StatusMethodNotAllowed,
Message: "/api/queue only accepts POST and GET requests"}
Expand Down Expand Up @@ -352,13 +373,43 @@ func getParam(param string, r *http.Request) (string, error) {

// Writes `v` as a JSON object to http.ResponseWriter `w`. If an error
// occurs, writes `http.StatusInternalServerError` to `w`.
func writeJSON(w http.ResponseWriter, v interface{}) {
func writeJSON(w http.ResponseWriter, apiResponse APIResponse) {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
b, err := json.MarshalIndent(v, "", " ")
w.WriteHeader(apiResponse.StatusCode)
b, err := json.MarshalIndent(apiResponse, "", " ")
if err != nil {
msg := fmt.Sprintf("Internal error: could not format JSON. (%s)\n", err)
http.Error(w, msg, http.StatusInternalServerError)
return
}
fmt.Fprintf(w, "%s\n", b)
}

func writeHTML(w http.ResponseWriter, apiResponse APIResponse) {
// Add some additional useful fields for use in templates.
if apiResponse.TemplatePath == "" {
apiResponse.TemplatePath = "views/default.html.tmpl"
}
data := struct {
APIResponse
BaseURL string
StatusText string
}{
APIResponse: apiResponse,
BaseURL: os.Getenv("FRONTEND_WEBSITE_LINK"),
StatusText: http.StatusText(apiResponse.StatusCode),
}
tmpl, err := template.ParseFiles(apiResponse.TemplatePath)
if err != nil {
log.Println(err)
raven.CaptureError(err, nil)
http.Error(w, "Internal error: could not parse template.", http.StatusInternalServerError)
return
}
w.WriteHeader(apiResponse.StatusCode)
err = tmpl.Execute(w, data)
if err != nil {
log.Println(err)
raven.CaptureError(err, nil)
}
}
22 changes: 22 additions & 0 deletions api_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
package main

import (
"io/ioutil"
"net/http"
"net/url"
"strings"
"testing"

"github.com/EFForg/starttls-backend/checker"
Expand Down Expand Up @@ -40,3 +44,21 @@ func TestPolicyCheckWithQueuedDomain(t *testing.T) {
t.Errorf("Check should have warned.")
}
}

func testHTMLPost(path string, data url.Values, t *testing.T) ([]byte, int) {
req, err := http.NewRequest("POST", server.URL+path, strings.NewReader(data.Encode()))
if err != nil {
t.Fatal(err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("accept", "text/html")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatal(err)
}
body, _ := ioutil.ReadAll(resp.Body)
if !strings.Contains(strings.ToLower(string(body)), "</html") {
t.Errorf("Response should be HTML, got %s", string(body))
}
return body, resp.StatusCode
}
2 changes: 1 addition & 1 deletion checker/domain.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ func (d DomainResult) Class() string {
}

func (d DomainResult) setStatus(status DomainStatus) DomainResult {
d.Status = DomainStatus(SetStatus(CheckStatus(d.Status), CheckStatus(status)))
d.Status = DomainStatus(SetStatus(Status(d.Status), Status(status)))
return d
}

Expand Down
24 changes: 12 additions & 12 deletions checker/domain_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,23 @@ var mxLookup = map[string][]string{
// Fake hostname checks :)
var hostnameResults = map[string]Result{
"noconnection": Result{
Status: Error,
Status: 3,
Checks: map[string]*Result{
"connectivity": {"connectivity", Error, nil, nil},
Connectivity: {Connectivity, 3, nil, nil},
},
},
"nostarttls": Result{
Status: Failure,
Status: 2,
Checks: map[string]*Result{
"connectivity": {"connectivity", 0, nil, nil},
"starttls": {"starttls", Failure, nil, nil},
Connectivity: {Connectivity, 0, nil, nil},
STARTTLS: {STARTTLS, 2, nil, nil},
},
},
"nostarttlsconnect": Result{
Status: Error,
Status: 3,
Checks: map[string]*Result{
"connectivity": {"connectivity", 0, nil, nil},
"starttls": {"starttls", Error, nil, nil},
Connectivity: {Connectivity, 0, nil, nil},
STARTTLS: {STARTTLS, 3, nil, nil},
},
},
}
Expand Down Expand Up @@ -73,10 +73,10 @@ func mockCheckHostname(domain string, hostname string) HostnameResult {
Result: &Result{
Status: 0,
Checks: map[string]*Result{
"connectivity": {"connectivity", 0, nil, nil},
"starttls": {"starttls", 0, nil, nil},
"certificate": {"certificate", 0, nil, nil},
"version": {"version", 0, nil, nil},
Connectivity: {Connectivity, 0, nil, nil},
STARTTLS: {STARTTLS, 0, nil, nil},
Certificate: {Certificate, 0, nil, nil},
Version: {Version, 0, nil, nil},
},
},
Timestamp: time.Now(),
Expand Down
12 changes: 6 additions & 6 deletions checker/hostname.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ func (h HostnameResult) checkSucceeded(checkName string) bool {
}

func (h HostnameResult) couldConnect() bool {
return h.checkSucceeded("connectivity")
return h.checkSucceeded(Connectivity)
}

func (h HostnameResult) couldSTARTTLS() bool {
return h.checkSucceeded("starttls")
return h.checkSucceeded(STARTTLS)
}

// Modelled after policyMatches in Appendix B of the MTA-STS RFC 8641.
Expand Down Expand Up @@ -95,7 +95,7 @@ func smtpDialWithTimeout(hostname string, timeout time.Duration) (*smtp.Client,

// Simply tries to StartTLS with the server.
func checkStartTLS(client *smtp.Client) *Result {
result := MakeResult("starttls")
result := MakeResult(STARTTLS)
ok, _ := client.Extension("StartTLS")
if !ok {
return result.Failure("Server does not advertise support for STARTTLS.")
Expand Down Expand Up @@ -141,7 +141,7 @@ var certRoots *x509.CertPool
// Checks that the certificate presented is valid for a particular hostname, unexpired,
// and chains to a trusted root.
func checkCert(client *smtp.Client, domain, hostname string) *Result {
result := MakeResult("certificate")
result := MakeResult(Certificate)
state, ok := client.TLSConnectionState()
if !ok {
return result.Error("TLS not initiated properly.")
Expand Down Expand Up @@ -188,7 +188,7 @@ func checkTLSCipher(hostname string, timeout time.Duration) *Result {
}

func checkTLSVersion(client *smtp.Client, hostname string, timeout time.Duration) *Result {
result := MakeResult("version")
result := MakeResult(Version)

// Check the TLS version of the existing connection.
tlsConnectionState, ok := client.TLSConnectionState()
Expand Down Expand Up @@ -235,7 +235,7 @@ func (c *Checker) CheckHostname(domain string, hostname string) HostnameResult {
}

// Connect to the SMTP server and use that connection to perform as many checks as possible.
connectivityResult := MakeResult("connectivity")
connectivityResult := MakeResult(Connectivity)
client, err := smtpDialWithTimeout(hostname, c.timeout())
if err != nil {
result.addCheck(connectivityResult.Error("Could not establish connection: %v", err))
Expand Down
38 changes: 19 additions & 19 deletions checker/hostname_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ func TestNoConnection(t *testing.T) {
expected := Result{
Status: 3,
Checks: map[string]*Result{
"connectivity": {"connectivity", 3, nil, nil},
"connectivity": {Connectivity, 3, nil, nil},
},
}
compareStatuses(t, expected, result)
Expand All @@ -122,8 +122,8 @@ func TestNoTLS(t *testing.T) {
expected := Result{
Status: 2,
Checks: map[string]*Result{
"connectivity": {"connectivity", 0, nil, nil},
"starttls": {"starttls", 2, nil, nil},
Connectivity: {Connectivity, 0, nil, nil},
STARTTLS: {STARTTLS, 2, nil, nil},
},
}
compareStatuses(t, expected, result)
Expand All @@ -142,10 +142,10 @@ func TestSelfSigned(t *testing.T) {
expected := Result{
Status: 2,
Checks: map[string]*Result{
"connectivity": {"connectivity", 0, nil, nil},
"starttls": {"starttls", 0, nil, nil},
"certificate": {"certificate", 2, nil, nil},
"version": {"version", 0, nil, nil},
Connectivity: {Connectivity, 0, nil, nil},
STARTTLS: {STARTTLS, 0, nil, nil},
Certificate: {Certificate, 2, nil, nil},
Version: {Version, 0, nil, nil},
},
}
compareStatuses(t, expected, result)
Expand All @@ -168,10 +168,10 @@ func TestNoTLS12(t *testing.T) {
expected := Result{
Status: 2,
Checks: map[string]*Result{
"connectivity": {"connectivity", 0, nil, nil},
"starttls": {"starttls", 0, nil, nil},
"certificate": {"certificate", 2, nil, nil},
"version": {"version", 1, nil, nil},
Connectivity: {Connectivity, 0, nil, nil},
STARTTLS: {STARTTLS, 0, nil, nil},
Certificate: {Certificate, 2, nil, nil},
Version: {Version, 1, nil, nil},
},
}
compareStatuses(t, expected, result)
Expand Down Expand Up @@ -200,10 +200,10 @@ func TestSuccessWithFakeCA(t *testing.T) {
expected := Result{
Status: 0,
Checks: map[string]*Result{
"connectivity": {"connectivity", 0, nil, nil},
"starttls": {"starttls", 0, nil, nil},
"certificate": {"certificate", 0, nil, nil},
"version": {"version", 0, nil, nil},
Connectivity: {Connectivity, 0, nil, nil},
STARTTLS: {STARTTLS, 0, nil, nil},
Certificate: {Certificate, 0, nil, nil},
Version: {Version, 0, nil, nil},
},
}
compareStatuses(t, expected, result)
Expand Down Expand Up @@ -275,10 +275,10 @@ func TestFailureWithBadHostname(t *testing.T) {
expected := Result{
Status: 2,
Checks: map[string]*Result{
"connectivity": {"connectivity", 0, nil, nil},
"starttls": {"starttls", 0, nil, nil},
"certificate": {"certificate", 2, nil, nil},
"version": {"version", 0, nil, nil},
Connectivity: {Connectivity, 0, nil, nil},
STARTTLS: {STARTTLS, 0, nil, nil},
Certificate: {Certificate, 2, nil, nil},
Version: {Version, 0, nil, nil},
},
}
compareStatuses(t, expected, result)
Expand Down
Loading