diff --git a/cmd/root.go b/cmd/root.go index c7f931e9e..f1c63cec8 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -913,7 +913,7 @@ func runSignalWrapper(cmd *Command) (err error) { func quitquitquit(quitOnce *sync.Once, shutdownCh chan<- error) http.HandlerFunc { return func(rw http.ResponseWriter, req *http.Request) { - if req.Method != http.MethodPost { + if req.Method != http.MethodPost && req.Method != http.MethodGet { rw.WriteHeader(400) return } diff --git a/cmd/root_test.go b/cmd/root_test.go index b2f60b02c..1607e7164 100644 --- a/cmd/root_test.go +++ b/cmd/root_test.go @@ -1307,7 +1307,7 @@ func TestPProfServer(t *testing.T) { } } -func TestQuitQuitQuit(t *testing.T) { +func TestQuitQuitQuitHTTPPost(t *testing.T) { c := NewCommand(WithDialer(&spyDialer{})) c.SilenceUsage = true c.SilenceErrors = true @@ -1320,7 +1320,7 @@ func TestQuitQuitQuit(t *testing.T) { err := c.ExecuteContext(ctx) errCh <- err }() - resp, err := tryDial("GET", "http://localhost:9192/quitquitquit") + resp, err := tryDial("HEAD", "http://localhost:9192/quitquitquit") if err != nil { t.Fatalf("failed to dial endpoint: %v", err) } @@ -1348,6 +1348,41 @@ func TestQuitQuitQuit(t *testing.T) { } } +func TestQuitQuitQuitHTTPGet(t *testing.T) { + c := NewCommand(WithDialer(&spyDialer{})) + c.SilenceUsage = true + c.SilenceErrors = true + c.SetArgs([]string{"--quitquitquit", "--admin-port", "9194", "my-project:my-region:my-instance"}) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + errCh := make(chan error) + go func() { + err := c.ExecuteContext(ctx) + errCh <- err + }() + + resp, err := tryDial("GET", "http://localhost:9194/quitquitquit") + if err != nil { + t.Fatalf("failed to dial endpoint: %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected a 200 status, got = %v", resp.StatusCode) + } + + var gotErr error + select { + case err := <-errCh: + gotErr = err + case <-time.After(30 * time.Second): + t.Fatal("timeout waiting for error") + } + + if !errors.Is(gotErr, errQuitQuitQuit) { + t.Fatalf("want = %v, got = %v", errQuitQuitQuit, gotErr) + } +} + type errorDialer struct { spyDialer }