Skip to content
This repository has been archived by the owner on Apr 19, 2024. It is now read-only.

Commit

Permalink
Add OnInterrupt to run function code on CTRL+C (#301)
Browse files Browse the repository at this point in the history
* Add OnSIGINTFunc to run function on CTRL+C

* Add question specific SIGINTFunc options.

* Minor renaming.

* Trigger Build

* Rename to OnInterrupt.

* Rename forgotten function.

* Add tests.
  • Loading branch information
infalmo committed Dec 11, 2020
1 parent 0fa4cd6 commit 90b418e
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 11 deletions.
20 changes: 20 additions & 0 deletions survey.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ import (
"github.com/AlecAivazis/survey/v2/terminal"
)

// OnInterrupt is the function to run when
// SIGINT (CTRL+C) is sent to the process.
var OnInterrupt func()

// DefaultAskOptions is the default options on ask, using the OS stdio.
func defaultAskOptions() *AskOptions {
return &AskOptions{
Expand Down Expand Up @@ -56,6 +60,7 @@ func defaultAskOptions() *AskOptions {
},
KeepFilter: false,
},
OnInterrupt: OnInterrupt,
}
}
func defaultPromptConfig() *PromptConfig {
Expand Down Expand Up @@ -137,6 +142,7 @@ type AskOptions struct {
Stdio terminal.Stdio
Validators []Validator
PromptConfig PromptConfig
OnInterrupt func()
}

// WithStdio specifies the standard input, output and error files survey
Expand Down Expand Up @@ -182,6 +188,16 @@ func WithValidator(v Validator) AskOpt {
}
}

// WithInterruptFunc specifies a function to run on recieving
// SIGINT (aka CTRL+C) during prompt.
func WithInterruptFunc(fn func()) AskOpt {
return func(options *AskOptions) error {
options.OnInterrupt = fn
// nothing went wrong
return nil
}
}

type wantsStdio interface {
WithStdio(terminal.Stdio)
}
Expand Down Expand Up @@ -291,6 +307,10 @@ func Ask(qs []*Question, response interface{}, opts ...AskOpt) error {

// grab the user input and save it
ans, err := q.Prompt.Prompt(&options.PromptConfig)
// if SIGINT is recieved.
if err == terminal.InterruptErr {
options.OnInterrupt()
}
// if there was a problem
if err != nil {
return err
Expand Down
5 changes: 3 additions & 2 deletions survey_posix_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
"github.com/stretchr/testify/require"
)

func RunTest(t *testing.T, procedure func(*expect.Console), test func(terminal.Stdio) error) {
func RunTest(t *testing.T, procedure func(*expect.Console), test func(terminal.Stdio) error) error {
t.Parallel()

// Multiplex output to a buffer as well for the raw bytes.
Expand All @@ -28,7 +28,6 @@ func RunTest(t *testing.T, procedure func(*expect.Console), test func(terminal.S
}()

err = test(Stdio(c))
require.Nil(t, err)

// Close the slave end of the pty, and read the remaining bytes from the master end.
c.Tty().Close()
Expand All @@ -38,4 +37,6 @@ func RunTest(t *testing.T, procedure func(*expect.Console), test func(terminal.S

// Dump the terminal's screen.
t.Logf("\n%s", expect.StripTrailingEmptyLines(state.String()))

return err
}
72 changes: 64 additions & 8 deletions survey_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ type PromptTest struct {

func RunPromptTest(t *testing.T, test PromptTest) {
var answer interface{}
RunTest(t, test.procedure, func(stdio terminal.Stdio) error {
err := RunTest(t, test.procedure, func(stdio terminal.Stdio) error {
var err error
if p, ok := test.prompt.(wantsStdio); ok {
p.WithStdio(stdio)
Expand All @@ -40,12 +40,13 @@ func RunPromptTest(t *testing.T, test PromptTest) {
answer, err = test.prompt.Prompt(defaultPromptConfig())
return err
})
require.Nil(t, err)
require.Equal(t, test.expected, answer)
}

func RunPromptTestKeepFilter(t *testing.T, test PromptTest) {
var answer interface{}
RunTest(t, test.procedure, func(stdio terminal.Stdio) error {
err := RunTest(t, test.procedure, func(stdio terminal.Stdio) error {
var err error
if p, ok := test.prompt.(wantsStdio); ok {
p.WithStdio(stdio)
Expand All @@ -55,6 +56,7 @@ func RunPromptTestKeepFilter(t *testing.T, test PromptTest) {
answer, err = test.prompt.Prompt(config)
return err
})
require.Nil(t, err)
require.Equal(t, test.expected, answer)
}

Expand Down Expand Up @@ -133,7 +135,6 @@ func TestPagination_lastHalf(t *testing.T) {

func TestAsk(t *testing.T) {
t.Skip()
return
tests := []struct {
name string
questions []*Question
Expand Down Expand Up @@ -250,10 +251,10 @@ func TestAsk(t *testing.T) {
"pizza": true,
"commit-message": "Add editor prompt tests\n",
"commit-message-validated": "Add editor prompt tests\n",
"name": "Johnny Appleseed",
"day": []string{"Monday", "Wednesday"},
"password": "secret",
"color": "yellow",
"name": "Johnny Appleseed",
"day": []string{"Monday", "Wednesday"},
"password": "secret",
"color": "yellow",
},
},
{
Expand Down Expand Up @@ -305,9 +306,10 @@ func TestAsk(t *testing.T) {
test := test
t.Run(test.name, func(t *testing.T) {
answers := make(map[string]interface{})
RunTest(t, test.procedure, func(stdio terminal.Stdio) error {
err := RunTest(t, test.procedure, func(stdio terminal.Stdio) error {
return Ask(test.questions, &answers, WithStdio(stdio.In, stdio.Out, stdio.Err))
})
require.Nil(t, err)
require.Equal(t, test.expected, answers)
})
}
Expand All @@ -323,3 +325,57 @@ func TestAsk_returnsErrorIfTargetIsNil(t *testing.T) {
t.Error("Did not encounter error when asking with no where to record.")
}
}

func TestOnInterruptFunc(t *testing.T) {
// No Interrupt function set.
t.Run("No OnInterrupt", func(t *testing.T) {
answer := ""
err := RunTest(t, func(e *expect.Console) {
e.ExpectString("Are you a bot?")
e.SendLine(string(terminal.KeyInterrupt))
e.ExpectEOF()
}, func(t terminal.Stdio) error {
return AskOne(&Input{Message: "Are you a bot?"}, &answer,
WithStdio(t.In, t.Out, t.Err))
})

require.Equal(t, terminal.InterruptErr, err)
require.Equal(t, "", answer)
})

// Set global Interrupt function.
OnInterrupt = func() { fmt.Println("Ended abruptly!") }
t.Run("Global OnInterrupt", func(t *testing.T) {
answer := ""
err := RunTest(t, func(e *expect.Console) {
e.ExpectString("Are you a bot?")
e.SendLine(string(terminal.KeyInterrupt))
e.ExpectString("Ended abruptly!")
e.ExpectEOF()
}, func(t terminal.Stdio) error {
return AskOne(&Input{Message: "Are you a bot?"}, &answer,
WithStdio(t.In, t.Out, t.Err))
})

require.Equal(t, terminal.InterruptErr, err)
require.Equal(t, "", answer)
})

// Set local Interrupt function (overide global).
t.Run("Local Interrupt (override global", func(t *testing.T) {
answer := ""
err := RunTest(t, func(e *expect.Console) {
e.ExpectString("Are you a bot?")
e.SendLine(string(terminal.KeyInterrupt))
e.ExpectString("The end.")
e.ExpectEOF()
}, func(t terminal.Stdio) error {
return AskOne(&Input{Message: "Are you a bot?"}, &answer,
WithStdio(t.In, t.Out, t.Err),
WithInterruptFunc(func() { fmt.Println("The end.") }))
})

require.Equal(t, terminal.InterruptErr, err)
require.Equal(t, "", answer)
})
}
3 changes: 2 additions & 1 deletion survey_windows_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
expect "github.com/Netflix/go-expect"
)

func RunTest(t *testing.T, procedure func(*expect.Console), test func(terminal.Stdio) error) {
func RunTest(t *testing.T, procedure func(*expect.Console), test func(terminal.Stdio) error) error {
t.Skip("Windows does not support psuedoterminals")
return nil
}

0 comments on commit 90b418e

Please sign in to comment.