Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fixed a bug in how headers are forwarded for GitHub API #25

Merged
merged 7 commits into from
Sep 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -80,4 +80,5 @@ jobs:
push: true
cache-from: type=gha
cache-to: type=gha,mode=max
build-args: VERSION=${{ env.IMAGE_TAG }}
tags: us.gcr.io/${{ env.REPO }}/runner:${{ env.IMAGE_TAG }}
4 changes: 3 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ COPY go.mod go.sum ./
RUN go mod download

COPY . .
RUN CGO_ENABLED=1 GOOS=linux GOARCH=amd64 go build -o runner ./cmd/runner/*.go

ARG VERSION=1.0.0-beta.x
RUN CGO_ENABLED=1 GOOS=linux GOARCH=amd64 go build -ldflags "-X main.version=${VERSION}" -o runner ./cmd/runner/*.go

FROM alpine:3.17

Expand Down
8 changes: 4 additions & 4 deletions auth/token/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func TestSessionAuthMiddleware(t *testing.T) {
assert.NoError(t, err)

assert.Equal(t, http.StatusTemporaryRedirect, rec.Code)
assert.Equal(t, "/refresh", rec.Header().Get("Location"))
assert.Equal(t, "/refresh?redirect=/", rec.Header().Get("Location"))
})

t.Run("cookie empty", func(t *testing.T) {
Expand All @@ -48,7 +48,7 @@ func TestSessionAuthMiddleware(t *testing.T) {
assert.NoError(t, err)

assert.Equal(t, http.StatusTemporaryRedirect, rec.Code)
assert.Equal(t, "/refresh", rec.Header().Get("Location"))
assert.Equal(t, "/refresh?redirect=/", rec.Header().Get("Location"))
})

t.Run("cookie invalid", func(t *testing.T) {
Expand All @@ -63,7 +63,7 @@ func TestSessionAuthMiddleware(t *testing.T) {
assert.NoError(t, err)

assert.Equal(t, http.StatusTemporaryRedirect, rec.Code)
assert.Equal(t, "/refresh", rec.Header().Get("Location"))
assert.Equal(t, "/refresh?redirect=/", rec.Header().Get("Location"))
})

user := &model.User{
Expand All @@ -88,7 +88,7 @@ func TestSessionAuthMiddleware(t *testing.T) {
err = h(c)
assert.NoError(t, err)
assert.Equal(t, http.StatusTemporaryRedirect, rec.Code)
assert.Equal(t, "/refresh", rec.Header().Get("Location"))
assert.Equal(t, "/refresh?redirect=/", rec.Header().Get("Location"))
})

t.Run("valid token", func(t *testing.T) {
Expand Down
22 changes: 0 additions & 22 deletions cmd/runner/cors.go

This file was deleted.

10 changes: 6 additions & 4 deletions cmd/runner/provder.go → cmd/runner/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
"github.com/deepsourcecorp/runner/provider/model"
)

func GetProvider(ctx context.Context, c *config.Config, client *http.Client) (*provider.Facade, error) {
func GetProvider(_ context.Context, c *config.Config, client *http.Client) (*provider.Facade, error) {
githubApps := createGithubApps(c)
providerApps := createProviderApps(c)

Expand All @@ -24,10 +24,12 @@ func GetProvider(ctx context.Context, c *config.Config, client *http.Client) (*p
Host: c.DeepSource.Host,
}

apiFactory := github.NewAPIProxyFactory(githubApps, client)
webhookFactory := github.NewWebhookProxyFactory(runner, deepsource, githubApps, client)
appFactory := github.NewAppFactory(githubApps)

githubProvider, err := github.NewHandler(apiFactory, webhookFactory)
webhookService := github.NewWebhookService(appFactory, runner, deepsource, client)
apiService := github.NewAPIService(appFactory, client)

githubProvider, err := github.NewHandler(webhookService, apiService, appFactory, runner, deepsource, client)
if err != nil {
return nil, fmt.Errorf("error initializing provider: %w", err)
}
Expand Down
17 changes: 9 additions & 8 deletions cmd/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,25 @@ import (
"net/http"
"time"

"github.com/deepsourcecorp/runner/config"
runnerconfig "github.com/deepsourcecorp/runner/config"
runnermiddleware "github.com/deepsourcecorp/runner/middleware"
"github.com/getsentry/sentry-go"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
"golang.org/x/exp/slog"
)

var version string

const (
Banner = `________
___ __ \___ ____________________________
__ /_/ / / / /_ __ \_ __ \ _ \_ ___/
_ _, _// /_/ /_ / / / / / / __/ /
/_/ |_| \__,_/ /_/ /_//_/ /_/\___//_/
------------------------------------------
By DeepSource | v%s
By DeepSource | %s
------------------------------------------`

Version = "0.1.0-beta.1"
)

const (
Expand All @@ -33,12 +34,12 @@ const (

type Server struct {
*echo.Echo
*config.Config
*runnerconfig.Config
*http.Client
cors echo.MiddlewareFunc
}

func NewServer(c *config.Config) *Server {
func NewServer(c *runnerconfig.Config) *Server {
e := echo.New()
e.HideBanner = true
e.HidePort = true
Expand All @@ -54,7 +55,7 @@ func NewServer(c *config.Config) *Server {
e.Use(middleware.LoggerWithConfig(middleware.LoggerConfig{
Format: "time=${time_rfc3339_nano} level=INFO method=${method}, uri=${uri}, status=${status}\n",
}))
cors := CorsMiddleware(c.DeepSource.Host.String())
cors := runnermiddleware.CorsMiddleware(c.DeepSource.Host.String())
return &Server{Echo: e, Config: c, cors: cors}
}

Expand All @@ -68,7 +69,7 @@ func (s *Server) Start() error {
}

func (*Server) PrintBanner() {
fmt.Println(fmt.Sprintf(Banner, Version))
fmt.Println(fmt.Sprintf(Banner, version))
}

func (s *Server) Router() (*Router, error) {
Expand Down
135 changes: 135 additions & 0 deletions forwarder/forwarder.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
package forwarder

import (
"bytes"
"fmt"
"io"
"net/http"
"net/textproto"
"net/url"
"strings"
)

type Opts struct {
TargetURL url.URL
Headers http.Header
Query url.Values
}

type Forwarder struct {
client *http.Client
}

func New(client *http.Client) *Forwarder {
return &Forwarder{client: client}
}

func (f *Forwarder) Forward(req *http.Request, opts *Opts) (*http.Response, error) {
defer req.Body.Close()
ctx := req.Context()

body, err := io.ReadAll(req.Body)
if err != nil {
return nil, fmt.Errorf("failed to read request body: %w", err)
}

out, err := http.NewRequestWithContext(
ctx,
req.Method,
opts.TargetURL.String(),
bytes.NewReader(body),
)

if err != nil {
return nil, fmt.Errorf("failed to create target request: %w", err)
}

copyHeader(out.Header, req.Header)
appendHeaders(out, opts.Headers)

copyQueryParams(out, req)
appendQueryParams(out, opts.Query)

removeHopHeaders(out.Header)
removeCloudflareHeaders(out.Header)

res, err := f.client.Do(out)
if err != nil {
return nil, fmt.Errorf("failed to make target request: %w", err)
}
removeHopHeaders(res.Header)
return res, nil
}

func copyHeader(dst, src http.Header) {
for k, vv := range src {
for _, v := range vv {
dst.Add(k, v)
}
}
}

func removeHopHeaders(h http.Header) {
hopHeaders := []string{
"Connection",
"Proxy-Connection",
"Keep-Alive",
"Proxy-Authenticate",
"Proxy-Authorization",
"Te",
"Trailer",
"Transfer-Encoding",
"Upgrade",
}

for _, f := range h["Connection"] {
for _, sf := range strings.Split(f, ",") {
if sf = textproto.TrimString(sf); sf != "" {
h.Del(sf)
}
}
}

for _, k := range hopHeaders {
h.Del(k)
}
}

func removeCloudflareHeaders(h http.Header) {
cloudflareHeaders := []string{
"CF-Connecting-IP",
"CF-IPCountry",
"CF-RAY",
"CF-Visitor",
"CF-Request-ID",
"CF-Worker",
}

for _, k := range cloudflareHeaders {
h.Del(k)
}
}

func copyQueryParams(dst, src *http.Request) {
q := dst.URL.Query()
for k, v := range src.URL.Query() {
q[k] = v
}
dst.URL.RawQuery = q.Encode()
}

func appendQueryParams(req *http.Request, query url.Values) {
q := req.URL.Query()
for k, v := range query {
q[k] = v
}
req.URL.RawQuery = q.Encode()
}

func appendHeaders(req *http.Request, headers http.Header) {
for k, v := range headers {
for _, vv := range v {
req.Header.Add(k, vv)
}
}
}
50 changes: 50 additions & 0 deletions forwarder/forwarder_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package forwarder

import (
"bytes"
"io"
"net/http"
"net/http/httptest"
"net/url"
"testing"

"github.com/stretchr/testify/assert"
)

func TestProxy(t *testing.T) {
body := []byte("test-body")

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)
assert.Equal(t, "original-header-value", r.Header.Get("Original-Header"))
assert.Equal(t, "extra-header-value", r.Header.Get("Extra-Header"))
assert.Equal(t, "1", r.URL.Query().Get("original-query"))
assert.Equal(t, "2", r.URL.Query().Get("extra-query"))
assert.Empty(t, r.Header.Get("Keep-Alive"))

body, _ := io.ReadAll(r.Body)
assert.Equal(t, "test-body", string(body))
assert.Equal(t, r.ContentLength, int64(len(body)))

w.WriteHeader(http.StatusOK)
}))
serverURL, _ := url.Parse(server.URL)

in := httptest.NewRequest(http.MethodGet, "https://example.com?original-query=1", bytes.NewReader(body))
in.Header.Set("Original-Header", "original-header-value")
in.Header.Set("Keep-Alive", "300")

extraHeaders := http.Header{}
extraHeaders.Set("Extra-Header", "extra-header-value")

forwarder := New(http.DefaultClient)

res, err := forwarder.Forward(in, &Opts{
TargetURL: *serverURL,
Headers: extraHeaders,
Query: map[string][]string{"extra-query": {"2"}},
})

assert.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)
}
Loading
Loading