diff --git a/github/login.go b/github/login.go index 0c50801..c8d2520 100644 --- a/github/login.go +++ b/github/login.go @@ -2,7 +2,9 @@ package github import ( "errors" + "fmt" "net/http" + "net/url" "github.com/dghubble/gologin" oauth2Login "github.com/dghubble/gologin/oauth2" @@ -37,7 +39,17 @@ func LoginHandler(config *oauth2.Config, failure http.Handler) http.Handler { // access token and User to the ctx. If authentication succeeds, handling // delegates to the success handler, otherwise to the failure handler. func CallbackHandler(config *oauth2.Config, success, failure http.Handler) http.Handler { - success = githubHandler(config, success, failure) + success = githubHandler(config, false, success, failure) + return oauth2Login.CallbackHandler(config, success, failure) +} + +// EnterpriseCallbackHandler handles Github Enterprise redirection URI requests +// and adds the Github access token and User to the ctx. If authentication +// succeeds,handling delegates to the success handler, otherwise to the failure +// handler. The Github Enterprise API URL is inferred from the OAuth2 config's +// AuthURL endpoint. +func EnterpriseCallbackHandler(config *oauth2.Config, success, failure http.Handler) http.Handler { + success = githubHandler(config, true, success, failure) return oauth2Login.CallbackHandler(config, success, failure) } @@ -45,7 +57,7 @@ func CallbackHandler(config *oauth2.Config, success, failure http.Handler) http. // get the corresponding Github User. If successful, the User is added to the // ctx and the success handler is called. Otherwise, the failure handler is // called. -func githubHandler(config *oauth2.Config, success, failure http.Handler) http.Handler { +func githubHandler(config *oauth2.Config, isEnterprise bool, success, failure http.Handler) http.Handler { if failure == nil { failure = gologin.DefaultFailureHandler } @@ -57,8 +69,19 @@ func githubHandler(config *oauth2.Config, success, failure http.Handler) http.Ha failure.ServeHTTP(w, req.WithContext(ctx)) return } + httpClient := config.Client(ctx, token) - githubClient := github.NewClient(httpClient) + var githubClient *github.Client + if isEnterprise { + githubClient, err = enterpriseGithubClientFromAuthURL(config.Endpoint.AuthURL, httpClient) + if err != nil { + ctx = gologin.WithError(ctx, fmt.Errorf("github: error creating Client: %v", err)) + failure.ServeHTTP(w, req.WithContext(ctx)) + return + } + } else { + githubClient = github.NewClient(httpClient) + } user, resp, err := githubClient.Users.Get(ctx, "") err = validateResponse(user, resp, err) if err != nil { @@ -83,3 +106,20 @@ func validateResponse(user *github.User, resp *github.Response, err error) error } return nil } + +// enterpriseGithubClientFromAuthURL returns a Github client that targets a GHE instance. +func enterpriseGithubClientFromAuthURL(authURL string, httpClient *http.Client) (*github.Client, error) { + client := github.NewClient(httpClient) + + // convert authURL to GHE baseURL https://.com/api/v3/ + baseURL, err := url.Parse(authURL) + if err != nil { + return nil, fmt.Errorf("github: error parsing Endoint.AuthURL: %s", authURL) + } + + baseURL.Path = "/api/v3/" + client.BaseURL = baseURL + client.UploadURL = baseURL + + return client, nil +} diff --git a/github/login_test.go b/github/login_test.go index fb0bca6..be1c448 100644 --- a/github/login_test.go +++ b/github/login_test.go @@ -18,7 +18,7 @@ import ( func TestGithubHandler(t *testing.T) { jsonData := `{"id": 917408, "name": "Alyssa Hacker"}` expectedUser := &github.User{ID: github.Int64(917408), Name: github.String("Alyssa Hacker")} - proxyClient, server := newGithubTestServer(jsonData) + proxyClient, server := newGithubTestServer("", jsonData) defer server.Close() // oauth2 Client will use the proxy client's base Transport @@ -41,7 +41,7 @@ func TestGithubHandler(t *testing.T) { // - github User is obtained from the Github API // - success handler is called // - github User is added to the ctx of the success handler - githubHandler := githubHandler(config, http.HandlerFunc(success), failure) + githubHandler := githubHandler(config, false, http.HandlerFunc(success), failure) w := httptest.NewRecorder() req, _ := http.NewRequest("GET", "/", nil) githubHandler.ServeHTTP(w, req.WithContext(ctx)) @@ -63,7 +63,7 @@ func TestGithubHandler_MissingCtxToken(t *testing.T) { // GithubHandler called without Token in ctx, assert that: // - failure handler is called // - error about ctx missing token is added to the failure handler ctx - githubHandler := githubHandler(config, success, http.HandlerFunc(failure)) + githubHandler := githubHandler(config, false, success, http.HandlerFunc(failure)) w := httptest.NewRecorder() req, _ := http.NewRequest("GET", "/", nil) githubHandler.ServeHTTP(w, req) @@ -92,13 +92,47 @@ func TestGithubHandler_ErrorGettingUser(t *testing.T) { // GithubHandler cannot get Github User, assert that: // - failure handler is called // - error cannot get Github User added to the failure handler ctx - githubHandler := githubHandler(config, success, http.HandlerFunc(failure)) + githubHandler := githubHandler(config, false, success, http.HandlerFunc(failure)) w := httptest.NewRecorder() req, _ := http.NewRequest("GET", "/", nil) githubHandler.ServeHTTP(w, req.WithContext(ctx)) assert.Equal(t, "failure handler called", w.Body.String()) } +func TestGithubHandler_Enterprise(t *testing.T) { + jsonData := `{"id": 917408, "name": "Alyssa Hacker"}` + expectedUser := &github.User{ID: github.Int64(917408), Name: github.String("Alyssa Hacker")} + proxyClient, server := newGithubTestServer("/api/v3", jsonData) + defer server.Close() + + // oauth2 Client will use the proxy client's base Transport + ctx := context.WithValue(context.Background(), oauth2.HTTPClient, proxyClient) + anyToken := &oauth2.Token{AccessToken: "any-token"} + ctx = oauth2Login.WithToken(ctx, anyToken) + + config := &oauth2.Config{} + config.Endpoint.AuthURL = "https://github.mycompany.com/login/oauth/authorize" + success := func(w http.ResponseWriter, req *http.Request) { + ctx := req.Context() + githubUser, err := UserFromContext(ctx) + assert.Nil(t, err) + assert.Equal(t, expectedUser, githubUser) + fmt.Fprintf(w, "success handler called") + } + failure := testutils.AssertFailureNotCalled(t) + + // GithubHandler assert that: + // - Token is read from the ctx and passed to the Github API + // - github User is obtained from the Github API + // - success handler is called + // - github User is added to the ctx of the success handler + githubHandler := githubHandler(config, true, http.HandlerFunc(success), failure) + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/", nil) + githubHandler.ServeHTTP(w, req.WithContext(ctx)) + assert.Equal(t, "success handler called", w.Body.String()) +} + func TestValidateResponse(t *testing.T) { validUser := &github.User{ID: github.Int64(123)} validResponse := &github.Response{Response: &http.Response{StatusCode: 200}} @@ -108,3 +142,18 @@ func TestValidateResponse(t *testing.T) { assert.Equal(t, ErrUnableToGetGithubUser, validateResponse(validUser, invalidResponse, nil)) assert.Equal(t, ErrUnableToGetGithubUser, validateResponse(&github.User{}, validResponse, nil)) } + +func Test_enterpriseGithubClientFromAuthURL(t *testing.T) { + cases := []struct { + authURL string + expClientBaseURL string + }{ + {"https://github.mycompany.com/login/oauth/authorize", "https://github.mycompany.com/api/v3/"}, + {"http://github.mycompany.com/login/oauth/authorize", "http://github.mycompany.com/api/v3/"}, + } + for _, c := range cases { + client, err := enterpriseGithubClientFromAuthURL(c.authURL, nil) + assert.Nil(t, err) + assert.Equal(t, client.BaseURL.String(), c.expClientBaseURL) + } +} diff --git a/github/server_test.go b/github/server_test.go index 53a3d6b..9e4c2f4 100644 --- a/github/server_test.go +++ b/github/server_test.go @@ -8,12 +8,14 @@ import ( "github.com/dghubble/gologin/testutils" ) -// newGithubTestServer returns a new httptest.Server which mocks the Github -// user endpoint and a client which proxies requests to the server. The server -// responds with the given json data. The caller must close the server. -func newGithubTestServer(jsonData string) (*http.Client, *httptest.Server) { +// newGithubTestServer returns a new httptest.Server which mocks the Github user +// endpoint and a client which proxies requests to the server. The server +// responds with the given json data. The caller must close the server. The +// routePrefix parameter specifies an optional route prefix that should be set +// to the API root route (empty for github.com, "/api/v3" for GHE). +func newGithubTestServer(routePrefix, jsonData string) (*http.Client, *httptest.Server) { client, mux, server := testutils.TestServer() - mux.HandleFunc("/user", func(w http.ResponseWriter, r *http.Request) { + mux.HandleFunc(routePrefix+"/user", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") fmt.Fprintf(w, jsonData) })