diff --git a/platform/osInterface.go b/platform/osInterface.go index c0f47ed145..116e0ee5cd 100644 --- a/platform/osInterface.go +++ b/platform/osInterface.go @@ -1,6 +1,16 @@ package platform -type execClient struct{} +import ( + "time" +) + +const ( + defaultExecTimeout = 10 +) + +type execClient struct { + Timeout time.Duration +} //nolint:revive // ExecClient make sense type ExecClient interface { @@ -8,5 +18,13 @@ type ExecClient interface { } func NewExecClient() ExecClient { - return &execClient{} + return &execClient{ + Timeout: defaultExecTimeout * time.Second, + } +} + +func NewExecClientTimeout(timeout time.Duration) ExecClient { + return &execClient{ + Timeout: timeout, + } } diff --git a/platform/os_linux.go b/platform/os_linux.go index 0403ea422b..047c14dc12 100644 --- a/platform/os_linux.go +++ b/platform/os_linux.go @@ -5,6 +5,7 @@ package platform import ( "bytes" + "context" "fmt" "os" "os/exec" @@ -50,7 +51,7 @@ func GetOSInfo() string { } func GetProcessSupport() error { - p := &execClient{} + p := NewExecClient() cmd := fmt.Sprintf("ps -p %v -o comm=", os.Getpid()) _, err := p.ExecuteCommand(cmd) return err @@ -81,7 +82,12 @@ func (p *execClient) ExecuteCommand(command string) (string, error) { var stderr bytes.Buffer var out bytes.Buffer - cmd := exec.Command("sh", "-c", command) + + // Create a new context and add a timeout to it + ctx, cancel := context.WithTimeout(context.Background(), p.Timeout) + defer cancel() // The cancel should be deferred so resources are cleaned up + + cmd := exec.CommandContext(ctx, "sh", "-c", command) cmd.Stderr = &stderr cmd.Stdout = &out @@ -94,7 +100,7 @@ func (p *execClient) ExecuteCommand(command string) (string, error) { } func SetOutboundSNAT(subnet string) error { - p := execClient{} + p := NewExecClient() cmd := fmt.Sprintf("iptables -t nat -A POSTROUTING -m iprange ! --dst-range 168.63.129.16 -m addrtype ! --dst-type local ! -d %v -j MASQUERADE", subnet) _, err := p.ExecuteCommand(cmd) @@ -112,7 +118,7 @@ func ClearNetworkConfiguration() (bool, error) { } func KillProcessByName(processName string) error { - p := &execClient{} + p := NewExecClient() cmd := fmt.Sprintf("pkill -f %v", processName) _, err := p.ExecuteCommand(cmd) return err @@ -143,7 +149,7 @@ func GetOSDetails() (map[string]string, error) { } func GetProcessNameByID(pidstr string) (string, error) { - p := &execClient{} + p := NewExecClient() pidstr = strings.Trim(pidstr, "\n") cmd := fmt.Sprintf("ps -p %s -o comm=", pidstr) out, err := p.ExecuteCommand(cmd) @@ -159,7 +165,7 @@ func GetProcessNameByID(pidstr string) (string, error) { } func PrintDependencyPackageDetails() { - p := &execClient{} + p := NewExecClient() out, err := p.ExecuteCommand("iptables --version") out = strings.TrimSuffix(out, "\n") log.Printf("[cni-net] iptable version:%s, err:%v", out, err) diff --git a/platform/os_linux_test.go b/platform/os_linux_test.go new file mode 100644 index 0000000000..1848f22d03 --- /dev/null +++ b/platform/os_linux_test.go @@ -0,0 +1,29 @@ +package platform + +import ( + "testing" + "time" +) + +// Command execution time is more than timeout, so ExecuteCommand should return error +func TestExecuteCommandTimeout(t *testing.T) { + const timeout = 2 * time.Second + client := NewExecClientTimeout(timeout) + + _, err := client.ExecuteCommand("sleep 3") + if err == nil { + t.Errorf("TestExecuteCommandTimeout should have returned timeout error") + } + t.Logf("%s", err.Error()) +} + +// Command execution time is less than timeout, so ExecuteCommand should work without error +func TestExecuteCommandNoTimeout(t *testing.T) { + const timeout = 2 * time.Second + client := NewExecClientTimeout(timeout) + + _, err := client.ExecuteCommand("sleep 1") + if err != nil { + t.Errorf("TestExecuteCommandNoTimeout failed with error %v", err) + } +}