diff --git a/config/config.go b/config/config.go index 4e0451def..98ec6b6a3 100644 --- a/config/config.go +++ b/config/config.go @@ -9,11 +9,12 @@ import ( // AdminServer represents the Admin server configuration details type AdminServer struct { - ListenURL string `json:"listen_url"` - UseTLS bool `json:"use_tls"` - CertPath string `json:"cert_path"` - KeyPath string `json:"key_path"` - CSRFKey string `json:"csrf_key"` + ListenURL string `json:"listen_url"` + UseTLS bool `json:"use_tls"` + CertPath string `json:"cert_path"` + KeyPath string `json:"key_path"` + CSRFKey string `json:"csrf_key"` + AllowedInternalHosts []string `json:"allowed_internal_hosts"` } // PhishServer represents the Phish server configuration details diff --git a/controllers/api/import.go b/controllers/api/import.go index 7cf96a013..efaf0178b 100644 --- a/controllers/api/import.go +++ b/controllers/api/import.go @@ -10,6 +10,7 @@ import ( "strings" "github.com/PuerkitoBio/goquery" + "github.com/gophish/gophish/dialer" log "github.com/gophish/gophish/logger" "github.com/gophish/gophish/models" "github.com/gophish/gophish/util" @@ -113,7 +114,9 @@ func (as *Server) ImportSite(w http.ResponseWriter, r *http.Request) { JSONResponse(w, models.Response{Success: false, Message: err.Error()}, http.StatusBadRequest) return } + restrictedDialer := dialer.Dialer() tr := &http.Transport{ + DialContext: restrictedDialer.DialContext, TLSClientConfig: &tls.Config{ InsecureSkipVerify: true, }, diff --git a/controllers/api/import_test.go b/controllers/api/import_test.go new file mode 100644 index 000000000..2278de503 --- /dev/null +++ b/controllers/api/import_test.go @@ -0,0 +1,84 @@ +package api + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/gophish/gophish/dialer" + "github.com/gophish/gophish/models" +) + +func makeImportRequest(ctx *testContext, allowedHosts []string, url string) *httptest.ResponseRecorder { + orig := dialer.DefaultDialer.AllowedHosts() + dialer.SetAllowedHosts(allowedHosts) + req := httptest.NewRequest(http.MethodPost, "/api/import/site", + bytes.NewBuffer([]byte(fmt.Sprintf(` + { + "url" : "%s" + } + `, url)))) + req.Header.Set("Content-Type", "application/json") + response := httptest.NewRecorder() + ctx.apiServer.ImportSite(response, req) + dialer.SetAllowedHosts(orig) + return response +} + +func TestDefaultDeniedImport(t *testing.T) { + ctx := setupTest(t) + metadataURL := "http://169.254.169.254/latest/meta-data/" + response := makeImportRequest(ctx, []string{}, metadataURL) + expectedCode := http.StatusBadRequest + if response.Code != expectedCode { + t.Fatalf("incorrect status code received. expected %d got %d", expectedCode, response.Code) + } + got := &models.Response{} + err := json.NewDecoder(response.Body).Decode(got) + if err != nil { + t.Fatalf("error decoding body: %v", err) + } + if !strings.Contains(got.Message, "upstream connection denied") { + t.Fatalf("incorrect response error provided: %s", got.Message) + } +} + +func TestDefaultAllowedImport(t *testing.T) { + ctx := setupTest(t) + h := "" + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, h) + })) + defer ts.Close() + response := makeImportRequest(ctx, []string{}, ts.URL) + expectedCode := http.StatusOK + if response.Code != expectedCode { + t.Fatalf("incorrect status code received. expected %d got %d", expectedCode, response.Code) + } +} + +func TestCustomDeniedImport(t *testing.T) { + ctx := setupTest(t) + h := "" + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, h) + })) + defer ts.Close() + response := makeImportRequest(ctx, []string{"192.168.1.1"}, ts.URL) + expectedCode := http.StatusBadRequest + if response.Code != expectedCode { + t.Fatalf("incorrect status code received. expected %d got %d", expectedCode, response.Code) + } + got := &models.Response{} + err := json.NewDecoder(response.Body).Decode(got) + if err != nil { + t.Fatalf("error decoding body: %v", err) + } + if !strings.Contains(got.Message, "upstream connection denied") { + t.Fatalf("incorrect response error provided: %s", got.Message) + } +} diff --git a/dialer/dialer.go b/dialer/dialer.go new file mode 100644 index 000000000..15fb402ee --- /dev/null +++ b/dialer/dialer.go @@ -0,0 +1,158 @@ +package dialer + +import ( + "fmt" + "net" + "syscall" + "time" +) + +// RestrictedDialer is used to create a net.Dialer which restricts outbound +// connections to only allowlisted IP ranges. +type RestrictedDialer struct { + allowedHosts []*net.IPNet +} + +// DefaultDialer is a global instance of a RestrictedDialer +var DefaultDialer = &RestrictedDialer{} + +// SetAllowedHosts sets the list of allowed hosts or IP ranges for the default +// dialer. +func SetAllowedHosts(allowed []string) { + DefaultDialer.SetAllowedHosts(allowed) +} + +// AllowedHosts returns the configured hosts that are allowed for the dialer. +func (d *RestrictedDialer) AllowedHosts() []string { + ranges := []string{} + for _, ipRange := range d.allowedHosts { + ranges = append(ranges, ipRange.String()) + } + return ranges +} + +// SetAllowedHosts sets the list of allowed hosts or IP ranges for the dialer. +func (d *RestrictedDialer) SetAllowedHosts(allowed []string) error { + for _, ipRange := range allowed { + // For flexibility, try to parse as an IP first since this will + // undoubtedly cause issues. If it works, then just append the + // appropriate subnet mask, then parse as CIDR + if singleIP := net.ParseIP(ipRange); singleIP != nil { + if singleIP.To4() != nil { + ipRange += "/32" + } else { + ipRange += "/128" + } + } + _, parsed, err := net.ParseCIDR(ipRange) + if err != nil { + return fmt.Errorf("provided ip range is not valid CIDR notation: %v", err) + } + d.allowedHosts = append(d.allowedHosts, parsed) + } + return nil +} + +// Dialer returns a net.Dialer that restricts outbound connections to only the +// addresses allowed by the DefaultDialer. +func Dialer() *net.Dialer { + return DefaultDialer.Dialer() +} + +// Dialer returns a net.Dialer that restricts outbound connections to only the +// allowed addresses over TCP. +// +// By default, since Gophish anticipates connections originating to hosts on +// the local network, we only deny access to the link-local addresses at +// 169.254.0.0/16. +// +// If hosts are provided, then Gophish blocks access to all local addresses +// except the ones provided. +// +// This implementation is based on the blog post by Andrew Ayer at +// https://www.agwa.name/blog/post/preventing_server_side_request_forgery_in_golang +func (d *RestrictedDialer) Dialer() *net.Dialer { + return &net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + Control: restrictedControl(d.allowedHosts), + } +} + +// defaultDeny represents the list of IP ranges that we want to block unless +// explicitly overriden. +var defaultDeny = []string{ + "169.254.0.0/16", // Link-local (used for VPS instance metadata) +} + +// allInternal represents all internal hosts such that the only connections +// allowed are external ones. +var allInternal = []string{ + "0.0.0.0/8", + "127.0.0.0/8", // IPv4 loopback + "10.0.0.0/8", // RFC1918 + "100.64.0.0/10", // CGNAT + "172.16.0.0/12", // RFC1918 + "169.254.0.0/16", // RFC3927 link-local + "192.88.99.0/24", // IPv6 to IPv4 Relay + "192.168.0.0/16", // RFC1918 + "198.51.100.0/24", // TEST-NET-2 + "203.0.113.0/24", // TEST-NET-3 + "224.0.0.0/4", // Multicast + "240.0.0.0/4", // Reserved + "255.255.255.255/32", // Broadcast + "::/0", // Default route + "::/128", // Unspecified address + "::1/128", // IPv6 loopback + "::ffff:0:0/96", // IPv4 mapped addresses. + "::ffff:0:0:0/96", // IPv4 translated addresses. + "fe80::/10", // IPv6 link-local + "fc00::/7", // IPv6 unique local addr +} + +type dialControl = func(network, address string, c syscall.RawConn) error + +type restrictedDialer struct { + *net.Dialer + allowed []string +} + +func restrictedControl(allowed []*net.IPNet) dialControl { + return func(network string, address string, conn syscall.RawConn) error { + if !(network == "tcp4" || network == "tcp6") { + return fmt.Errorf("%s is not a safe network type", network) + } + + host, _, err := net.SplitHostPort(address) + if err != nil { + return fmt.Errorf("%s is not a valid host/port pair: %s", address, err) + } + + ip := net.ParseIP(host) + if ip == nil { + return fmt.Errorf("%s is not a valid IP address", host) + } + + denyList := defaultDeny + if len(allowed) > 0 { + denyList = allInternal + } + + for _, ipRange := range allowed { + if ipRange.Contains(ip) { + return nil + } + } + + for _, ipRange := range denyList { + _, parsed, err := net.ParseCIDR(ipRange) + if err != nil { + return fmt.Errorf("error parsing denied range: %v", err) + } + if parsed.Contains(ip) { + return fmt.Errorf("upstream connection denied to internal host") + } + } + return nil + } +} diff --git a/dialer/dialer_test.go b/dialer/dialer_test.go new file mode 100644 index 000000000..0b70b1a9f --- /dev/null +++ b/dialer/dialer_test.go @@ -0,0 +1,85 @@ +package dialer + +import ( + "fmt" + "net" + "strings" + "syscall" + "testing" +) + +func TestDefaultDeny(t *testing.T) { + control := restrictedControl([]*net.IPNet{}) + host := "169.254.169.254" + expected := fmt.Errorf("upstream connection denied to internal host at %s", host) + conn := new(syscall.RawConn) + got := control("tcp4", fmt.Sprintf("%s:80", host), *conn) + if !strings.Contains(got.Error(), "upstream connection denied") { + t.Fatalf("unexpected error dialing denylisted host. expected %v got %v", expected, got) + } +} + +func TestDefaultAllow(t *testing.T) { + control := restrictedControl([]*net.IPNet{}) + host := "1.1.1.1" + conn := new(syscall.RawConn) + got := control("tcp4", fmt.Sprintf("%s:80", host), *conn) + if got != nil { + t.Fatalf("error dialing allowed host. got %v", got) + } +} + +func TestCustomAllow(t *testing.T) { + host := "127.0.0.1" + _, ipRange, _ := net.ParseCIDR(fmt.Sprintf("%s/32", host)) + allowed := []*net.IPNet{ipRange} + control := restrictedControl(allowed) + conn := new(syscall.RawConn) + got := control("tcp4", fmt.Sprintf("%s:80", host), *conn) + if got != nil { + t.Fatalf("error dialing allowed host. got %v", got) + } +} + +func TestCustomDeny(t *testing.T) { + host := "127.0.0.1" + _, ipRange, _ := net.ParseCIDR(fmt.Sprintf("%s/32", host)) + allowed := []*net.IPNet{ipRange} + control := restrictedControl(allowed) + conn := new(syscall.RawConn) + expected := fmt.Errorf("upstream connection denied to internal host at %s", host) + got := control("tcp4", "192.168.1.2:80", *conn) + if !strings.Contains(got.Error(), "upstream connection denied") { + t.Fatalf("unexpected error dialing denylisted host. expected %v got %v", expected, got) + } +} + +func TestSingleIP(t *testing.T) { + orig := DefaultDialer.AllowedHosts() + host := "127.0.0.1" + DefaultDialer.SetAllowedHosts([]string{host}) + control := DefaultDialer.Dialer().Control + conn := new(syscall.RawConn) + expected := fmt.Errorf("upstream connection denied to internal host at %s", host) + got := control("tcp4", "192.168.1.2:80", *conn) + if !strings.Contains(got.Error(), "upstream connection denied") { + t.Fatalf("unexpected error dialing denylisted host. expected %v got %v", expected, got) + } + + host = "::1" + DefaultDialer.SetAllowedHosts([]string{host}) + control = DefaultDialer.Dialer().Control + conn = new(syscall.RawConn) + expected = fmt.Errorf("upstream connection denied to internal host at %s", host) + got = control("tcp4", "192.168.1.2:80", *conn) + if !strings.Contains(got.Error(), "upstream connection denied") { + t.Fatalf("unexpected error dialing denylisted host. expected %v got %v", expected, got) + } + + // Test an allowed connection + got = control("tcp4", fmt.Sprintf("[%s]:80", host), *conn) + if got != nil { + t.Fatalf("error dialing allowed host. got %v", got) + } + DefaultDialer.SetAllowedHosts(orig) +} diff --git a/go.mod b/go.mod index e0649971f..6fc6a4324 100644 --- a/go.mod +++ b/go.mod @@ -11,7 +11,7 @@ require ( github.com/emersion/go-imap v1.0.4 github.com/emersion/go-message v0.12.0 github.com/go-sql-driver/mysql v1.5.0 - github.com/gophish/gomail v0.0.0-20180314010319-cf7e1a5479be + github.com/gophish/gomail v0.0.0-20200818021916-1f6d0dfd512e github.com/gorilla/context v1.1.1 github.com/gorilla/csrf v1.6.2 github.com/gorilla/handlers v1.4.2 @@ -29,7 +29,5 @@ require ( golang.org/x/crypto v0.0.0-20200128174031-69ecbb4d6d5d golang.org/x/time v0.0.0-20200416051211-89c76fbcd5d1 gopkg.in/alecthomas/kingpin.v2 v2.2.6 - gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 - gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df // indirect ) diff --git a/go.sum b/go.sum index ede800857..c9bae1df7 100644 --- a/go.sum +++ b/go.sum @@ -32,8 +32,8 @@ github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LB github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe h1:lXe2qZdvpiX5WZkZR4hgp4KJVfY3nMkvmwbVkpv1rVY= github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/gophish/gomail v0.0.0-20180314010319-cf7e1a5479be h1:VTe1cdyqSi/wLowKNz/shz6E0G+9/XzldZbyAmt+0Yw= -github.com/gophish/gomail v0.0.0-20180314010319-cf7e1a5479be/go.mod h1:MpSuP7kw+gRy2z+4gIFZeF3DwhhdQhEXwRmPVQYD9ig= +github.com/gophish/gomail v0.0.0-20200818021916-1f6d0dfd512e h1:URNpXdOxXAfuZ8wsr/DY27KTffVenKDjtNVAEwcR2Oo= +github.com/gophish/gomail v0.0.0-20200818021916-1f6d0dfd512e/go.mod h1:JGlHttcLdDp3F4g8bPHqqQnUUDuB3poB4zLXozQ0xCY= github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8= github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= github.com/gorilla/csrf v1.6.2 h1:QqQ/OWwuFp4jMKgBFAzJVW3FMULdyUW7JoM4pEWuqKg= @@ -110,7 +110,5 @@ gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc h1:2gG gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc/go.mod h1:m7x9LTH6d71AHyAX77c9yqWCCa3UKHcVEj9y7hAtKDk= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df h1:n7WqCuqOuCbNr617RXOY0AWRXxgwEyPp2z+p0+hgMuE= -gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df/go.mod h1:LRQQ+SO6ZHR7tOkpBDuZnXENFzX8qRjMDMyPD6BRkCw= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/gophish.go b/gophish.go index 71ae7e8c8..67ccab16f 100644 --- a/gophish.go +++ b/gophish.go @@ -28,6 +28,7 @@ THE SOFTWARE. import ( "fmt" "io/ioutil" + "net/http" "os" "os/signal" @@ -35,10 +36,12 @@ import ( "github.com/gophish/gophish/config" "github.com/gophish/gophish/controllers" + "github.com/gophish/gophish/dialer" "github.com/gophish/gophish/imap" log "github.com/gophish/gophish/logger" "github.com/gophish/gophish/middleware" "github.com/gophish/gophish/models" + "github.com/gophish/gophish/webhook" ) const ( @@ -79,6 +82,13 @@ func main() { } config.Version = string(version) + // Configure our various upstream clients to make sure that we restrict + // outbound connections as needed. + dialer.SetAllowedHosts(conf.AdminConf.AllowedInternalHosts) + webhook.SetTransport(&http.Transport{ + DialContext: dialer.Dialer().DialContext, + }) + err = log.Setup(conf.Logging) if err != nil { log.Fatal(err) diff --git a/imap/imap.go b/imap/imap.go index 7e0561761..aa76a4af2 100644 --- a/imap/imap.go +++ b/imap/imap.go @@ -11,6 +11,7 @@ import ( "github.com/emersion/go-imap" "github.com/emersion/go-imap/client" "github.com/emersion/go-message/charset" + "github.com/gophish/gophish/dialer" log "github.com/gophish/gophish/logger" "github.com/gophish/gophish/models" @@ -184,12 +185,13 @@ func (mbox *Mailbox) GetUnread(markAsRead, delete bool) ([]Email, error) { func (mbox *Mailbox) newClient() (*client.Client, error) { var imapClient *client.Client var err error + restrictedDialer := dialer.Dialer() if mbox.TLS { config := new(tls.Config) config.InsecureSkipVerify = mbox.IgnoreCertErrors - imapClient, err = client.DialTLS(mbox.Host, config) + imapClient, err = client.DialWithDialerTLS(restrictedDialer, mbox.Host, config) } else { - imapClient, err = client.Dial(mbox.Host) + imapClient, err = client.DialWithDialer(restrictedDialer, mbox.Host) } if err != nil { return imapClient, err diff --git a/models/smtp.go b/models/smtp.go index 8ca8485ba..cd4d4e232 100644 --- a/models/smtp.go +++ b/models/smtp.go @@ -10,6 +10,7 @@ import ( "time" "github.com/gophish/gomail" + "github.com/gophish/gophish/dialer" log "github.com/gophish/gophish/logger" "github.com/gophish/gophish/mailer" "github.com/jinzhu/gorm" @@ -109,7 +110,8 @@ func (s *SMTP) GetDialer() (mailer.Dialer, error) { log.Error(err) return nil, err } - d := gomail.NewDialer(host, port, s.Username, s.Password) + dialer := dialer.Dialer() + d := gomail.NewWithDialer(dialer, host, port, s.Username, s.Password) d.TLSConfig = &tls.Config{ ServerName: host, InsecureSkipVerify: s.IgnoreCertErrors, diff --git a/models/smtp_test.go b/models/smtp_test.go index 7ffbaadf9..b559c2829 100644 --- a/models/smtp_test.go +++ b/models/smtp_test.go @@ -81,3 +81,15 @@ func (s *ModelsSuite) TestGetInvalidSMTP(ch *check.C) { _, err := GetSMTP(-1, 1) ch.Assert(err, check.Equals, gorm.ErrRecordNotFound) } + +func (s *ModelsSuite) TestDefaultDeniedDial(ch *check.C) { + host := "169.254.169.254" + port := 25 + smtp := SMTP{ + Host: fmt.Sprintf("%s:%d", host, port), + } + d, err := smtp.GetDialer() + ch.Assert(err, check.Equals, nil) + _, err = d.Dial() + ch.Assert(err, check.ErrorMatches, ".*upstream connection denied.*") +} diff --git a/webhook/webhook.go b/webhook/webhook.go index 0ce281b46..92ee20bf6 100644 --- a/webhook/webhook.go +++ b/webhook/webhook.go @@ -51,6 +51,11 @@ var senderInstance = &defaultSender{ }, } +// SetTransport sets the underlying transport for the default webhook client. +func SetTransport(tr *http.Transport) { + senderInstance.client.Transport = tr +} + // EndPoint represents a URL to send the webhook to, as well as a secret used // to sign the event type EndPoint struct {