Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mTLS support for server API #110

Merged
merged 1 commit into from
May 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ Configuration of the adapter is done via environment variables at startup.
| `SCANNER_API_SERVER_ADDR` | `:8080` | Binding address for the API server |
| `SCANNER_API_SERVER_TLS_CERTIFICATE` | N/A | The absolute path to the x509 certificate file |
| `SCANNER_API_SERVER_TLS_KEY` | N/A | The absolute path to the x509 private key file |
| `SCANNER_API_SERVER_CLIENT_CAS` | N/A | A list of absolute paths to x509 root certificate authorities that the api use if required to verify a client certificate |
| `SCANNER_API_SERVER_READ_TIMEOUT` | `15s` | The maximum duration for reading the entire request, including the body |
| `SCANNER_API_SERVER_WRITE_TIMEOUT` | `15s` | The maximum duration before timing out writes of the response |
| `SCANNER_API_SERVER_IDLE_TIMEOUT` | `60s` | The maximum amount of time to wait for the next request when keep-alives are enabled |
Expand Down
5 changes: 4 additions & 1 deletion cmd/scanner-trivy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,10 @@ func run(info etc.BuildInfo) error {
store := redis.NewStore(config.RedisStore)
enqueuer := queue.NewEnqueuer(config.JobQueue, store)
apiHandler := v1.NewAPIHandler(info, config, enqueuer, store, trivy.NewWrapper(config.Trivy, ext.DefaultAmbassador))
apiServer := api.NewServer(config.API, apiHandler)
apiServer, err := api.NewServer(config.API, apiHandler)
if err != nil {
return fmt.Errorf("new api server: %w", err)
}

shutdownComplete := make(chan struct{})
go func() {
Expand Down
8 changes: 8 additions & 0 deletions pkg/etc/checker.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,18 @@ func Check(config Config) (err error) {
err = fmt.Errorf("TLS certificate file does not exist: %s", config.API.TLSCertificate)
return
}

if !fileExists(config.API.TLSKey) {
err = fmt.Errorf("TLS private key file does not exist: %s", config.API.TLSKey)
return
}

for _, path := range config.API.ClientCAs {
if !fileExists(path) {
err = fmt.Errorf("ClientCA file does not exist: %s", path)
return
}
}
}

return
Expand Down
46 changes: 46 additions & 0 deletions pkg/etc/checker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,4 +146,50 @@ func TestCheck(t *testing.T) {
})
assert.EqualError(t, err, fmt.Sprintf("TLS private key file does not exist: %s", keyFile))
})

t.Run("Should return error when one of ClientCAs does not exist", func(t *testing.T) {
tempDir, err := ioutil.TempDir("", "TestCheck_*")
require.NoError(t, err)
defer func() {
_ = os.RemoveAll(tempDir)
}()

cacheDir := path.Join(tempDir, "cache")
reportsDir := path.Join(tempDir, "reports")
certFile := path.Join(tempDir, "tls.crt")
keyFile := path.Join(tempDir, "tls.key")
clientCA1File := path.Join(tempDir, "clientCA1.crt")
clientCA2File := path.Join(tempDir, "clientCA2.crt")
clientCA3File := path.Join(tempDir, "clientCA3.crt")

f, err := os.Create(certFile)
require.NoError(t, err)
_ = f.Close()

f, err = os.Create(keyFile)
require.NoError(t, err)
_ = f.Close()

f, err = os.Create(clientCA1File)
require.NoError(t, err)
_ = f.Close()

f, err = os.Create(clientCA3File)
require.NoError(t, err)
_ = f.Close()

err = Check(Config{
API: API{
TLSCertificate: certFile,
TLSKey: keyFile,
ClientCAs: []string{clientCA1File, clientCA2File, clientCA3File},
},
Trivy: Trivy{
CacheDir: cacheDir,
ReportsDir: reportsDir,
},
})

assert.EqualError(t, err, fmt.Sprintf("ClientCA file does not exist: %s", clientCA2File))
})
}
1 change: 1 addition & 0 deletions pkg/etc/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ type API struct {
Addr string `env:"SCANNER_API_SERVER_ADDR" envDefault:":8080"`
TLSCertificate string `env:"SCANNER_API_SERVER_TLS_CERTIFICATE"`
TLSKey string `env:"SCANNER_API_SERVER_TLS_KEY"`
ClientCAs []string `env:"SCANNER_API_SERVER_CLIENT_CAS"`
ReadTimeout time.Duration `env:"SCANNER_API_SERVER_READ_TIMEOUT" envDefault:"15s"`
WriteTimeout time.Duration `env:"SCANNER_API_SERVER_WRITE_TIMEOUT" envDefault:"15s"`
IdleTimeout time.Duration `env:"SCANNER_API_SERVER_IDLE_TIMEOUT" envDefault:"60s"`
Expand Down
2 changes: 2 additions & 0 deletions pkg/etc/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ func TestGetConfig(t *testing.T) {
"SCANNER_API_SERVER_ADDR": ":4200",
"SCANNER_API_SERVER_TLS_CERTIFICATE": "/certs/tls.crt",
"SCANNER_API_SERVER_TLS_KEY": "/certs/tls.key",
"SCANNER_API_SERVER_CLIENT_CAS": "/certs/tls1.crt,/certs/tls2.crt",
"SCANNER_API_SERVER_TLS_MIN_VERSION": "1.0",
"SCANNER_API_SERVER_TLS_MAX_VERSION": "1.2",
"SCANNER_API_SERVER_READ_TIMEOUT": "1h",
Expand Down Expand Up @@ -121,6 +122,7 @@ func TestGetConfig(t *testing.T) {
Addr: ":4200",
TLSCertificate: "/certs/tls.crt",
TLSKey: "/certs/tls.key",
ClientCAs: []string{"/certs/tls1.crt", "/certs/tls2.crt"},
ReadTimeout: parseDuration(t, "1h"),
WriteTimeout: parseDuration(t, "2m"),
IdleTimeout: parseDuration(t, "3m10s"),
Expand Down
25 changes: 24 additions & 1 deletion pkg/http/api/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@ package api

import (
"crypto/tls"
"crypto/x509"
"fmt"
"io/ioutil"
"net/http"
"strings"

"github.com/aquasecurity/harbor-scanner-trivy/pkg/etc"
log "github.com/sirupsen/logrus"
Expand All @@ -14,7 +18,7 @@ type Server struct {
server *http.Server
}

func NewServer(config etc.API, handler http.Handler) (server *Server) {
func NewServer(config etc.API, handler http.Handler) (server *Server, err error) {
server = &Server{
config: config,
server: &http.Server{
Expand All @@ -25,6 +29,7 @@ func NewServer(config etc.API, handler http.Handler) (server *Server) {
IdleTimeout: config.IdleTimeout,
},
}

if config.IsTLSEnabled() {
server.server.TLSConfig = &tls.Config{
MinVersion: tls.VersionTLS12,
Expand All @@ -46,7 +51,24 @@ func NewServer(config etc.API, handler http.Handler) (server *Server) {
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
},
}

if len(config.ClientCAs) > 0 {
certPool := x509.NewCertPool()

for _, clientCAPath := range config.ClientCAs {
clientCA, err := ioutil.ReadFile(clientCAPath)
if err != nil {
return nil, fmt.Errorf("cound not read file %s: %w", clientCAPath, err)
}

certPool.AppendCertsFromPEM(clientCA)
}

server.server.TLSConfig.ClientCAs = certPool
server.server.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert
}
}

return
}

Expand All @@ -64,6 +86,7 @@ func (s *Server) listenAndServe() error {
log.WithFields(log.Fields{
"certificate": s.config.TLSCertificate,
"key": s.config.TLSKey,
"clientCAs": strings.Join(s.config.ClientCAs, ", "),
"addr": s.config.Addr,
}).Debug("Starting API server with TLS")
return s.server.ListenAndServeTLS(s.config.TLSCertificate, s.config.TLSKey)
Expand Down