Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@
import org.apache.sshd.sftp.client.fs.SftpFileSystem;

import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
Expand Down Expand Up @@ -235,23 +235,28 @@ public String runRemote(String command) throws IOException {
private int runRemoteAndProcessLines(String command, Consumer<String> lineConsumer) throws IOException {
try (
ChannelExec channel = getSession().createExecChannel(command);
ByteArrayOutputStream out = new ByteArrayOutputStream();
ByteArrayOutputStream err = new ByteArrayOutputStream()) {
InputStream out = channel.getInvertedOut();
channel.setOut(out);
channel.setErr(err);
channel.open();
channel.waitFor(EnumSet.of(ClientChannelEvent.CLOSED), 0);
int readLines = 0;
try (BufferedReader reader = new BufferedReader(new InputStreamReader(out, StandardCharsets.UTF_8))) {
try (
BufferedReader reader = new BufferedReader(
new InputStreamReader(new ByteArrayInputStream(out.toByteArray()),
StandardCharsets.UTF_8))) {
Comment thread
leocook marked this conversation as resolved.
String line;
while ((line = reader.readLine()) != null) {
readLines++;
lineConsumer.accept(line);
}
Comment thread
leocook marked this conversation as resolved.
}
channel.waitFor(EnumSet.of(ClientChannelEvent.CLOSED), 0);
Integer exitStatus = channel.getExitStatus();
if (exitStatus == null || exitStatus != 0) {
throw new TaskException(
"Remote shell task error, exitStatus: " + exitStatus + " error message: " + err);
"Remote shell task error, exitStatus: " + exitStatus + " error message: "
+ new String(err.toByteArray(), StandardCharsets.UTF_8));
}
return readLines;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@

package org.apache.dolphinscheduler.plugin.task.remoteshell;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.doThrow;
Expand All @@ -32,16 +34,13 @@
import org.apache.dolphinscheduler.plugin.datasource.ssh.param.SSHDataSourceProcessor;
import org.apache.dolphinscheduler.plugin.task.api.TaskException;

import org.apache.commons.io.IOUtils;
import org.apache.commons.io.input.NullInputStream;
import org.apache.commons.lang3.SystemUtils;
import org.apache.sshd.client.channel.ChannelExec;
import org.apache.sshd.client.channel.ClientChannelEvent;
import org.apache.sshd.client.session.ClientSession;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.charset.StandardCharsets;
import java.util.EnumSet;

Expand Down Expand Up @@ -82,23 +81,29 @@
sshConnectionUtilsMockedStatic.close();
}

/** Helper: make channel write data to whichever OutputStream is passed to setOut(). */
private void mockChannelOutput(ChannelExec channel, String data) throws IOException {

Check warning on line 85 in dolphinscheduler-task-plugin/dolphinscheduler-task-remoteshell/src/test/java/org/apache/dolphinscheduler/plugin/task/remoteshell/RemoteExecutorTest.java

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Remove the declaration of thrown exception 'java.io.IOException', as it cannot be thrown from method's body.

See more on https://sonarcloud.io/project/issues?id=apache-dolphinscheduler&issues=AZ3otEeoIQccHbaj5o8K&open=AZ3otEeoIQccHbaj5o8K&pullRequest=18210
doAnswer(invocation -> {
OutputStream out = invocation.getArgument(0);
out.write(data.getBytes(StandardCharsets.UTF_8));
return null;
}).when(channel).setOut(any(OutputStream.class));
}

@Test
void testRunRemote() throws IOException {
RemoteExecutor remoteExecutor = spy(new RemoteExecutor(sshConnectionParam));
ChannelExec channel = Mockito.mock(ChannelExec.class, RETURNS_DEEP_STUBS);
when(clientSession.auth().verify().isSuccess()).thenReturn(true);
when(clientSession.createExecChannel(Mockito.anyString())).thenReturn(channel);
when(channel.getExitStatus()).thenReturn(1);
when(channel.getInvertedOut()).thenReturn(new NullInputStream());
Assertions.assertThrows(TaskException.class, () -> remoteExecutor.runRemote("ls -l"));

// Mock the streaming runRemote to simulate log output
String output = "total 26392\n" +
"dr-xr-xr-x. 6 root root 3072 Aug 15 2023 boot\n" +
"drwxr-xr-x 18 root root 3120 Sep 23 2023 dev\n" +
"drwxr-xr-x. 91 root root 4096 Sep 23 2023 etc\n";
InputStream inputStream = IOUtils.toInputStream(output, StandardCharsets.UTF_8);
when(channel.getInvertedOut()).thenReturn(inputStream);
mockChannelOutput(channel, output);
when(channel.getExitStatus()).thenReturn(0);
String actualOut = Assertions.assertDoesNotThrow(() -> remoteExecutor.runRemote("ls -l"));
Assertions.assertEquals(output, actualOut);
Expand Down Expand Up @@ -157,7 +162,6 @@
void getAllRemotePidStr() throws IOException {

RemoteExecutor remoteExecutor = spy(new RemoteExecutor(sshConnectionParam));
// Mock pstree output based on OS
if (SystemUtils.IS_OS_MAC) {
doReturn("-+= 9527 root\n \\-+= 9528 root").when(remoteExecutor).runRemote(anyString());
} else {
Expand Down Expand Up @@ -186,37 +190,32 @@
String taskId = "1234";
ChannelExec channel = Mockito.mock(ChannelExec.class, RETURNS_DEEP_STUBS);

// Mock getTaskPid to control the loop, return a valid pid 2 times, then return empty
doReturn("9527")
.doReturn("9527")
.doReturn("").when(remoteExecutor).getTaskPid(taskId);
when(clientSession.auth().verify().isSuccess()).thenReturn(true);
when(clientSession.createExecChannel(anyString())).thenReturn(channel);

// Mock the streaming runRemote to simulate log output
String logContent = "some log line 1\n"
+ "echo \"${setValue(my_prop=my_value)}\"\n"
+ "some log line 2\n";
InputStream inputStream = IOUtils.toInputStream(logContent, StandardCharsets.UTF_8);
when(channel.getInvertedOut()).thenReturn(inputStream);
mockChannelOutput(channel, logContent);
when(channel.getExitStatus()).thenReturn(0);

remoteExecutor.track(taskId);

// Verify that the output parameter was parsed and stored
Assertions.assertEquals(1, remoteExecutor.getTaskOutputParams().size());
Assertions.assertEquals("my_value", remoteExecutor.getTaskOutputParams().get("my_prop"));
}

@Test
void testRunRemoteWithEmptyOutput() throws Exception {
// Test empty output scenario (readLines = 0)
RemoteExecutor remoteExecutor = spy(new RemoteExecutor(sshConnectionParam));
ChannelExec channel = Mockito.mock(ChannelExec.class, RETURNS_DEEP_STUBS);

when(clientSession.auth().verify().isSuccess()).thenReturn(true);
when(clientSession.createExecChannel(anyString())).thenReturn(channel);
when(channel.getInvertedOut()).thenReturn(new ByteArrayInputStream(new byte[0]));
// no mockChannelOutput → setOut() is a no-op → ByteArrayOutputStream stays empty
when(channel.getExitStatus()).thenReturn(0);
when(channel.waitFor(EnumSet.of(ClientChannelEvent.CLOSED), 0))
.thenReturn(EnumSet.of(ClientChannelEvent.CLOSED));
Expand All @@ -227,47 +226,48 @@

@Test
void testRunRemoteWithNonZeroExitStatus() throws Exception {
// Test command failure scenario (exitStatus != 0)
RemoteExecutor remoteExecutor = spy(new RemoteExecutor(sshConnectionParam));
ChannelExec channel = Mockito.mock(ChannelExec.class, RETURNS_DEEP_STUBS);

when(clientSession.auth().verify().isSuccess()).thenReturn(true);
when(clientSession.createExecChannel(anyString())).thenReturn(channel);
when(channel.getInvertedOut()).thenReturn(IOUtils.toInputStream("error output", StandardCharsets.UTF_8));
mockChannelOutput(channel, "partial output before failure\n");
when(channel.getExitStatus()).thenReturn(1);
when(channel.waitFor(EnumSet.of(ClientChannelEvent.CLOSED), 0))
.thenReturn(EnumSet.of(ClientChannelEvent.CLOSED));

Assertions.assertThrows(TaskException.class, () -> remoteExecutor.runRemote("failing_command"));
TaskException ex =
Assertions.assertThrows(TaskException.class, () -> remoteExecutor.runRemote("failing_command"));
Assertions.assertTrue(ex.getMessage().contains("exitStatus: 1"));
}

@Test
void testRunRemoteWithNullExitStatus() throws Exception {
// Test null exitStatus scenario
RemoteExecutor remoteExecutor = spy(new RemoteExecutor(sshConnectionParam));
ChannelExec channel = Mockito.mock(ChannelExec.class, RETURNS_DEEP_STUBS);

when(clientSession.auth().verify().isSuccess()).thenReturn(true);
when(clientSession.createExecChannel(anyString())).thenReturn(channel);
when(channel.getInvertedOut()).thenReturn(IOUtils.toInputStream("some output", StandardCharsets.UTF_8));
mockChannelOutput(channel, "some output\n");
when(channel.getExitStatus()).thenReturn(null);
when(channel.waitFor(EnumSet.of(ClientChannelEvent.CLOSED), 0))
.thenReturn(EnumSet.of(ClientChannelEvent.CLOSED));

Assertions.assertThrows(TaskException.class, () -> remoteExecutor.runRemote("command"));
TaskException ex =
Assertions.assertThrows(TaskException.class, () -> remoteExecutor.runRemote("command"));
Assertions.assertTrue(ex.getMessage().contains("exitStatus: null"));
}

@Test
void testTrackWithEmptyLogOutput() throws Exception {
// Test track with empty log output (readLines = 0 scenario in track loop)
RemoteExecutor remoteExecutor = spy(new RemoteExecutor(sshConnectionParam));
String taskId = "1234";
ChannelExec channel = Mockito.mock(ChannelExec.class, RETURNS_DEEP_STUBS);

doReturn("9527").doReturn("").when(remoteExecutor).getTaskPid(taskId);
when(clientSession.auth().verify().isSuccess()).thenReturn(true);
when(clientSession.createExecChannel(anyString())).thenReturn(channel);
when(channel.getInvertedOut()).thenReturn(new ByteArrayInputStream(new byte[0]));
// no output → readLines=0 → sleep once, then getTaskPid returns "" → exit
when(channel.getExitStatus()).thenReturn(0);
when(channel.waitFor(EnumSet.of(ClientChannelEvent.CLOSED), 0))
.thenReturn(EnumSet.of(ClientChannelEvent.CLOSED));
Expand Down
Loading