Skip to content

Commit

Permalink
Merge dcf6daa into 01f0d03
Browse files Browse the repository at this point in the history
  • Loading branch information
vbrown608 committed Jan 31, 2019
2 parents 01f0d03 + dcf6daa commit 5cda1ae
Show file tree
Hide file tree
Showing 15 changed files with 410 additions and 74 deletions.
83 changes: 67 additions & 16 deletions api.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ package main
import (
"encoding/json"
"fmt"
"html/template"
"log"
"net/http"
"os"
"strings"
"time"

Expand Down Expand Up @@ -62,30 +64,32 @@ type EmailSender interface {

// APIResponse wraps all the responses from this API.
type APIResponse struct {
StatusCode int `json:"status_code"`
Message string `json:"message"`
Response interface{} `json:"response"`
StatusCode int `json:"status_code"`
Message string `json:"message"`
Response interface{} `json:"response"`
TemplatePath string `json:"-"`
}

type apiHandler func(r *http.Request) APIResponse

func apiWrapper(api apiHandler) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
response := api(r)
if response.StatusCode != http.StatusOK {
http.Error(w, response.Message, response.StatusCode)
}
if response.StatusCode == http.StatusInternalServerError {
packet := raven.NewPacket(response.Message, raven.NewHttp(r))
raven.Capture(packet, nil)
}
writeJSON(w, response)
if strings.Contains(r.Header.Get("accept"), "text/html") {
writeHTML(w, response)
} else {
writeJSON(w, response)
}
}
}

// Checks the policy status of this domain.
func (api API) policyCheck(domain string) *checker.Result {
result := checker.Result{Name: "policylist"}
result := checker.Result{Name: checker.PolicyList}
if _, err := api.List.Get(domain); err == nil {
return result.Success()
}
Expand Down Expand Up @@ -153,16 +157,20 @@ func (api API) Scan(r *http.Request) APIResponse {
scan, err := api.Database.GetLatestScan(domain)
if err == nil && scan.Version == models.ScanVersion &&
time.Now().Before(scan.Timestamp.Add(cacheScanTime)) {
return APIResponse{StatusCode: http.StatusOK, Response: scan}
return APIResponse{
StatusCode: http.StatusOK,
Response: scan,
TemplatePath: "views/scan.html.tmpl",
}
}
// 1. Conduct scan via starttls-checker
rawScandata, err := api.CheckDomain(api, domain)
scanData, err := api.CheckDomain(api, domain)
if err != nil {
return APIResponse{StatusCode: http.StatusInternalServerError, Message: err.Error()}
}
scan = models.Scan{
Domain: domain,
Data: rawScandata,
Data: scanData,
Timestamp: time.Now(),
Version: models.ScanVersion,
}
Expand All @@ -171,7 +179,11 @@ func (api API) Scan(r *http.Request) APIResponse {
if err != nil {
return APIResponse{StatusCode: http.StatusInternalServerError, Message: err.Error()}
}
return APIResponse{StatusCode: http.StatusOK, Response: scan}
return APIResponse{
StatusCode: http.StatusOK,
Response: scan,
TemplatePath: "views/scan.html.tmpl",
}
// GET: Just fetch the most recent scan
} else if r.Method == http.MethodGet {
scan, err := api.Database.GetLatestScan(domain)
Expand Down Expand Up @@ -278,14 +290,23 @@ func (api API) Queue(r *http.Request) APIResponse {
return APIResponse{StatusCode: http.StatusInternalServerError,
Message: "Unable to send validation e-mail"}
}
return APIResponse{StatusCode: http.StatusOK, Response: domainData}
// domainData.State = Unvalidated
// or queued?
return APIResponse{
StatusCode: http.StatusOK,
Response: fmt.Sprintf("Thank you for submitting your domain. Please check postmaster@%s to validate that you control the domain.", domain),
}
// GET: Retrieve domain status from queue
// JSON only
} else if r.Method == http.MethodGet {
status, err := api.Database.GetDomain(domain)
if err != nil {
return APIResponse{StatusCode: http.StatusNotFound, Message: err.Error()}
}
return APIResponse{StatusCode: http.StatusOK, Response: status}
return APIResponse{
StatusCode: http.StatusOK,
Response: status,
}
} else {
return APIResponse{StatusCode: http.StatusMethodNotAllowed,
Message: "/api/queue only accepts POST and GET requests"}
Expand Down Expand Up @@ -352,13 +373,43 @@ func getParam(param string, r *http.Request) (string, error) {

// Writes `v` as a JSON object to http.ResponseWriter `w`. If an error
// occurs, writes `http.StatusInternalServerError` to `w`.
func writeJSON(w http.ResponseWriter, v interface{}) {
func writeJSON(w http.ResponseWriter, apiResponse APIResponse) {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
b, err := json.MarshalIndent(v, "", " ")
w.WriteHeader(apiResponse.StatusCode)
b, err := json.MarshalIndent(apiResponse, "", " ")
if err != nil {
msg := fmt.Sprintf("Internal error: could not format JSON. (%s)\n", err)
http.Error(w, msg, http.StatusInternalServerError)
return
}
fmt.Fprintf(w, "%s\n", b)
}

func writeHTML(w http.ResponseWriter, apiResponse APIResponse) {
// Add some additional useful fields for use in templates.
if apiResponse.TemplatePath == "" {
apiResponse.TemplatePath = "views/default.html.tmpl"
}
data := struct {
APIResponse
BaseURL string
StatusText string
}{
APIResponse: apiResponse,
BaseURL: os.Getenv("FRONTEND_WEBSITE_LINK"),
StatusText: http.StatusText(apiResponse.StatusCode),
}
tmpl, err := template.ParseFiles(apiResponse.TemplatePath)
if err != nil {
log.Println(err)
raven.CaptureError(err, nil)
http.Error(w, "Internal error: could not parse template.", http.StatusInternalServerError)
return
}
w.WriteHeader(apiResponse.StatusCode)
err = tmpl.Execute(w, data)
if err != nil {
log.Println(err)
raven.CaptureError(err, nil)
}
}
22 changes: 22 additions & 0 deletions api_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
package main

import (
"io/ioutil"
"net/http"
"net/url"
"strings"
"testing"

"github.com/EFForg/starttls-backend/checker"
Expand Down Expand Up @@ -40,3 +44,21 @@ func TestPolicyCheckWithQueuedDomain(t *testing.T) {
t.Errorf("Check should have warned.")
}
}

func testHTMLPost(path string, data url.Values, t *testing.T) ([]byte, int) {
req, err := http.NewRequest("POST", server.URL+path, strings.NewReader(data.Encode()))
if err != nil {
t.Fatal(err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("accept", "text/html")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatal(err)
}
body, _ := ioutil.ReadAll(resp.Body)
if !strings.Contains(strings.ToLower(string(body)), "</html") {
t.Errorf("Response should be HTML, got %s", string(body))
}
return body, resp.StatusCode
}
2 changes: 1 addition & 1 deletion checker/domain.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ func (d DomainResult) Class() string {
}

func (d DomainResult) setStatus(status DomainStatus) DomainResult {
d.Status = DomainStatus(SetStatus(CheckStatus(d.Status), CheckStatus(status)))
d.Status = DomainStatus(SetStatus(Status(d.Status), Status(status)))
return d
}

Expand Down
24 changes: 12 additions & 12 deletions checker/domain_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,23 @@ var mxLookup = map[string][]string{
// Fake hostname checks :)
var hostnameResults = map[string]Result{
"noconnection": Result{
Status: Error,
Status: 3,
Checks: map[string]*Result{
"connectivity": {"connectivity", Error, nil, nil},
Connectivity: {Connectivity, 3, nil, nil},
},
},
"nostarttls": Result{
Status: Failure,
Status: 2,
Checks: map[string]*Result{
"connectivity": {"connectivity", 0, nil, nil},
"starttls": {"starttls", Failure, nil, nil},
Connectivity: {Connectivity, 0, nil, nil},
STARTTLS: {STARTTLS, 2, nil, nil},
},
},
"nostarttlsconnect": Result{
Status: Error,
Status: 3,
Checks: map[string]*Result{
"connectivity": {"connectivity", 0, nil, nil},
"starttls": {"starttls", Error, nil, nil},
Connectivity: {Connectivity, 0, nil, nil},
STARTTLS: {STARTTLS, 3, nil, nil},
},
},
}
Expand Down Expand Up @@ -73,10 +73,10 @@ func mockCheckHostname(domain string, hostname string) HostnameResult {
Result: &Result{
Status: 0,
Checks: map[string]*Result{
"connectivity": {"connectivity", 0, nil, nil},
"starttls": {"starttls", 0, nil, nil},
"certificate": {"certificate", 0, nil, nil},
"version": {"version", 0, nil, nil},
Connectivity: {Connectivity, 0, nil, nil},
STARTTLS: {STARTTLS, 0, nil, nil},
Certificate: {Certificate, 0, nil, nil},
Version: {Version, 0, nil, nil},
},
},
Timestamp: time.Now(),
Expand Down
12 changes: 6 additions & 6 deletions checker/hostname.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ func (h HostnameResult) checkSucceeded(checkName string) bool {
}

func (h HostnameResult) couldConnect() bool {
return h.checkSucceeded("connectivity")
return h.checkSucceeded(Connectivity)
}

func (h HostnameResult) couldSTARTTLS() bool {
return h.checkSucceeded("starttls")
return h.checkSucceeded(STARTTLS)
}

// Modelled after policyMatches in Appendix B of the MTA-STS RFC 8641.
Expand Down Expand Up @@ -95,7 +95,7 @@ func smtpDialWithTimeout(hostname string, timeout time.Duration) (*smtp.Client,

// Simply tries to StartTLS with the server.
func checkStartTLS(client *smtp.Client) *Result {
result := MakeResult("starttls")
result := MakeResult(STARTTLS)
ok, _ := client.Extension("StartTLS")
if !ok {
return result.Failure("Server does not advertise support for STARTTLS.")
Expand Down Expand Up @@ -141,7 +141,7 @@ var certRoots *x509.CertPool
// Checks that the certificate presented is valid for a particular hostname, unexpired,
// and chains to a trusted root.
func checkCert(client *smtp.Client, domain, hostname string) *Result {
result := MakeResult("certificate")
result := MakeResult(Certificate)
state, ok := client.TLSConnectionState()
if !ok {
return result.Error("TLS not initiated properly.")
Expand Down Expand Up @@ -188,7 +188,7 @@ func checkTLSCipher(hostname string, timeout time.Duration) *Result {
}

func checkTLSVersion(client *smtp.Client, hostname string, timeout time.Duration) *Result {
result := MakeResult("version")
result := MakeResult(Version)

// Check the TLS version of the existing connection.
tlsConnectionState, ok := client.TLSConnectionState()
Expand Down Expand Up @@ -235,7 +235,7 @@ func (c *Checker) CheckHostname(domain string, hostname string) HostnameResult {
}

// Connect to the SMTP server and use that connection to perform as many checks as possible.
connectivityResult := MakeResult("connectivity")
connectivityResult := MakeResult(Connectivity)
client, err := smtpDialWithTimeout(hostname, c.timeout())
if err != nil {
result.addCheck(connectivityResult.Error("Could not establish connection: %v", err))
Expand Down
38 changes: 19 additions & 19 deletions checker/hostname_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ func TestNoConnection(t *testing.T) {
expected := Result{
Status: 3,
Checks: map[string]*Result{
"connectivity": {"connectivity", 3, nil, nil},
"connectivity": {Connectivity, 3, nil, nil},
},
}
compareStatuses(t, expected, result)
Expand All @@ -122,8 +122,8 @@ func TestNoTLS(t *testing.T) {
expected := Result{
Status: 2,
Checks: map[string]*Result{
"connectivity": {"connectivity", 0, nil, nil},
"starttls": {"starttls", 2, nil, nil},
Connectivity: {Connectivity, 0, nil, nil},
STARTTLS: {STARTTLS, 2, nil, nil},
},
}
compareStatuses(t, expected, result)
Expand All @@ -142,10 +142,10 @@ func TestSelfSigned(t *testing.T) {
expected := Result{
Status: 2,
Checks: map[string]*Result{
"connectivity": {"connectivity", 0, nil, nil},
"starttls": {"starttls", 0, nil, nil},
"certificate": {"certificate", 2, nil, nil},
"version": {"version", 0, nil, nil},
Connectivity: {Connectivity, 0, nil, nil},
STARTTLS: {STARTTLS, 0, nil, nil},
Certificate: {Certificate, 2, nil, nil},
Version: {Version, 0, nil, nil},
},
}
compareStatuses(t, expected, result)
Expand All @@ -168,10 +168,10 @@ func TestNoTLS12(t *testing.T) {
expected := Result{
Status: 2,
Checks: map[string]*Result{
"connectivity": {"connectivity", 0, nil, nil},
"starttls": {"starttls", 0, nil, nil},
"certificate": {"certificate", 2, nil, nil},
"version": {"version", 1, nil, nil},
Connectivity: {Connectivity, 0, nil, nil},
STARTTLS: {STARTTLS, 0, nil, nil},
Certificate: {Certificate, 2, nil, nil},
Version: {Version, 1, nil, nil},
},
}
compareStatuses(t, expected, result)
Expand Down Expand Up @@ -200,10 +200,10 @@ func TestSuccessWithFakeCA(t *testing.T) {
expected := Result{
Status: 0,
Checks: map[string]*Result{
"connectivity": {"connectivity", 0, nil, nil},
"starttls": {"starttls", 0, nil, nil},
"certificate": {"certificate", 0, nil, nil},
"version": {"version", 0, nil, nil},
Connectivity: {Connectivity, 0, nil, nil},
STARTTLS: {STARTTLS, 0, nil, nil},
Certificate: {Certificate, 0, nil, nil},
Version: {Version, 0, nil, nil},
},
}
compareStatuses(t, expected, result)
Expand Down Expand Up @@ -275,10 +275,10 @@ func TestFailureWithBadHostname(t *testing.T) {
expected := Result{
Status: 2,
Checks: map[string]*Result{
"connectivity": {"connectivity", 0, nil, nil},
"starttls": {"starttls", 0, nil, nil},
"certificate": {"certificate", 2, nil, nil},
"version": {"version", 0, nil, nil},
Connectivity: {Connectivity, 0, nil, nil},
STARTTLS: {STARTTLS, 0, nil, nil},
Certificate: {Certificate, 2, nil, nil},
Version: {Version, 0, nil, nil},
},
}
compareStatuses(t, expected, result)
Expand Down

0 comments on commit 5cda1ae

Please sign in to comment.