diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index 5cbd5aba18f..a1d4cdcb240 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -27,7 +27,8 @@ on: - 'docs/**' - '*.md' - '*.html' - - 'src/main/python/docs/**' + - 'src/test/**' + - 'src/assembly/**' - 'dev/**' branches: - main @@ -36,7 +37,7 @@ on: - 'docs/**' - '*.md' - '*.html' - - 'src/main/python/docs/**' + - 'src/test/**' - 'dev/**' branches: - main @@ -94,7 +95,7 @@ jobs: architecture: 'x64' - name: Install pip Dependencies - run: pip install numpy py4j wheel scipy sklearn requests pandas + run: pip install numpy py4j wheel scipy sklearn requests pandas unittest-parallel - name: Build Python Package run: | @@ -107,15 +108,16 @@ jobs: export PATH=$SYSTEMDS_ROOT/bin:$PATH export SYSDS_QUIET=1 cd src/main/python - python -m unittest discover -s tests -p 'test_*.py' + unittest-parallel -t . -s tests --module-fixtures + # python -m unittest discover -s tests -p 'test_*.py' echo "Exit Status: " $? - # TODO debug and fix JDK11 environment - #- name: Run all python tests no environment - # run: | - # cd src/main/python - # python -m unittest discover -s tests -p 'test_*.py' - # echo "Exit Status: " $? + - name: Run all python tests no environment + run: | + cd src/main/python + unittest-parallel -t . -s tests --module-fixtures + # python -m unittest discover -s tests -p 'test_*.py' + echo "Exit Status: " $? - name: Run Federated Python Tests run: | diff --git a/src/main/java/org/apache/sysds/api/DMLOptions.java b/src/main/java/org/apache/sysds/api/DMLOptions.java index 5e10d73141e..1b911bcff0e 100644 --- a/src/main/java/org/apache/sysds/api/DMLOptions.java +++ b/src/main/java/org/apache/sysds/api/DMLOptions.java @@ -71,6 +71,7 @@ public class DMLOptions { public boolean lineage_debugger = false; // whether enable lineage debugger public boolean fedWorker = false; public int fedWorkerPort = -1; + public int pythonPort = -1; public boolean checkPrivacy = false; // Check which privacy constraints are loaded and checked during federated execution public boolean federatedCompilation = false; // Compile federated instructions based on input federation state and privacy constraints. @@ -242,6 +243,10 @@ else if (lineageType.equalsIgnoreCase("debugger")) } } + if (line.hasOption("python")){ + dmlOptions.pythonPort = Integer.parseInt(line.getOptionValue("python")); + } + // Named arguments map is created as ("$K, 123), ("$X", "X.csv"), etc if (line.hasOption("nvargs")){ String varNameRegex = "^[a-zA-Z]([a-zA-Z0-9_])*$"; @@ -302,8 +307,8 @@ private static Options createCLIOptions() { .hasOptionalArg().create("gpu"); Option debugOpt = OptionBuilder.withDescription("runs in debug mode; default off") .create("debug"); - Option pythonOpt = OptionBuilder.withDescription("parses Python-like DML") - .create("python"); + Option pythonOpt = OptionBuilder.withDescription("Python Context start with port argument for communication to python") + .isRequired().hasArg().create("python"); Option fileOpt = OptionBuilder.withArgName("filename") .withDescription("specifies dml/pydml file to execute; path can be local/hdfs/gpfs (prefixed with appropriate URI)") .isRequired().hasArg().create("f"); @@ -332,7 +337,6 @@ private static Options createCLIOptions() { options.addOption(execOpt); options.addOption(gpuOpt); options.addOption(debugOpt); - options.addOption(pythonOpt); options.addOption(lineageOpt); options.addOption(fedOpt); options.addOption(checkPrivacy); @@ -344,7 +348,8 @@ private static Options createCLIOptions() { .addOption(fileOpt) .addOption(cleanOpt) .addOption(helpOpt) - .addOption(fedOpt); + .addOption(fedOpt) + .addOption(pythonOpt); fileOrScriptOpt.setRequired(true); options.addOptionGroup(fileOrScriptOpt); diff --git a/src/main/java/org/apache/sysds/api/DMLScript.java b/src/main/java/org/apache/sysds/api/DMLScript.java index 658b99352b6..ebb7f3d3108 100644 --- a/src/main/java/org/apache/sysds/api/DMLScript.java +++ b/src/main/java/org/apache/sysds/api/DMLScript.java @@ -356,7 +356,7 @@ public static String readDMLScript( boolean isFile, String scriptOrFilename ) // (core compilation and execute) //////// - private static void loadConfiguration(String fnameOptConfig) throws IOException { + public static void loadConfiguration(String fnameOptConfig) throws IOException { DMLConfig dmlconf = DMLConfig.readConfigurationFile(fnameOptConfig); ConfigurationManager.setGlobalConfig(dmlconf); CompilerConfig cconf = OptimizerUtils.constructCompilerConfig(dmlconf); diff --git a/src/main/java/org/apache/sysds/api/PythonDMLScript.java b/src/main/java/org/apache/sysds/api/PythonDMLScript.java index e7251e784bd..d93409d8b2a 100644 --- a/src/main/java/org/apache/sysds/api/PythonDMLScript.java +++ b/src/main/java/org/apache/sysds/api/PythonDMLScript.java @@ -26,40 +26,29 @@ import py4j.GatewayServer; import py4j.GatewayServerListener; +import py4j.Py4JNetworkException; import py4j.Py4JServerConnection; public class PythonDMLScript { - private static final Log LOG = LogFactory.getLog(PythonDMLScript.class.getName()); + private Connection _connection; /** * Entry point for Python API. * - * The system returns with exit code 1, if the startup process fails, and 0 if the startup was successful. - * * @param args Command line arguments. + * @throws Exception Throws exceptions if there is issues in startup or while running. */ - public static void main(String[] args) { - if(args.length != 1) { - throw new IllegalArgumentException("Python DML Script should be initialized with a singe number argument"); - } - else { - int port = Integer.parseInt(args[0]); - start(port); - } + public static void main(String[] args) throws Exception { + final DMLOptions dmlOptions = DMLOptions.parseCLArguments(args); + DMLScript.loadConfiguration(dmlOptions.configFile); + start(dmlOptions.pythonPort); } - private static void start(int port) { - try { - // TODO Add argument parsing here. - GatewayServer GwS = new GatewayServer(new PythonDMLScript(), port); - GwS.addListener(new DMLGateWayListener()); - GwS.start(); - } - catch(py4j.Py4JNetworkException ex) { - LOG.error("Py4JNetworkException while executing the GateWay. Is a server instance already running?"); - System.exit(-1); - } + private static void start(int port) throws Py4JNetworkException { + GatewayServer GwS = new GatewayServer(new PythonDMLScript(), port); + GwS.addListener(new DMLGateWayListener()); + GwS.start(); } private PythonDMLScript() { @@ -79,50 +68,53 @@ private PythonDMLScript() { public Connection getConnection() { return _connection; } -} - -class DMLGateWayListener implements GatewayServerListener { - private static final Log LOG = LogFactory.getLog(DMLGateWayListener.class.getName()); - - @Override - public void connectionError(Exception e) { - LOG.warn("Connection error: " + e.getMessage()); - } - - @Override - public void connectionStarted(Py4JServerConnection gatewayConnection) { - LOG.debug("Connection Started: " + gatewayConnection.toString()); - } - - @Override - public void connectionStopped(Py4JServerConnection gatewayConnection) { - LOG.debug("Connection stopped: " + gatewayConnection.toString()); - } - - @Override - public void serverError(Exception e) { - LOG.error("Server Error " + e.getMessage()); - } - - @Override - public void serverPostShutdown() { - LOG.info("Shutdown done"); - System.exit(0); - } - - @Override - public void serverPreShutdown() { - LOG.info("Starting JVM shutdown"); - } - - @Override - public void serverStarted() { - // message the python interface that the JVM is ready. - System.out.println("GatewayServer Started"); - } - - @Override - public void serverStopped() { - System.out.println("GatewayServer Stopped"); + + protected static class DMLGateWayListener implements GatewayServerListener { + private static final Log LOG = LogFactory.getLog(DMLGateWayListener.class.getName()); + + @Override + public void connectionError(Exception e) { + LOG.warn("Connection error: " + e.getMessage()); + System.exit(1); + } + + @Override + public void connectionStarted(Py4JServerConnection gatewayConnection) { + LOG.debug("Connection Started: " + gatewayConnection.toString()); + } + + @Override + public void connectionStopped(Py4JServerConnection gatewayConnection) { + LOG.debug("Connection stopped: " + gatewayConnection.toString()); + } + + @Override + public void serverError(Exception e) { + LOG.error("Server Error " + e.getMessage()); + } + + @Override + public void serverPostShutdown() { + LOG.info("Shutdown done"); + System.exit(0); + } + + @Override + public void serverPreShutdown() { + LOG.info("Starting JVM shutdown"); + } + + @Override + public void serverStarted() { + // message the python interface that the JVM is ready. + System.out.println("GatewayServer Started"); + } + + @Override + public void serverStopped() { + System.out.println("GatewayServer Stopped"); + System.exit(0); + } } } + diff --git a/src/main/java/org/apache/sysds/runtime/io/IOUtilFunctions.java b/src/main/java/org/apache/sysds/runtime/io/IOUtilFunctions.java index 1a366c6f9ac..9ad3f0aeeea 100644 --- a/src/main/java/org/apache/sysds/runtime/io/IOUtilFunctions.java +++ b/src/main/java/org/apache/sysds/runtime/io/IOUtilFunctions.java @@ -93,7 +93,7 @@ public static FileSystem getFileSystem(Configuration conf) throws IOException { try{ return FileSystem.get(conf); } catch(NoClassDefFoundError err) { - throw new IOException(err.getMessage()); + throw new IOException(err.getMessage(), err); } } @@ -101,7 +101,7 @@ public static FileSystem getFileSystem(Path fname, Configuration conf) throws IO try { return FileSystem.get(fname.toUri(), conf); } catch(NoClassDefFoundError err) { - throw new IOException(err.getMessage()); + throw new IOException(err.getMessage(), err); } } diff --git a/src/main/python/systemds/context/systemds_context.py b/src/main/python/systemds/context/systemds_context.py index afa38c23286..474b7ac40f7 100644 --- a/src/main/python/systemds/context/systemds_context.py +++ b/src/main/python/systemds/context/systemds_context.py @@ -25,6 +25,7 @@ import os import socket import sys +import tracemalloc from glob import glob from queue import Queue from subprocess import PIPE, Popen @@ -42,6 +43,7 @@ from systemds.utils.consts import VALID_INPUT_TYPES from systemds.utils.helpers import get_module_dir +tracemalloc.start() class SystemDSContext(object): """A context with a connection to a java instance with which SystemDS operations are executed. @@ -60,26 +62,13 @@ def __init__(self, port: int = -1): Standard out and standard error form the JVM is also handled in this class, filling up Queues, that can be read from to get the printed statements from the JVM. """ - command = self.__build_startup_command() - process, port = self.__try_startup(command, port) - - # Handle Std out from the subprocess. - self.__stdout = Queue() - self.__stderr = Queue() - - self.__stdout_thread = Thread(target=self.__enqueue_output, args=( - process.stdout, self.__stdout), daemon=True) - - self.__stderr_thread = Thread(target=self.__enqueue_output, args=( - process.stderr, self.__stderr), daemon=True) - - self.__stdout_thread.start() - self.__stderr_thread.start() + actual_port = self.__start(port) + process = self.__process + if process.poll() is None: + self.__start_gateway(actual_port) + else: + self.exception_and_close("Java process stopped before gateway could connect") - # Py4j connect to the started process. - gwp = GatewayParameters(port=port, eager_load=True) - self.java_gateway = JavaGateway( - gateway_parameters=gwp, java_process=process) def get_stdout(self, lines: int = -1): """Getter for the stdout of the java subprocess @@ -103,14 +92,13 @@ def get_stderr(self, lines: int = -1): else: return [self.__stderr.get() for x in range(lines)] - def exception_and_close(self, exception_str: str, trace_back_limit : int = None): + def exception_and_close(self, exception_str: str, trace_back_limit: int = None): """ Method for printing exception, printing stdout and error, while also closing the context correctly. :param e: the exception thrown """ - # e = sys.exc_info()[0] message = "" stdOut = self.get_stdout() if stdOut: @@ -118,101 +106,166 @@ def exception_and_close(self, exception_str: str, trace_back_limit : int = None) stdErr = self.get_stderr() if stdErr: message += "standard error :\n" + "\n".join(stdErr) + message += "\n\n" message += exception_str sys.tracebacklimit = trace_back_limit self.close() raise RuntimeError(message) - def __try_startup(self, command, port, rep=0): - """ Try to perform startup of system. + def __try_startup(self, command) -> bool: - :param command: The command to execute for starting JMLC content - :param port: The port to try to connect to to. - :param rep: The number of repeated tries to startup the jvm. - """ - if port == -1: - assignedPort = self.__get_open_port() - elif rep == 0: - assignedPort = port - else: - assignedPort = self.__get_open_port() - fullCommand = [] - fullCommand.extend(command) - fullCommand.append(str(assignedPort)) - process = Popen(fullCommand, stdout=PIPE, stdin=PIPE, stderr=PIPE) + self.__process = Popen(command, stdout=PIPE, stdin=PIPE, stderr=PIPE) - try: - self.__verify_startup(process) - - return process, assignedPort - except Exception as e: - self.close() - if rep > 3: - raise Exception( - "Failed to start SystemDS context with " + str(rep) + " repeated tries") - else: - rep += 1 - print("Failed to startup JVM process, retrying: " + str(rep)) - sleep(0.5) - return self.__try_startup(command, port, rep) - - def __verify_startup(self, process): - first_stdout = process.stdout.readline() - if(not b"GatewayServer Started" in first_stdout): - stderr = process.stderr.readline().decode("utf-8") - if(len(stderr) > 1): - raise Exception( - "Exception in startup of GatewayServer: " + stderr) - outputs = [] - outputs.append(first_stdout.decode("utf-8")) - max_tries = 10 - for i in range(max_tries): - next_line = process.stdout.readline() - if(b"GatewayServer Started" in next_line): - print("WARNING: Stdout corrupted by prints: " + str(outputs)) - print("Startup success") - break - else: - outputs.append(next_line) - - if (i == max_tries-1): - raise Exception("Error in startup of systemDS gateway process: \n gateway StdOut: " + str( - outputs) + " \n gateway StdErr" + process.stderr.readline().decode("utf-8")) - - def __build_startup_command(self): + # Handle Std out from the subprocess. + self.__stdout = Queue() + self.__stderr = Queue() + + self.__stdout_thread = Thread(target=self.__enqueue_output, args=( + self.__process.stdout, self.__stdout), daemon=True) + + self.__stderr_thread = Thread(target=self.__enqueue_output, args=( + self.__process.stderr, self.__stderr), daemon=True) + + self.__stdout_thread.start() + self.__stderr_thread.start() + + return self.__verify_startup(command) + + + def __verify_startup(self, command) -> bool: + first_stdout = self.get_stdout() + if(not "GatewayServer Started" in first_stdout): + return self.__verify_startup_retry( command) + else: + return True + + def __verify_startup_retry(self, command, retry: int=1) -> bool: + sleep(0.8 * retry) + stdout = self.get_stdout() + if "GatewayServer Started" in stdout: + return True, "" + elif retry < 3: # retry 3 times + return self.__verify_startup_retry(command, retry + 1) + else: + error_message = "Error in startup of systemDS gateway process:" + error_message += "\n" + " ".join(command) + stderr = self.get_stderr() + if len(stderr) > 0: + error_message += "\n" + "\n".join(stderr) + if len(stdout) > 0: + error_message += "\n\n" + "\n".join(stdout) + self.__error_message = error_message + return False + + def __build_startup_command(self, port: int): command = ["java", "-cp"] root = os.environ.get("SYSTEMDS_ROOT") if root == None: # If there is no systemds install default to use the PIP packaged java files. - root = os.path.join(get_module_dir(), "systemds-java") + root = os.path.join(get_module_dir()) # nt means its Windows cp_separator = ";" if os.name == "nt" else ":" if os.environ.get("SYSTEMDS_ROOT") != None: - lib_cp = os.path.join(root, "target", "lib", "*") - systemds_cp = os.path.join(root, "target", "SystemDS.jar") - classpath = cp_separator.join([lib_cp, systemds_cp]) - - command.append(classpath) - files = glob(os.path.join(root, "conf", "log4j*.properties")) - if len(files) > 1: - print( - "WARNING: Multiple logging files found selecting: " + files[0]) - if len(files) == 0: - print("WARNING: No log4j file found at: " - + os.path.join(root, "conf") - + " therefore using default settings") + lib_release = os.path.join(root, "lib") + lib_cp = os.path.join(root, "target", "lib") + if os.path.exists(lib_release): + classpath = cp_separator.join([os.path.join(lib_release, '*')]) + elif os.path.exists(lib_cp): + systemds_cp = os.path.join(root, "target", "SystemDS.jar") + classpath = cp_separator.join( + [os.path.join(lib_cp, '*'), systemds_cp]) else: - command.append("-Dlog4j.configuration=file:" + files[0]) + raise ValueError( + "Invalid setup at SYSTEMDS_ROOT env variable path") + else: + lib1 = os.path.join(root, "lib", "*") + lib2 = os.path.join(root, "lib") + classpath = cp_separator.join([lib1, lib2]) + + command.append(classpath) + + files = glob(os.path.join(root, "conf", "log4j*.properties")) + if len(files) > 1: + print( + "WARNING: Multiple logging files found selecting: " + files[0]) + if len(files) == 0: + print("WARNING: No log4j file found at: " + + os.path.join(root, "conf") + + " therefore using default settings") else: - lib_cp = os.path.join(root, "lib", "*") - command.append(lib_cp) + command.append("-Dlog4j.configuration=file:" + files[0]) command.append("org.apache.sysds.api.PythonDMLScript") - return command + files = glob(os.path.join(root, "conf", "SystemDS*.xml")) + if len(files) > 1: + print( + "WARNING: Multiple config files found selecting: " + files[0]) + if len(files) == 0: + print("WARNING: No log4j file found at: " + + os.path.join(root, "conf") + + " therefore using default settings") + else: + command.append("-config") + command.append(files[0]) + + if port == -1: + actual_port = self.__get_open_port() + else: + actual_port = port + + command.append("--python") + command.append(str(actual_port)) + + return command, actual_port + + + def __start(self, port:int): + command, actual_port = self.__build_startup_command(port) + success = self.__try_startup(command) + + if not success: + retry = 1 + while not success and retry < 3: + self.__kill_Popen(self.__process) + # retry after waiting a bit. + sleep(3 * retry) + self.close() + self.__error_message = None + success, command, actual_port = self.__retry_start(retry) + retry = retry + 1 + if not success: + self.exception_and_close(self.__error_message) + return actual_port + + + def __retry_start(self, ret): + command, actual_port = self.__build_startup_command(-1) + success = self.__try_startup(command) + return success, command, actual_port + + def __start_gateway(self, actual_port: int): + process = self.__process + gwp = GatewayParameters(port=actual_port, eager_load=True) + self.__retry_start_gateway(process, gwp) + + def __retry_start_gateway(self,process:Popen, gwp:GatewayParameters, retry:int = 0 ): + try: + self.java_gateway = JavaGateway(gateway_parameters=gwp, java_process=process) + self.__process = None # On success clear process variable + return + except: + sleep(3 * retry) + if retry < 3: + self.__retry_start_gateway(process, gwp, retry + 1) + return + else: + e = "Error in startup of Java Gateway" + self.exception_and_close(e) + def __enter__(self): return self @@ -224,26 +277,28 @@ def __exit__(self, exc_type, exc_val, exc_tb): def close(self): """Close the connection to the java process and do necessary cleanup.""" - if(self.__stdout_thread.is_alive()): + if hasattr(self, 'java_gateway'): + self.__kill_Popen(self.java_gateway.java_process) + self.java_gateway.shutdown() + if hasattr(self, '__process'): + print("Has process variable") + self.__kill_Popen(self.__process) + if hasattr(self, '__stdout_thread') and self.__stdout_thread.is_alive(): self.__stdout_thread.join(0) - if(self.__stdout_thread.is_alive()): + if hasattr(self, '__stderr_thread') and self.__stderr_thread.is_alive(): self.__stderr_thread.join(0) - pid = self.java_gateway.java_process.pid - if self.java_gateway.java_gateway_server is not None: - try: - self.java_gateway.shutdown(True) - except Py4JNetworkError as e: - if "Gateway is not connected" not in str(e): - self.java_gateway.java_process.kill() - os.kill(pid, 14) + def __kill_Popen(self, process: Popen): + process.kill() + process.__exit__(None, None, None) def __enqueue_output(self, out, queue): """Method for handling the output from java. It is locating the string handeling inside a different thread, since the 'out.readline' is a blocking command. """ for line in iter(out.readline, b""): - queue.put(line.decode("utf-8").strip()) + line_string = line.decode("utf-8") + queue.put(line_string.strip()) def __get_open_port(self): """Get a random available port. @@ -291,7 +346,7 @@ def seq(self, start: Union[float, int], stop: Union[float, int] = None, def rand(self, rows: int, cols: int, min: Union[float, int] = None, max: Union[float, int] = None, pdf: str = "uniform", sparsity: Union[float, int] = None, seed: Union[float, int] = None, - lambd: Union[float, int] = 1) -> 'Matrix': + lamb: Union[float, int] = 1) -> 'Matrix': """Generates a matrix filled with random values :param sds_context: SystemDS context @@ -299,26 +354,26 @@ def rand(self, rows: int, cols: int, :param cols: number of cols :param min: min value for cells :param max: max value for cells - :param pdf: "uniform"/"normal"/"poison" distribution + :param pdf: probability distribution function: "uniform"/"normal"/"poison" distribution :param sparsity: fraction of non-zero cells :param seed: random seed - :param lambd: lamda value for "poison" distribution + :param lamb: lambda value for "poison" distribution :return: """ - available_pdfs = ["uniform", "normal", "poisson"] + available_pdf = ["uniform", "normal", "poisson"] if rows < 0: raise ValueError("In rand statement, can only assign rows a long (integer) value >= 0 " "-- attempted to assign value: {r}".format(r=rows)) if cols < 0: raise ValueError("In rand statement, can only assign cols a long (integer) value >= 0 " "-- attempted to assign value: {c}".format(c=cols)) - if pdf not in available_pdfs: + if pdf not in available_pdf: raise ValueError("The pdf passed is invalid! given: {g}, expected: {e}".format( - g=pdf, e=available_pdfs)) + g=pdf, e=available_pdf)) pdf = '\"' + pdf + '\"' named_input_nodes = { - 'rows': rows, 'cols': cols, 'pdf': pdf, 'lambda': lambd} + 'rows': rows, 'cols': cols, 'pdf': pdf, 'lambda': lamb} if min is not None: named_input_nodes['min'] = min if max is not None: @@ -357,7 +412,11 @@ def read(self, path: os.PathLike, **kwargs: Dict[str, VALID_INPUT_TYPES]) -> Ope output_type = OutputType.from_str(kwargs.get("value_type", None)) kwargs["value_type"] = f'"{output_type.name}"' return Scalar(self, "read", [f'"{path}"'], named_input_nodes=kwargs, output_type=output_type) - + elif data_type == "list": + # Reading a list have no extra arguments. + return List(self, "read", [f'"{path}"']) + + kwargs["data_type"] = None print("WARNING: Unknown type read please add a mtd file, or specify in arguments") return OperationNode(self, "read", [f'"{path}"'], named_input_nodes=kwargs) diff --git a/src/main/python/systemds/operator/nodes/list.py b/src/main/python/systemds/operator/nodes/list.py index 6f5bfb1715e..6ad69cac2a7 100644 --- a/src/main/python/systemds/operator/nodes/list.py +++ b/src/main/python/systemds/operator/nodes/list.py @@ -80,7 +80,7 @@ def code_line(self, var_name: str, unnamed_input_vars: Sequence[str], unnamed_input_vars, named_input_vars) return f'{var_name}={self.operation}({inputs_comma_sep});' - def compute(self, verbose: bool = False, lineage: bool = False) -> Union[np.array]: + def compute(self, verbose: bool = False, lineage: bool = False) -> np.array: return super().compute(verbose, lineage) def __str__(self): diff --git a/src/main/python/systemds/operator/nodes/list_access.py b/src/main/python/systemds/operator/nodes/list_access.py index a954f9c6ac7..170820694ac 100644 --- a/src/main/python/systemds/operator/nodes/list_access.py +++ b/src/main/python/systemds/operator/nodes/list_access.py @@ -21,10 +21,7 @@ __all__ = ["ListAccess"] -from typing import Dict, Iterable, Sequence, Tuple, Union - -import numpy as np -from py4j.java_gateway import JavaObject +from typing import Dict, Sequence from systemds.operator import Frame, Matrix, OperationNode, Scalar from systemds.script_building.dag import OutputType diff --git a/src/main/python/tests/algorithms/test_kmeans.py b/src/main/python/tests/algorithms/test_kmeans.py index 328cdd4a591..6369749cee0 100644 --- a/src/main/python/tests/algorithms/test_kmeans.py +++ b/src/main/python/tests/algorithms/test_kmeans.py @@ -82,7 +82,7 @@ def test_500x2(self): self.assertTrue(len(corners) == 4) - def generate_matrices_for_k_means(self, dims: (int, int), seed: int = 1234): + def generate_matrices_for_k_means(self, dims, seed: int = 1234): np.random.seed(seed) mu, sigma = 0, 0.1 s = np.random.normal(mu, sigma, dims[0] * dims[1]) diff --git a/src/main/python/tests/basics/test_context_creation.py b/src/main/python/tests/basics/test_context_creation.py new file mode 100644 index 00000000000..1a70deb4ca6 --- /dev/null +++ b/src/main/python/tests/basics/test_context_creation.py @@ -0,0 +1,51 @@ +# ------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# ------------------------------------------------------------- + +import unittest + +from systemds.context import SystemDSContext + + +class TestContextCreation(unittest.TestCase): + + def test_same_port(self): + # Same port should graciously change port + sds1 = SystemDSContext(port=9415) + sds2 = SystemDSContext(port=9415) + sds1.close() + sds2.close() + + def test_create_10_contexts(self): + # Creating multiple contexts and closing them should be no problem. + for _ in range(0, 10): + SystemDSContext().close() + + def test_create_multiple_context(self): + # Creating multiple contexts in sequence but open at the same time is okay. + a = SystemDSContext() + b = SystemDSContext() + c = SystemDSContext() + d = SystemDSContext() + + a.close() + b.close() + c.close() + d.close() diff --git a/src/main/python/tests/frame/test_hyperband.py b/src/main/python/tests/frame/test_hyperband.py index 78a12c7226b..d8cd9586898 100644 --- a/src/main/python/tests/frame/test_hyperband.py +++ b/src/main/python/tests/frame/test_hyperband.py @@ -54,32 +54,29 @@ def tearDown(self): pass def test_hyperband(self): - if "SYSTEMDS_ROOT" in os.environ: - x_train = self.sds.from_numpy(self.X_train) - y_train = self.sds.from_numpy(self.y_train) - x_val = self.sds.from_numpy(self.X_val) - y_val = self.sds.from_numpy(self.y_val) - paramRanges = self.sds.from_numpy(self.param_ranges) - params = self.params - [best_weights_mat, opt_hyper_params_df] = hyperband( - X_train=x_train, - y_train=y_train, - X_val=x_val, - y_val=y_val, - params=params, - paramRanges=paramRanges, - ).compute() - self.assertTrue(isinstance(best_weights_mat, np.ndarray)) - self.assertTrue(best_weights_mat.shape[0] == self.X_train.shape[1]) - self.assertTrue(best_weights_mat.shape[1] == self.y_train.shape[1]) + x_train = self.sds.from_numpy(self.X_train) + y_train = self.sds.from_numpy(self.y_train) + x_val = self.sds.from_numpy(self.X_val) + y_val = self.sds.from_numpy(self.y_val) + paramRanges = self.sds.from_numpy(self.param_ranges) + params = self.params + [best_weights_mat, opt_hyper_params_df] = hyperband( + X_train=x_train, + y_train=y_train, + X_val=x_val, + y_val=y_val, + params=params, + paramRanges=paramRanges, + ).compute() + self.assertTrue(isinstance(best_weights_mat, np.ndarray)) + self.assertTrue(best_weights_mat.shape[0] == self.X_train.shape[1]) + self.assertTrue(best_weights_mat.shape[1] == self.y_train.shape[1]) - self.assertTrue(isinstance(opt_hyper_params_df, pd.DataFrame)) - self.assertTrue(opt_hyper_params_df.shape[1] == 1) - for i, hyper_param in enumerate(opt_hyper_params_df.values.flatten().tolist()): - self.assertTrue( - self.min_max_params[i][0] <= hyper_param <= self.min_max_params[i][1]) - else: - print("to enable hyperband tests, set SYSTEMDS_ROOT") + self.assertTrue(isinstance(opt_hyper_params_df, pd.DataFrame)) + self.assertTrue(opt_hyper_params_df.shape[1] == 1) + for i, hyper_param in enumerate(opt_hyper_params_df.values.flatten().tolist()): + self.assertTrue( + self.min_max_params[i][0] <= hyper_param <= self.min_max_params[i][1]) if __name__ == "__main__": diff --git a/src/main/python/tests/frame/test_write_read.py b/src/main/python/tests/frame/test_write_read.py index ffa417ef784..cbbad68c76d 100644 --- a/src/main/python/tests/frame/test_write_read.py +++ b/src/main/python/tests/frame/test_write_read.py @@ -19,9 +19,7 @@ # # ------------------------------------------------------------- -import os import shutil -import sys import unittest import pandas as pd @@ -63,7 +61,8 @@ def test_write_read_binary(self): def test_write_read_csv(self): frame = self.sds.from_pandas(self.df) frame.write(self.temp_dir + "02", header=True, format="csv").compute() - NX = self.sds.read(self.temp_dir + "02", data_type="frame", format="csv") + NX = self.sds.read(self.temp_dir + "02", + data_type="frame", format="csv") result_df = NX.compute() self.assertTrue(isinstance(result_df, pd.DataFrame)) self.assertTrue(self.df.equals(result_df)) diff --git a/src/main/python/tests/lineage/test_lineagetrace.py b/src/main/python/tests/lineage/test_lineagetrace.py index 9f7528b94c9..a2237979631 100644 --- a/src/main/python/tests/lineage/test_lineagetrace.py +++ b/src/main/python/tests/lineage/test_lineagetrace.py @@ -21,11 +21,9 @@ import os import shutil -import sys import unittest from systemds.context import SystemDSContext -from systemds.utils.helpers import get_module_dir os.environ['SYSDS_QUIET'] = "1" @@ -48,29 +46,27 @@ def tearDownClass(cls): def tearDown(self): shutil.rmtree(temp_dir, ignore_errors=True) + @unittest.skipIf("SYSTEMDS_ROOT" not in os.environ, "The test is skipped if SYSTEMDS_ROOT is not set, this is required for this tests since it use the bin/systemds file to execute a reference") def test_compare_trace1(self): # test getLineageTrace() on an intermediate - if "SYSTEMDS_ROOT" in os.environ: - m = self.sds.full((10, 10), 1) - m_res = m + m - - python_trace = [x.strip().split("°") - for x in m_res.get_lineage_trace().split("\n")] - - dml_script = ( - "x = matrix(1, rows=10, cols=10);\n" - "y = x + x;\n" - "print(lineage(y));\n" - ) - - sysds_trace = create_execute_and_trace_dml(dml_script, "trace1") - - # It is not garantied, that the two lists 100% align to be the same. - # Therefore for now, we only compare if the command is the same, in same order. - python_trace_commands = [x[:1] for x in python_trace] - dml_script_commands = [x[:1] for x in sysds_trace] - self.assertEqual(python_trace_commands[0], dml_script_commands[0]) - else: - print("to enable lineage tests, set SYSTEMDS_ROOT") + m = self.sds.full((10, 10), 1) + m_res = m + m + + python_trace = [x.strip().split("°") + for x in m_res.get_lineage_trace().split("\n")] + + dml_script = ( + "x = matrix(1, rows=10, cols=10);\n" + "y = x + x;\n" + "print(lineage(y));\n" + ) + + sysds_trace = create_execute_and_trace_dml(dml_script, "trace1") + + # It is not garantied, that the two lists 100% align to be the same. + # Therefore for now, we only compare if the command is the same, in same order. + python_trace_commands = [x[:1] for x in python_trace] + dml_script_commands = [x[:1] for x in sysds_trace] + self.assertEqual(python_trace_commands[0], dml_script_commands[0]) # TODO add more tests cases. diff --git a/src/main/python/tests/list/test_list_readwrite.py b/src/main/python/tests/list/test_list_readwrite.py new file mode 100644 index 00000000000..0ec0cb51a91 --- /dev/null +++ b/src/main/python/tests/list/test_list_readwrite.py @@ -0,0 +1,63 @@ +# ------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# ------------------------------------------------------------- + +import shutil +import unittest + +import numpy as np +from systemds.context import SystemDSContext + + +class TestListOperations(unittest.TestCase): + + sds: SystemDSContext = None + temp_dir: str = "tests/list/tmp/readwrite/" + + @classmethod + def setUpClass(cls): + cls.sds = SystemDSContext() + + @classmethod + def tearDownClass(cls): + cls.sds.close() + shutil.rmtree(cls.temp_dir) + + def test_write_followed_by_read(self): + ''' Test write and read of lists variables in python. + Since we do not support serializing a list (from java to python) yet we + read and compute each list element when reading again + ''' + m1 = np.array([[1., 2., 3.]]) + m1p = self.sds.from_numpy(m1) + m2 = np.array([[4., 5., 6.]]) + m2p = self.sds.from_numpy(m2) + list_obj = self.sds.array(m1p, m2p) + + path = self.temp_dir + "01" + list_obj.write(path).compute() + ret_m1 = self.sds.read(path)[1].as_matrix().compute() + ret_m2 = self.sds.read(path)[2].as_matrix().compute() + self.assertTrue(np.allclose(m1, ret_m1)) + self.assertTrue(np.allclose(m2, ret_m2)) + + +if __name__ == "__main__": + unittest.main(exit=False) diff --git a/src/main/python/tests/matrix/test_cholesky.py b/src/main/python/tests/matrix/test_cholesky.py index 64772ed4e21..d6ba5ba232f 100644 --- a/src/main/python/tests/matrix/test_cholesky.py +++ b/src/main/python/tests/matrix/test_cholesky.py @@ -30,6 +30,7 @@ # set A = MM^T and A is a positive definite matrix A = np.matmul(A, A.transpose()) + class TestCholesky(unittest.TestCase): sds: SystemDSContext = None @@ -43,7 +44,7 @@ def tearDownClass(cls): cls.sds.close() -class TestCholesky_0(TestCholesky): +class TestCholeskyValid(TestCholesky): def test_basic1(self): L = self.sds.from_numpy(A).cholesky().compute() @@ -54,24 +55,27 @@ def test_basic2(self): # L * L.H = A self.assertTrue(np.allclose(A, np.dot(L, L.T.conj()))) -class TestCholesky_1(TestCholesky): + +class TestCholeskyInvalid_1(TestCholesky): def test_pos_def(self): m1 = -np.random.rand(shape, shape) - with self.assertRaises(RuntimeError) as context: + with self.assertRaises(Exception): self.sds.from_numpy(m1).cholesky().compute() - -class TestCholesky_2(TestCholesky): + + +class TestCholeskyInvalid_2(TestCholesky): def test_symmetric_matrix(self): m2 = np.asarray([[4, 9], [1, 4]]) np.linalg.cholesky(m2) - with self.assertRaises(RuntimeError) as context: + with self.assertRaises(Exception): self.sds.from_numpy(m2).cholesky().compute() -class TestCholesky_3(TestCholesky): + +class TestCholeskyInvalid_3(TestCholesky): def test_asymetric_dim(self): m3 = np.random.rand(shape, shape + 1) - with self.assertRaises(RuntimeError) as context: + with self.assertRaises(Exception): self.sds.from_numpy(m3).cholesky().compute() diff --git a/src/main/python/tests/matrix/test_order.py b/src/main/python/tests/matrix/test_order.py index 24878e16513..cd88ac4b4cd 100644 --- a/src/main/python/tests/matrix/test_order.py +++ b/src/main/python/tests/matrix/test_order.py @@ -19,8 +19,8 @@ # # ------------------------------------------------------------- -import unittest import random +import unittest import numpy as np from systemds.context import SystemDSContext @@ -33,7 +33,8 @@ my = np.random.rand(shape[0], 1) by = random.randrange(1, np.size(m, 1)+1) -class TestOrder(unittest.TestCase): + +class TestOrderBase(unittest.TestCase): sds: SystemDSContext = None @@ -45,26 +46,35 @@ def setUpClass(cls): def tearDownClass(cls): cls.sds.close() + +class TestOrderValid(TestOrderBase): + def test_basic(self): - o = self.sds.from_numpy(m).order(by=by, decreasing=False, index_return=False).compute() + o = self.sds.from_numpy(m).order( + by=by, decreasing=False, index_return=False).compute() s = m[np.argsort(m[:, by-1])] self.assertTrue(np.allclose(o, s)) def test_index(self): - o = self.sds.from_numpy(m).order(by=by, decreasing=False, index_return=True).compute() + o = self.sds.from_numpy(m).order( + by=by, decreasing=False, index_return=True).compute() s = np.argsort(m[:, by - 1]) + 1 self.assertTrue(np.allclose(np.transpose(o), s)) def test_decreasing(self): - o = self.sds.from_numpy(m).order(by=by, decreasing=True, index_return=True).compute() + o = self.sds.from_numpy(m).order( + by=by, decreasing=True, index_return=True).compute() s = np.argsort(-m[:, by - 1]) + 1 self.assertTrue(np.allclose(np.transpose(o), s)) -class TestOrder_1(TestOrder): + +class TestOrderInvalid(TestOrderBase): + def test_out_of_bounds(self): by_max = np.size(m, 1) + 2 - with self.assertRaises(RuntimeError) as context: + with self.assertRaises(Exception): self.sds.from_numpy(m).order(by=by_max).compute() + if __name__ == "__main__": unittest.main(exit=False) diff --git a/src/main/python/tests/matrix/test_print.py b/src/main/python/tests/matrix/test_print.py index b7231e778b5..c7337de73e6 100644 --- a/src/main/python/tests/matrix/test_print.py +++ b/src/main/python/tests/matrix/test_print.py @@ -22,6 +22,7 @@ import unittest import numpy as np +from time import sleep from systemds.context import SystemDSContext @@ -32,6 +33,10 @@ class TestPrint(unittest.TestCase): @classmethod def setUpClass(cls): cls.sds = SystemDSContext() + sleep(1.0) + # Clear stdout ... + cls.sds.get_stdout() + cls.sds.get_stdout() @classmethod def tearDownClass(cls): diff --git a/src/main/python/tests/script/test_dml_script.py b/src/main/python/tests/script/test_dml_script.py index f67405780fd..3c307e1cd58 100644 --- a/src/main/python/tests/script/test_dml_script.py +++ b/src/main/python/tests/script/test_dml_script.py @@ -20,7 +20,7 @@ # ------------------------------------------------------------- import unittest -import time +from time import sleep from systemds.context import SystemDSContext from systemds.script_building import DMLScript @@ -35,6 +35,9 @@ class Test_DMLScript(unittest.TestCase): @classmethod def setUpClass(cls): cls.sds = SystemDSContext() + sleep(1) + cls.sds.get_stdout() + cls.sds.get_stdout() @classmethod def tearDownClass(cls): @@ -44,7 +47,7 @@ def test_simple_print_1(self): script = DMLScript(self.sds) script.add_code('print("Hello")') script.execute() - time.sleep(0.5) + sleep(0.5) stdout = self.sds.get_stdout(100) self.assertListEqual(["Hello"], stdout) @@ -54,7 +57,7 @@ def test_simple_print_2(self): script.add_code('print("World")') script.add_code('print("!")') script.execute() - time.sleep(0.5) + sleep(0.5) stdout = self.sds.get_stdout(100) self.assertListEqual(['Hello', 'World', '!'], stdout) @@ -65,7 +68,7 @@ def test_multiple_executions_1(self): scr_a.add_code('y = x + 1') scr_a.add_code('print(y)') scr_a.execute() - time.sleep(0.5) + sleep(0.5) stdout = self.sds.get_stdout(100) self.assertEqual("4", stdout[0]) self.assertEqual("5", stdout[1]) diff --git a/src/test/java/org/apache/sysds/test/usertest/pythonapi/StartupTest.java b/src/test/java/org/apache/sysds/test/usertest/pythonapi/StartupTest.java index 01060991b93..585c6ef25fe 100644 --- a/src/test/java/org/apache/sysds/test/usertest/pythonapi/StartupTest.java +++ b/src/test/java/org/apache/sysds/test/usertest/pythonapi/StartupTest.java @@ -25,29 +25,36 @@ /** Simple tests to verify startup of Python Gateway server happens without crashes */ public class StartupTest { - @Test(expected = IllegalArgumentException.class) - public void testStartupIncorrect_1() { + @Test(expected = Exception.class) + public void testStartupIncorrect_1() throws Exception { PythonDMLScript.main(new String[] {}); } - @Test(expected = IllegalArgumentException.class) - public void testStartupIncorrect_2() { + @Test(expected = Exception.class) + public void testStartupIncorrect_2() throws Exception { PythonDMLScript.main(new String[] {""}); } - @Test(expected = IllegalArgumentException.class) - public void testStartupIncorrect_3() { + @Test(expected = Exception.class) + public void testStartupIncorrect_3() throws Exception { PythonDMLScript.main(new String[] {"131", "131"}); } - @Test(expected = NumberFormatException.class) - public void testStartupIncorrect_4() { + @Test(expected = Exception.class) + public void testStartupIncorrect_4() throws Exception { PythonDMLScript.main(new String[] {"Hello"}); } - @Test(expected = IllegalArgumentException.class) - public void testStartupIncorrect_5() { + @Test(expected = Exception.class) + public void testStartupIncorrect_5() throws Exception { // Number out of range - PythonDMLScript.main(new String[] {"918757"}); + PythonDMLScript.main(new String[] {"-python", "918757"}); + } + + @Test(expected = Exception.class) + public void testStartupCorrectButTwice() throws Exception { + // crash if you start two instances on same port. + PythonDMLScript.main(new String[] {"-python", "8142"}); + PythonDMLScript.main(new String[] {"-python", "8142"}); } }