Skip to content

Commit

Permalink
Merge 429a734 into 01f0d03
Browse files Browse the repository at this point in the history
  • Loading branch information
vbrown608 committed Jan 30, 2019
2 parents 01f0d03 + 429a734 commit 1c2c8d6
Show file tree
Hide file tree
Showing 16 changed files with 373 additions and 79 deletions.
102 changes: 84 additions & 18 deletions api.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ package main
import (
"encoding/json"
"fmt"
"html/template"
"log"
"net/http"
"os"
"strings"
"time"

Expand Down Expand Up @@ -43,6 +45,7 @@ type API struct {
List PolicyList
DontScan map[string]bool
Emailer EmailSender
Templates map[string]*template.Template
}

// PolicyList interface wraps a policy-list like structure.
Expand All @@ -62,30 +65,32 @@ type EmailSender interface {

// APIResponse wraps all the responses from this API.
type APIResponse struct {
StatusCode int `json:"status_code"`
Message string `json:"message"`
Response interface{} `json:"response"`
StatusCode int `json:"status_code"`
Message string `json:"message"`
Response interface{} `json:"response"`
templateName string `json:"-"`
}

type apiHandler func(r *http.Request) APIResponse

func apiWrapper(api apiHandler) func(w http.ResponseWriter, r *http.Request) {
func (api *API) wrapper(handler apiHandler) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
response := api(r)
if response.StatusCode != http.StatusOK {
http.Error(w, response.Message, response.StatusCode)
}
response := handler(r)
if response.StatusCode == http.StatusInternalServerError {
packet := raven.NewPacket(response.Message, raven.NewHttp(r))
raven.Capture(packet, nil)
}
writeJSON(w, response)
if strings.Contains(r.Header.Get("accept"), "text/html") {
api.writeHTML(w, response)
} else {
api.writeJSON(w, response)
}
}
}

// Checks the policy status of this domain.
func (api API) policyCheck(domain string) *checker.Result {
result := checker.Result{Name: "policylist"}
result := checker.Result{Name: checker.PolicyList}
if _, err := api.List.Get(domain); err == nil {
return result.Success()
}
Expand Down Expand Up @@ -153,16 +158,20 @@ func (api API) Scan(r *http.Request) APIResponse {
scan, err := api.Database.GetLatestScan(domain)
if err == nil && scan.Version == models.ScanVersion &&
time.Now().Before(scan.Timestamp.Add(cacheScanTime)) {
return APIResponse{StatusCode: http.StatusOK, Response: scan}
return APIResponse{
StatusCode: http.StatusOK,
Response: scan,
templateName: "scan",
}
}
// 1. Conduct scan via starttls-checker
rawScandata, err := api.CheckDomain(api, domain)
scanData, err := api.CheckDomain(api, domain)
if err != nil {
return APIResponse{StatusCode: http.StatusInternalServerError, Message: err.Error()}
}
scan = models.Scan{
Domain: domain,
Data: rawScandata,
Data: scanData,
Timestamp: time.Now(),
Version: models.ScanVersion,
}
Expand All @@ -171,7 +180,11 @@ func (api API) Scan(r *http.Request) APIResponse {
if err != nil {
return APIResponse{StatusCode: http.StatusInternalServerError, Message: err.Error()}
}
return APIResponse{StatusCode: http.StatusOK, Response: scan}
return APIResponse{
StatusCode: http.StatusOK,
Response: scan,
templateName: "scan",
}
// GET: Just fetch the most recent scan
} else if r.Method == http.MethodGet {
scan, err := api.Database.GetLatestScan(domain)
Expand Down Expand Up @@ -278,14 +291,23 @@ func (api API) Queue(r *http.Request) APIResponse {
return APIResponse{StatusCode: http.StatusInternalServerError,
Message: "Unable to send validation e-mail"}
}
return APIResponse{StatusCode: http.StatusOK, Response: domainData}
// domainData.State = Unvalidated
// or queued?
return APIResponse{
StatusCode: http.StatusOK,
Response: fmt.Sprintf("Thank you for submitting your domain. Please check postmaster@%s to validate that you control the domain.", domain),
}
// GET: Retrieve domain status from queue
// JSON only
} else if r.Method == http.MethodGet {
status, err := api.Database.GetDomain(domain)
if err != nil {
return APIResponse{StatusCode: http.StatusNotFound, Message: err.Error()}
}
return APIResponse{StatusCode: http.StatusOK, Response: status}
return APIResponse{
StatusCode: http.StatusOK,
Response: status,
}
} else {
return APIResponse{StatusCode: http.StatusMethodNotAllowed,
Message: "/api/queue only accepts POST and GET requests"}
Expand Down Expand Up @@ -352,13 +374,57 @@ func getParam(param string, r *http.Request) (string, error) {

// Writes `v` as a JSON object to http.ResponseWriter `w`. If an error
// occurs, writes `http.StatusInternalServerError` to `w`.
func writeJSON(w http.ResponseWriter, v interface{}) {
func (api *API) writeJSON(w http.ResponseWriter, apiResponse APIResponse) {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
b, err := json.MarshalIndent(v, "", " ")
w.WriteHeader(apiResponse.StatusCode)
b, err := json.MarshalIndent(apiResponse, "", " ")
if err != nil {
msg := fmt.Sprintf("Internal error: could not format JSON. (%s)\n", err)
http.Error(w, msg, http.StatusInternalServerError)
return
}
fmt.Fprintf(w, "%s\n", b)
}

func (api *API) parseTemplates() {
names := []string{"default", "scan"}
api.Templates = make(map[string]*template.Template)
for _, name := range names {
path := fmt.Sprintf("views/%s.html.tmpl", name)
tmpl, err := template.ParseFiles(path)
if err != nil {
raven.CaptureError(err, nil)
log.Fatal(err)
}
api.Templates[name] = tmpl
}
}

func (api *API) writeHTML(w http.ResponseWriter, apiResponse APIResponse) {
// Add some additional useful fields for use in templates.
data := struct {
APIResponse
BaseURL string
StatusText string
}{
APIResponse: apiResponse,
BaseURL: os.Getenv("FRONTEND_WEBSITE_LINK"),
StatusText: http.StatusText(apiResponse.StatusCode),
}
if apiResponse.templateName == "" {
apiResponse.templateName = "default"
}
tmpl, ok := api.Templates[apiResponse.templateName]
if !ok {
err := fmt.Errorf("Template not found: %s", apiResponse.templateName)
raven.CaptureError(err, nil)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.WriteHeader(apiResponse.StatusCode)
err := tmpl.Execute(w, data)
if err != nil {
log.Println(err)
raven.CaptureError(err, nil)
}
}
22 changes: 22 additions & 0 deletions api_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
package main

import (
"io/ioutil"
"net/http"
"net/url"
"strings"
"testing"

"github.com/EFForg/starttls-backend/checker"
Expand Down Expand Up @@ -40,3 +44,21 @@ func TestPolicyCheckWithQueuedDomain(t *testing.T) {
t.Errorf("Check should have warned.")
}
}

func testHTMLPost(path string, data url.Values, t *testing.T) ([]byte, int) {
req, err := http.NewRequest("POST", server.URL+path, strings.NewReader(data.Encode()))
if err != nil {
t.Fatal(err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("accept", "text/html")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatal(err)
}
body, _ := ioutil.ReadAll(resp.Body)
if !strings.Contains(strings.ToLower(string(body)), "<html") {
t.Errorf("Response should be HTML, got %s", string(body))
}
return body, resp.StatusCode
}
2 changes: 1 addition & 1 deletion checker/domain.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ func (d DomainResult) Class() string {
}

func (d DomainResult) setStatus(status DomainStatus) DomainResult {
d.Status = DomainStatus(SetStatus(CheckStatus(d.Status), CheckStatus(status)))
d.Status = DomainStatus(SetStatus(Status(d.Status), Status(status)))
return d
}

Expand Down
24 changes: 12 additions & 12 deletions checker/domain_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,23 @@ var mxLookup = map[string][]string{
// Fake hostname checks :)
var hostnameResults = map[string]Result{
"noconnection": Result{
Status: Error,
Status: 3,
Checks: map[string]*Result{
"connectivity": {"connectivity", Error, nil, nil},
Connectivity: {Connectivity, 3, nil, nil},
},
},
"nostarttls": Result{
Status: Failure,
Status: 2,
Checks: map[string]*Result{
"connectivity": {"connectivity", 0, nil, nil},
"starttls": {"starttls", Failure, nil, nil},
Connectivity: {Connectivity, 0, nil, nil},
STARTTLS: {STARTTLS, 2, nil, nil},
},
},
"nostarttlsconnect": Result{
Status: Error,
Status: 3,
Checks: map[string]*Result{
"connectivity": {"connectivity", 0, nil, nil},
"starttls": {"starttls", Error, nil, nil},
Connectivity: {Connectivity, 0, nil, nil},
STARTTLS: {STARTTLS, 3, nil, nil},
},
},
}
Expand Down Expand Up @@ -73,10 +73,10 @@ func mockCheckHostname(domain string, hostname string) HostnameResult {
Result: &Result{
Status: 0,
Checks: map[string]*Result{
"connectivity": {"connectivity", 0, nil, nil},
"starttls": {"starttls", 0, nil, nil},
"certificate": {"certificate", 0, nil, nil},
"version": {"version", 0, nil, nil},
Connectivity: {Connectivity, 0, nil, nil},
STARTTLS: {STARTTLS, 0, nil, nil},
Certificate: {Certificate, 0, nil, nil},
Version: {Version, 0, nil, nil},
},
},
Timestamp: time.Now(),
Expand Down
12 changes: 6 additions & 6 deletions checker/hostname.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ func (h HostnameResult) checkSucceeded(checkName string) bool {
}

func (h HostnameResult) couldConnect() bool {
return h.checkSucceeded("connectivity")
return h.checkSucceeded(Connectivity)
}

func (h HostnameResult) couldSTARTTLS() bool {
return h.checkSucceeded("starttls")
return h.checkSucceeded(STARTTLS)
}

// Modelled after policyMatches in Appendix B of the MTA-STS RFC 8641.
Expand Down Expand Up @@ -95,7 +95,7 @@ func smtpDialWithTimeout(hostname string, timeout time.Duration) (*smtp.Client,

// Simply tries to StartTLS with the server.
func checkStartTLS(client *smtp.Client) *Result {
result := MakeResult("starttls")
result := MakeResult(STARTTLS)
ok, _ := client.Extension("StartTLS")
if !ok {
return result.Failure("Server does not advertise support for STARTTLS.")
Expand Down Expand Up @@ -141,7 +141,7 @@ var certRoots *x509.CertPool
// Checks that the certificate presented is valid for a particular hostname, unexpired,
// and chains to a trusted root.
func checkCert(client *smtp.Client, domain, hostname string) *Result {
result := MakeResult("certificate")
result := MakeResult(Certificate)
state, ok := client.TLSConnectionState()
if !ok {
return result.Error("TLS not initiated properly.")
Expand Down Expand Up @@ -188,7 +188,7 @@ func checkTLSCipher(hostname string, timeout time.Duration) *Result {
}

func checkTLSVersion(client *smtp.Client, hostname string, timeout time.Duration) *Result {
result := MakeResult("version")
result := MakeResult(Version)

// Check the TLS version of the existing connection.
tlsConnectionState, ok := client.TLSConnectionState()
Expand Down Expand Up @@ -235,7 +235,7 @@ func (c *Checker) CheckHostname(domain string, hostname string) HostnameResult {
}

// Connect to the SMTP server and use that connection to perform as many checks as possible.
connectivityResult := MakeResult("connectivity")
connectivityResult := MakeResult(Connectivity)
client, err := smtpDialWithTimeout(hostname, c.timeout())
if err != nil {
result.addCheck(connectivityResult.Error("Could not establish connection: %v", err))
Expand Down

0 comments on commit 1c2c8d6

Please sign in to comment.