Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 127 additions & 48 deletions api.go → api/api.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
package main
package api

import (
"encoding/json"
"fmt"
"html/template"
"io/ioutil"
"log"
"net/http"
"os"
Expand All @@ -15,9 +16,11 @@ import (

"github.com/EFForg/starttls-backend/checker"
"github.com/EFForg/starttls-backend/db"
"github.com/EFForg/starttls-backend/email"
"github.com/EFForg/starttls-backend/models"
"github.com/EFForg/starttls-backend/policy"
"github.com/getsentry/raven-go"
"github.com/EFForg/starttls-backend/util"
raven "github.com/getsentry/raven-go"
)

////////////////////////////////
Expand All @@ -32,7 +35,7 @@ const cacheScanTime = time.Minute
type checkPerformer func(API, string) (checker.DomainResult, error)

// API is the HTTP API that this service provides.
// All requests respond with an APIResponse JSON, with fields:
// All requests respond with an response JSON, with fields:
// {
// status_code // HTTP status code of request
// message // Any error message accompanying the status_code. If 200, empty.
Expand All @@ -41,12 +44,12 @@ type checkPerformer func(API, string) (checker.DomainResult, error)
// Any POST request accepts either URL query parameters or data value parameters,
// and prefers the latter if both are present.
type API struct {
Database db.Database
CheckDomain checkPerformer
List PolicyList
DontScan map[string]bool
Emailer EmailSender
Templates map[string]*template.Template
Database db.Database
checkDomainOverride checkPerformer
List PolicyList
DontScan map[string]bool
Emailer EmailSender
Templates map[string]*template.Template
}

// PolicyList interface wraps a policy-list like structure.
Expand All @@ -64,15 +67,21 @@ type EmailSender interface {
SendValidation(*models.Domain, string) error
}

// APIResponse wraps all the responses from this API.
type APIResponse struct {
type response struct {
StatusCode int `json:"status_code"`
Message string `json:"message"`
Response interface{} `json:"response"`
templateName string `json:"-"`
}

type apiHandler func(r *http.Request) APIResponse
type apiHandler func(r *http.Request) response

func (api *API) checkDomain(domain string) (checker.DomainResult, error) {
if api.checkDomainOverride == nil {
return defaultCheck(*api, domain)
}
return api.checkDomainOverride(*api, domain)
}

func (api *API) wrapper(handler apiHandler) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
Expand All @@ -89,6 +98,24 @@ func (api *API) wrapper(handler apiHandler) func(w http.ResponseWriter, r *http.
}
}

func pingHandler(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Header().Set("Content-Type", "application/json")
}

// RegisterHandlers binds API functions to the given http server,
// and returns the resulting handler.
func (api *API) RegisterHandlers(mux *http.ServeMux) http.Handler {
mux.HandleFunc("/sns", HandleSESNotification(api.Database))
mux.HandleFunc("/api/scan", api.wrapper(api.scan))
mux.Handle("/api/queue",
throttleHandler(time.Hour, 20, http.HandlerFunc(api.wrapper(api.queue))))
mux.HandleFunc("/api/validate", api.wrapper(api.validate))
mux.HandleFunc("/api/stats", api.wrapper(api.stats))
mux.HandleFunc("/api/ping", pingHandler)
return middleware(mux)
}

func defaultCheck(api API, domain string) (checker.DomainResult, error) {
policyChan := models.Domain{Name: domain}.AsyncPolicyListCheck(api.Database, api.List)
c := checker.Checker{
Expand All @@ -111,15 +138,15 @@ func defaultCheck(api API, domain string) (checker.DomainResult, error) {
// GET /api/scan?domain=<domain>
// Retrieves most recent scan for domain.
// Both set a models.Scan JSON as the response.
func (api API) Scan(r *http.Request) APIResponse {
func (api API) scan(r *http.Request) response {
domain, err := getASCIIDomain(r)
if err != nil {
return APIResponse{StatusCode: http.StatusBadRequest, Message: err.Error()}
return response{StatusCode: http.StatusBadRequest, Message: err.Error()}
}
// Check if we shouldn't scan this domain
if api.DontScan != nil {
if _, ok := api.DontScan[domain]; ok {
return APIResponse{StatusCode: http.StatusTooManyRequests}
return response{StatusCode: http.StatusTooManyRequests}
}
}
// POST: Force scan to be conducted
Expand All @@ -128,16 +155,16 @@ func (api API) Scan(r *http.Request) APIResponse {
scan, err := api.Database.GetLatestScan(domain)
if err == nil && scan.Version == models.ScanVersion &&
time.Now().Before(scan.Timestamp.Add(cacheScanTime)) {
return APIResponse{
return response{
StatusCode: http.StatusOK,
Response: scan,
templateName: "scan",
}
}
// 1. Conduct scan via starttls-checker
scanData, err := api.CheckDomain(api, domain)
scanData, err := api.checkDomain(domain)
if err != nil {
return APIResponse{StatusCode: http.StatusInternalServerError, Message: err.Error()}
return response{StatusCode: http.StatusInternalServerError, Message: err.Error()}
}
scan = models.Scan{
Domain: domain,
Expand All @@ -148,9 +175,9 @@ func (api API) Scan(r *http.Request) APIResponse {
// 2. Put scan into DB
err = api.Database.PutScan(scan)
if err != nil {
return APIResponse{StatusCode: http.StatusInternalServerError, Message: err.Error()}
return response{StatusCode: http.StatusInternalServerError, Message: err.Error()}
}
return APIResponse{
return response{
StatusCode: http.StatusOK,
Response: scan,
templateName: "scan",
Expand All @@ -159,11 +186,11 @@ func (api API) Scan(r *http.Request) APIResponse {
} else if r.Method == http.MethodGet {
scan, err := api.Database.GetLatestScan(domain)
if err != nil {
return APIResponse{StatusCode: http.StatusNotFound, Message: err.Error()}
return response{StatusCode: http.StatusNotFound, Message: err.Error()}
}
return APIResponse{StatusCode: http.StatusOK, Response: scan}
return response{StatusCode: http.StatusOK, Response: scan}
} else {
return APIResponse{StatusCode: http.StatusMethodNotAllowed,
return response{StatusCode: http.StatusMethodNotAllowed,
Message: "/api/scan only accepts POST and GET requests"}
}
}
Expand All @@ -184,11 +211,11 @@ func getDomainParams(r *http.Request) (models.Domain, error) {
MTASTS: mtasts == "on",
State: models.StateUnconfirmed,
}
email, err := getParam("email", r)
givenEmail, err := getParam("email", r)
if err == nil {
domain.Email = email
domain.Email = givenEmail
} else {
domain.Email = validationAddress(&domain)
domain.Email = email.ValidationAddress(&domain)
}
queueWeeks, err := getInt("weeks", r, 4, 52, 4)
if err != nil {
Expand All @@ -201,7 +228,7 @@ func getDomainParams(r *http.Request) (models.Domain, error) {
if len(hostname) == 0 {
continue
}
if !validDomainName(strings.TrimPrefix(hostname, ".")) {
if !util.ValidDomainName(strings.TrimPrefix(hostname, ".")) {
return domain, fmt.Errorf("Hostname %s is invalid", hostname)
}
domain.MXs = append(domain.MXs, hostname)
Expand All @@ -226,7 +253,7 @@ func getDomainParams(r *http.Request) (models.Domain, error) {
// email (optional): Contact email associated with domain.
// GET /api/queue?domain=<domain>
// Sets models.Domain object as response.
func (api API) Queue(r *http.Request) APIResponse {
func (api API) queue(r *http.Request) response {
// POST: Insert this domain into the queue
if r.Method == http.MethodPost {
domain, err := getDomainParams(r)
Expand All @@ -246,7 +273,7 @@ func (api API) Queue(r *http.Request) APIResponse {
log.Print(err)
return serverError("Unable to send validation e-mail")
}
return APIResponse{
return response{
StatusCode: http.StatusOK,
Response: fmt.Sprintf("Thank you for submitting your domain. Please check postmaster@%s to validate that you control the domain.", domain.Name),
}
Expand All @@ -259,28 +286,28 @@ func (api API) Queue(r *http.Request) APIResponse {
}
domainObj, err := models.GetDomain(api.Database, domainName)
if err != nil {
return APIResponse{StatusCode: http.StatusNotFound, Message: err.Error()}
return response{StatusCode: http.StatusNotFound, Message: err.Error()}
}
return APIResponse{
return response{
StatusCode: http.StatusOK,
Response: domainObj,
}
}
return APIResponse{StatusCode: http.StatusMethodNotAllowed,
return response{StatusCode: http.StatusMethodNotAllowed,
Message: "/api/queue only accepts POST and GET requests"}
}

// Validate handles requests to /api/validate
// POST /api/validate
// token: token to validate/redeem
// Sets the queued domain name as response.
func (api API) Validate(r *http.Request) APIResponse {
func (api API) validate(r *http.Request) response {
token, err := getParam("token", r)
if err != nil {
return APIResponse{StatusCode: http.StatusBadRequest, Message: err.Error()}
return response{StatusCode: http.StatusBadRequest, Message: err.Error()}
}
if r.Method != http.MethodPost {
return APIResponse{StatusCode: http.StatusMethodNotAllowed,
return response{StatusCode: http.StatusMethodNotAllowed,
Message: "/api/validate only accepts POST requests"}
}
tokenData := models.Token{Token: token}
Expand All @@ -291,7 +318,7 @@ func (api API) Validate(r *http.Request) APIResponse {
if dbErr != nil {
return serverError(dbErr.Error())
}
return APIResponse{StatusCode: http.StatusOK, Response: domain}
return response{StatusCode: http.StatusOK, Response: domain}
}

// Retrieve "domain" parameter from request as ASCII
Expand Down Expand Up @@ -341,7 +368,7 @@ func getInt(param string, r *http.Request, lowInc int, highExc int, defaultNum i

// Writes `v` as a JSON object to http.ResponseWriter `w`. If an error
// occurs, writes `http.StatusInternalServerError` to `w`.
func (api *API) writeJSON(w http.ResponseWriter, apiResponse APIResponse) {
func (api *API) writeJSON(w http.ResponseWriter, apiResponse response) {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.WriteHeader(apiResponse.StatusCode)
b, err := json.MarshalIndent(apiResponse, "", " ")
Expand All @@ -353,11 +380,12 @@ func (api *API) writeJSON(w http.ResponseWriter, apiResponse APIResponse) {
fmt.Fprintf(w, "%s\n", b)
}

func (api *API) parseTemplates() {
// ParseTemplates initializes our HTML template data
func (api *API) ParseTemplates() {
names := []string{"default", "scan"}
api.Templates = make(map[string]*template.Template)
for _, name := range names {
path := fmt.Sprintf("views/%s.html.tmpl", name)
path := fmt.Sprintf("../views/%s.html.tmpl", name)
tmpl, err := template.ParseFiles(path)
if err != nil {
raven.CaptureError(err, nil)
Expand All @@ -367,16 +395,16 @@ func (api *API) parseTemplates() {
}
}

func (api *API) writeHTML(w http.ResponseWriter, apiResponse APIResponse) {
func (api *API) writeHTML(w http.ResponseWriter, apiResponse response) {
// Add some additional useful fields for use in templates.
data := struct {
APIResponse
response
BaseURL string
StatusText string
}{
APIResponse: apiResponse,
BaseURL: os.Getenv("FRONTEND_WEBSITE_LINK"),
StatusText: http.StatusText(apiResponse.StatusCode),
response: apiResponse,
BaseURL: os.Getenv("FRONTEND_WEBSITE_LINK"),
StatusText: http.StatusText(apiResponse.StatusCode),
}
if apiResponse.templateName == "" {
apiResponse.templateName = "default"
Expand All @@ -396,16 +424,67 @@ func (api *API) writeHTML(w http.ResponseWriter, apiResponse APIResponse) {
}
}

func badRequest(format string, a ...interface{}) APIResponse {
return APIResponse{
func badRequest(format string, a ...interface{}) response {
return response{
StatusCode: http.StatusBadRequest,
Message: fmt.Sprintf(format, a...),
}
}

func serverError(format string, a ...interface{}) APIResponse {
return APIResponse{
func serverError(format string, a ...interface{}) response {
return response{
StatusCode: http.StatusInternalServerError,
Message: fmt.Sprintf(format, a...),
}
}

type ravenExtraContent string

// Class satisfies raven's Interface interface so we can send this as extra context.
// https://github.com/getsentry/raven-go/issues/125
func (r ravenExtraContent) Class() string {
return "extra"
}

func (r ravenExtraContent) MarshalJSON() ([]byte, error) {
return []byte(r), nil
}

// HandleSESNotification handles AWS SES bounces and complaints submitted to a webhook
// via AWS SNS (Simple Notification Service).
// The SNS webhook is configured to include a secret API key stored in the environment.
func HandleSESNotification(database db.Database) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
keyParam := r.URL.Query()["amazon_authorize_key"]
if len(keyParam) == 0 || keyParam[0] != os.Getenv("AMAZON_AUTHORIZE_KEY") {
w.WriteHeader(http.StatusUnauthorized)
return
}

body, err := ioutil.ReadAll(r.Body)
if err != nil {
raven.CaptureError(err, nil)
return
}

data := &email.BlacklistRequest{}
err = json.Unmarshal(body, data)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
raven.CaptureError(err, nil, ravenExtraContent(body))
return
}

tags := map[string]string{"notification_type": data.Reason}
raven.CaptureMessage("Received SES notification", tags, ravenExtraContent(data.Raw))

for _, recipient := range data.Recipients {
err = database.PutBlacklistedEmail(recipient.EmailAddress, data.Reason, data.Timestamp)
if err != nil {
raven.CaptureError(err, nil)
}
}

w.WriteHeader(http.StatusOK)
}
}
Loading