Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions api/http_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,50 @@ func NewHTTPClient(opts HTTPClientOptions) (*http.Client, error) {
return client, nil
}

// ExternalHTTPClientOptions holds options for creating an external HTTP client.
type ExternalHTTPClientOptions struct {
AppVersion string
Log io.Writer
LogColorize bool
Transport http.RoundTripper
}

// NewExternalHTTPClient creates an HTTP client for talking to non-GitHub hosts.
// It includes debug logging and a User-Agent header but does not attach any
// authentication tokens or GitHub-specific headers.
func NewExternalHTTPClient(opts ExternalHTTPClientOptions) (*http.Client, error) {
clientOpts := ghAPI.ClientOptions{
Host: "none",
AuthToken: "none",
LogIgnoreEnv: true,
SkipDefaultHeaders: true,
Transport: opts.Transport,
}

debugEnabled, debugValue := utils.IsDebugEnabled()
logVerboseHTTP := false
if strings.Contains(debugValue, "api") {
logVerboseHTTP = true
}

if logVerboseHTTP || debugEnabled {
clientOpts.Log = opts.Log
clientOpts.LogColorize = opts.LogColorize
clientOpts.LogVerboseHTTP = logVerboseHTTP
}

clientOpts.Headers = map[string]string{
userAgent: fmt.Sprintf("GitHub CLI %s", opts.AppVersion),
}

client, err := ghAPI.NewHTTPClient(clientOpts)
if err != nil {
return nil, err
}

return client, nil
}

func NewCachedHTTPClient(httpClient *http.Client, ttl time.Duration) *http.Client {
newClient := *httpClient
newClient.Transport = AddCacheTTLHeader(httpClient.Transport, ttl)
Expand Down
49 changes: 49 additions & 0 deletions api/http_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,55 @@ func TestNewHTTPClientWithoutTelemetryDisabler(t *testing.T) {
assert.Equal(t, 204, res.StatusCode)
}

func TestNewExternalHTTPClient(t *testing.T) {
tests := []struct {
name string
url string
}{
{
name: "third-party host",
url: "https://example.com/path",
},
{
// Even when talking to GitHub, the external client must not set
// authorization or any GitHub-specific headers.
name: "github.com host",
url: "https://api.github.com/repos/cli/cli",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var gotReq *http.Request
transport := &funcTripper{roundTrip: func(req *http.Request) (*http.Response, error) {
gotReq = req
return &http.Response{StatusCode: 204, Body: io.NopCloser(strings.NewReader(""))}, nil
}}

client, err := NewExternalHTTPClient(ExternalHTTPClientOptions{
AppVersion: "v1.2.3",
Transport: transport,
})
require.NoError(t, err)

req, err := http.NewRequest("GET", tt.url, nil)
require.NoError(t, err)

res, err := client.Do(req)
require.NoError(t, err)
assert.Equal(t, 204, res.StatusCode)

// No headers should be set by default, except for User-Agent which should include the app version.
assert.Equal(t, []string{"GitHub CLI v1.2.3"}, gotReq.Header.Values("user-agent"))
assert.Empty(t, gotReq.Header.Values("authorization"))
assert.Empty(t, gotReq.Header.Values("x-github-api-version"))
assert.Empty(t, gotReq.Header.Values("accept"))
assert.Empty(t, gotReq.Header.Values("content-type"))
assert.Empty(t, gotReq.Header.Values("time-zone"))
})
}
}

type fakeTelemetryDisabler struct {
disabled bool
}
Expand Down
30 changes: 14 additions & 16 deletions internal/codespaces/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,11 @@ const (

// API is the interface to the codespace service.
type API struct {
client func() (*http.Client, error)
githubAPI string
githubServer string
retryBackoff time.Duration
client func() (*http.Client, error)
externalClient func() (*http.Client, error)
githubAPI string
githubServer string
retryBackoff time.Duration
}

// New creates a new API client connecting to the configured endpoints with the HTTP client.
Expand Down Expand Up @@ -93,10 +94,11 @@ func New(f *cmdutil.Factory) *API {
}

return &API{
client: f.HttpClient,
githubAPI: strings.TrimSuffix(apiURL, "/"),
githubServer: strings.TrimSuffix(serverURL, "/"),
retryBackoff: 100 * time.Millisecond,
client: f.HttpClient,
externalClient: f.ExternalHttpClient,
githubAPI: strings.TrimSuffix(apiURL, "/"),
githubServer: strings.TrimSuffix(serverURL, "/"),
retryBackoff: 100 * time.Millisecond,
}
}

Expand Down Expand Up @@ -1214,12 +1216,8 @@ func (a *API) withRetry(f func() (*http.Response, error)) (*http.Response, error
}, backoff.WithMaxRetries(bo, 3))
}

// HTTPClient returns the HTTP client used to make requests to the API.
func (a *API) HTTPClient() (*http.Client, error) {
httpClient, err := a.client()
if err != nil {
return nil, err
}

return httpClient, nil
// ExternalHTTPClient returns an HTTP client for requests to non-GitHub hosts.
// It must not carry GitHub authentication credentials.
func (a *API) ExternalHTTPClient() (*http.Client, error) {
return a.externalClient()
}
6 changes: 3 additions & 3 deletions internal/codespaces/codespaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func connectionReady(codespace *api.Codespace) bool {
type apiClient interface {
GetCodespace(ctx context.Context, name string, includeConnection bool) (*api.Codespace, error)
StartCodespace(ctx context.Context, name string) error
HTTPClient() (*http.Client, error)
ExternalHTTPClient() (*http.Client, error)
}

type progressIndicator interface {
Expand All @@ -66,12 +66,12 @@ func GetCodespaceConnection(ctx context.Context, progress progressIndicator, api
progress.StartProgressIndicatorWithLabel("Connecting to codespace")
defer progress.StopProgressIndicator()

httpClient, err := apiClient.HTTPClient()
externalHttpClient, err := apiClient.ExternalHTTPClient()
if err != nil {
return nil, fmt.Errorf("error getting http client: %w", err)
}

return connection.NewCodespaceConnection(ctx, codespace, httpClient)
return connection.NewCodespaceConnection(ctx, codespace, externalHttpClient)
}

// waitUntilCodespaceConnectionReady waits for a Codespace to be running and is able to be connected to.
Expand Down
4 changes: 2 additions & 2 deletions internal/codespaces/codespaces_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,8 @@ func (m *mockApiClient) GetCodespace(ctx context.Context, name string, includeCo
return m.onGetCodespace()
}

func (m *mockApiClient) HTTPClient() (*http.Client, error) {
panic("Not implemented")
func (m *mockApiClient) ExternalHTTPClient() (*http.Client, error) {
return nil, nil
}

type mockProgressIndicator struct{}
Expand Down
23 changes: 12 additions & 11 deletions pkg/cmd/attestation/api/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"io"
"net/http"
neturl "net/url"
"strings"
"time"

Expand Down Expand Up @@ -67,18 +68,18 @@ type Client interface {
}

type LiveClient struct {
githubAPI githubApiClient
httpClient httpClient
host string
logger *ioconfig.Handler
githubAPI githubApiClient
externalHttpClient httpClient
host string
logger *ioconfig.Handler
}

func NewLiveClient(hc *http.Client, host string, l *ioconfig.Handler) *LiveClient {
func NewLiveClient(hc *http.Client, externalClient *http.Client, host string, l *ioconfig.Handler) *LiveClient {
return &LiveClient{
githubAPI: api.NewClientFromHTTP(hc),
host: strings.TrimSuffix(host, "/"),
httpClient: hc,
logger: l,
githubAPI: api.NewClientFromHTTP(hc),
host: strings.TrimSuffix(host, "/"),
externalHttpClient: externalClient,
logger: l,
}
}

Expand Down Expand Up @@ -121,7 +122,7 @@ func (c *LiveClient) buildRequestURL(params FetchParams) (string, error) {
// ref: https://github.com/cli/go-gh/blob/d32c104a9a25c9de3d7c7b07a43ae0091441c858/example_gh_test.go#L96
url = fmt.Sprintf("%s?per_page=%d", url, perPage)
if params.PredicateType != "" {
url = fmt.Sprintf("%s&predicate_type=%s", url, params.PredicateType)
url = fmt.Sprintf("%s&predicate_type=%s", url, neturl.QueryEscape(params.PredicateType))
}
return url, nil
}
Expand Down Expand Up @@ -225,7 +226,7 @@ func (c *LiveClient) getBundle(url string) (*bundle.Bundle, error) {
var sgBundle *bundle.Bundle
bo := backoff.NewConstantBackOff(getAttestationRetryInterval)
err := backoff.Retry(func() error {
resp, err := c.httpClient.Get(url)
resp, err := c.externalHttpClient.Get(url)
if err != nil {
return fmt.Errorf("request to fetch bundle from URL failed: %w", err)
}
Expand Down
52 changes: 26 additions & 26 deletions pkg/cmd/attestation/api/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,17 @@ func NewClientWithMockGHClient(hasNextPage bool) Client {
githubAPI: mockAPIClient{
OnRESTWithNext: fetcher.OnRESTSuccessWithNextPage,
},
httpClient: httpClient,
logger: l,
externalHttpClient: httpClient,
logger: l,
}
}

return &LiveClient{
githubAPI: mockAPIClient{
OnRESTWithNext: fetcher.OnRESTSuccess,
},
httpClient: httpClient,
logger: l,
externalHttpClient: httpClient,
logger: l,
}
}

Expand Down Expand Up @@ -137,8 +137,8 @@ func TestGetByDigest_NoAttestationsFound(t *testing.T) {
githubAPI: mockAPIClient{
OnRESTWithNext: fetcher.OnRESTWithNextNoAttestations,
},
httpClient: httpClient,
logger: io.NewTestHandler(),
externalHttpClient: httpClient,
logger: io.NewTestHandler(),
}

attestations, err := c.GetByDigest(testFetchParamsWithRepo)
Expand Down Expand Up @@ -167,8 +167,8 @@ func TestGetByDigest_Error(t *testing.T) {
func TestFetchBundleFromAttestations_BundleURL(t *testing.T) {
httpClient := &mockHttpClient{}
client := LiveClient{
httpClient: httpClient,
logger: io.NewTestHandler(),
externalHttpClient: httpClient,
logger: io.NewTestHandler(),
}

att1 := makeTestAttestation()
Expand All @@ -184,8 +184,8 @@ func TestFetchBundleFromAttestations_BundleURL(t *testing.T) {
func TestFetchBundleFromAttestations_MissingBundleAndBundleURLFields(t *testing.T) {
httpClient := &mockHttpClient{}
client := LiveClient{
httpClient: httpClient,
logger: io.NewTestHandler(),
externalHttpClient: httpClient,
logger: io.NewTestHandler(),
}

// If both the BundleURL and Bundle fields are empty, the function should
Expand All @@ -207,8 +207,8 @@ func TestFetchBundleFromAttestations_FailOnTheSecondAttestation(t *testing.T) {
}

c := &LiveClient{
httpClient: mockHTTPClient,
logger: io.NewTestHandler(),
externalHttpClient: mockHTTPClient,
logger: io.NewTestHandler(),
}

att1 := makeTestAttestation()
Expand All @@ -223,8 +223,8 @@ func TestFetchBundleFromAttestations_FailAfterRetrying(t *testing.T) {
mockHTTPClient := &reqFailHttpClient{}

c := &LiveClient{
httpClient: mockHTTPClient,
logger: io.NewTestHandler(),
externalHttpClient: mockHTTPClient,
logger: io.NewTestHandler(),
}

a := makeTestAttestation()
Expand All @@ -239,8 +239,8 @@ func TestFetchBundleFromAttestations_FallbackToBundleField(t *testing.T) {
mockHTTPClient := &mockHttpClient{}

c := &LiveClient{
httpClient: mockHTTPClient,
logger: io.NewTestHandler(),
externalHttpClient: mockHTTPClient,
logger: io.NewTestHandler(),
}

// If the bundle URL is empty, the code will fallback to the bundle field
Expand All @@ -257,8 +257,8 @@ func TestGetBundle(t *testing.T) {
mockHTTPClient := &mockHttpClient{}

c := &LiveClient{
httpClient: mockHTTPClient,
logger: io.NewTestHandler(),
externalHttpClient: mockHTTPClient,
logger: io.NewTestHandler(),
}

b, err := c.getBundle("https://mybundleurl.com")
Expand All @@ -276,8 +276,8 @@ func TestGetBundle_SuccessfulRetry(t *testing.T) {
}

c := &LiveClient{
httpClient: mockHTTPClient,
logger: io.NewTestHandler(),
externalHttpClient: mockHTTPClient,
logger: io.NewTestHandler(),
}

b, err := c.getBundle("mybundleurl")
Expand All @@ -290,8 +290,8 @@ func TestGetBundle_SuccessfulRetry(t *testing.T) {
func TestGetBundle_PermanentBackoffFail(t *testing.T) {
mockHTTPClient := &invalidBundleClient{}
c := &LiveClient{
httpClient: mockHTTPClient,
logger: io.NewTestHandler(),
externalHttpClient: mockHTTPClient,
logger: io.NewTestHandler(),
}

b, err := c.getBundle("mybundleurl")
Expand All @@ -307,8 +307,8 @@ func TestGetBundle_RequestFail(t *testing.T) {
mockHTTPClient := &reqFailHttpClient{}

c := &LiveClient{
httpClient: mockHTTPClient,
logger: io.NewTestHandler(),
externalHttpClient: mockHTTPClient,
logger: io.NewTestHandler(),
}

b, err := c.getBundle("mybundleurl")
Expand Down Expand Up @@ -360,8 +360,8 @@ func TestGetAttestationsRetries(t *testing.T) {
githubAPI: mockAPIClient{
OnRESTWithNext: fetcher.FlakyOnRESTSuccessWithNextPageHandler(),
},
httpClient: &mockHttpClient{},
logger: io.NewTestHandler(),
externalHttpClient: &mockHttpClient{},
logger: io.NewTestHandler(),
}

testFetchParamsWithRepo.Limit = 30
Expand Down
Loading
Loading