From d3a82e1e25ddd338b9da6ffeec746a0cc932a06c Mon Sep 17 00:00:00 2001 From: Wenjun Si Date: Fri, 23 Dec 2022 14:31:11 +0800 Subject: [PATCH] Upgrade to v0.11.2.3 (#194) --- docs/source/df-basic.rst | 2 + docs/source/locale/en/LC_MESSAGES/df-basic.po | 74 +++++++++++-------- docs/source/options.rst | 2 +- odps/_version.py | 2 +- odps/config.py | 2 +- odps/core.py | 24 +++++- odps/df/backends/odpssql/context.py | 8 +- odps/df/backends/odpssql/tests/test_engine.py | 2 +- odps/lib/importer.py | 26 ++++--- odps/lib/tests/test_importer.py | 45 +++++++++++ odps/models/instance.py | 12 ++- odps/models/readers.py | 10 ++- odps/models/table.py | 9 ++- odps/models/tests/test_instances.py | 10 +++ odps/models/tests/test_session.py | 57 +++++++++++++- odps/models/tests/test_tables.py | 24 ++++++ odps/readers.py | 7 +- odps/serializers.py | 10 ++- odps/src/types_c.pyx | 3 + odps/tests/test_types.py | 10 +++ odps/tunnel/io/types.py | 35 +++++---- odps/tunnel/io/writer.py | 16 +++- odps/tunnel/tests/test_arrow_tabletunnel.py | 21 +++++- 23 files changed, 332 insertions(+), 79 deletions(-) diff --git a/docs/source/df-basic.rst b/docs/source/df-basic.rst index 74dd2e91..70b6248e 100644 --- a/docs/source/df-basic.rst +++ b/docs/source/df-basic.rst @@ -854,6 +854,8 @@ DataFrame 中相应字段的值决定该行将被写入的分区。例如,当 >>> iris[iris.sepalwidth < 2.5].persist('pyodps_iris4', partition='ds=test', drop_partition=True, create_partition=True) +persist 时,默认会覆盖原有数据。例如,当 persist 到一张分区表,对应分区的数据将会被重写。如果写入一张非分区表,整张表的数据都将被重写。如果你想要追加数据,可以使用参数 ``overwrite=False`` 。 + 写入表时,还可以指定表的生命周期,如下列语句将表的生命周期指定为10天: .. code:: python diff --git a/docs/source/locale/en/LC_MESSAGES/df-basic.po b/docs/source/locale/en/LC_MESSAGES/df-basic.po index ee6baf7a..697c6af1 100644 --- a/docs/source/locale/en/LC_MESSAGES/df-basic.po +++ b/docs/source/locale/en/LC_MESSAGES/df-basic.po @@ -1508,16 +1508,28 @@ msgid "" msgstr "" #: ../../source/df-basic.rst:857 +msgid "" +"persist 时,默认会覆盖原有数据。例如,当 persist " +"到一张分区表,对应分区的数据将会被重写。如果写入一张非分区表,整张表的数据都将被重写。如果你想要追加数据,可以使用参数 " +"``overwrite=False`` 。" +msgstr "" +"Persisting a DataFrame will overwrite existing data by default. For " +"instance, when persisting into a partitioned table, data in corresponding " +"partitions will be overwritten, while persisting into an unpartitioned " +"table will overwrite all data in it. If you want to append data into " +"existing tables or partitions, you may add ``overwrite=False`` ." + +#: ../../source/df-basic.rst:859 msgid "写入表时,还可以指定表的生命周期,如下列语句将表的生命周期指定为10天:" msgstr "" "You can also specify the lifecycle of a table when writing to it. The " "following example sets the lifecycle of a table to 10 days." -#: ../../source/df-basic.rst:859 +#: ../../source/df-basic.rst:861 msgid ">>> iris[iris.sepalwidth < 2.5].persist('pyodps_iris5', lifecycle=10)" msgstr "" -#: ../../source/df-basic.rst:863 +#: ../../source/df-basic.rst:865 msgid "" "如果数据源中没有 ODPS 对象,例如数据源仅为 Pandas,在 persist 时需要手动指定 ODPS 入口对象, " "或者将需要的入口对象标明为全局对象,如:" @@ -1526,7 +1538,7 @@ msgstr "" "data, you need to manually specify the ODPS object or mark the object as " "global when calling persist. For example:" -#: ../../source/df-basic.rst:866 +#: ../../source/df-basic.rst:868 msgid "" ">>> # 假设入口对象为 o\n" ">>> # 指定入口对象\n" @@ -1542,17 +1554,17 @@ msgstr "" ">>> o.to_global()\n" ">>> df.persist('table_name')" -#: ../../source/df-basic.rst:876 +#: ../../source/df-basic.rst:878 msgid "保存执行结果为 Pandas DataFrame" msgstr "Save results to pandas DataFrame" -#: ../../source/df-basic.rst:878 +#: ../../source/df-basic.rst:880 msgid "我们可以使用 ``to_pandas``\\ 方法,如果wrap参数为True,将返回PyODPS DataFrame对象。" msgstr "" "You can use the ``to_pandas``\\ method. If wrap is set to True, a PyODPS " "DataFrame object is returned." -#: ../../source/df-basic.rst:880 +#: ../../source/df-basic.rst:882 msgid "" ">>> type(iris[iris.sepalwidth < 2.5].to_pandas())\n" "pandas.core.frame.DataFrame\n" @@ -1560,7 +1572,7 @@ msgid "" "odps.df.core.DataFrame" msgstr "" -#: ../../source/df-basic.rst:889 +#: ../../source/df-basic.rst:891 msgid "" "``to_pandas`` 返回的 pandas DataFrame 与直接通过 pandas 创建的 DataFrame 没有任何区别, " "数据的存储和计算均在本地。如果 ``wrap=True``,生成的即便是 PyODPS DataFrame,数据依然在本地。 " @@ -1574,11 +1586,11 @@ msgstr "" " of data, or your running enviromnent is quite restrictive, please be " "cautious when using ``to_pandas``." -#: ../../source/df-basic.rst:894 +#: ../../source/df-basic.rst:896 msgid "立即运行设置运行参数" msgstr "Set runtime parameters" -#: ../../source/df-basic.rst:896 +#: ../../source/df-basic.rst:898 msgid "" "对于立即执行的方法,比如 ``execute``、``persist``、``to_pandas`` 等,可以设置运行时参数(仅对ODPS " "SQL后端有效 )。" @@ -1586,35 +1598,35 @@ msgstr "" "For actions such as `execute``, ``persist``, and ``to_pandas``, you can " "set runtime parameters. This is only valid for MaxCompute SQL." -#: ../../source/df-basic.rst:898 +#: ../../source/df-basic.rst:900 msgid "一种方法是设置全局参数。详细参考 :ref:`SQL设置运行参数 ` 。" msgstr "" "You can also set global parameters. For details, see :ref:`SQL - runtime " "parameters `." -#: ../../source/df-basic.rst:900 +#: ../../source/df-basic.rst:902 msgid "也可以在这些立即执行的方法上,使用 ``hints`` 参数。这样,这些参数只会作用于当前的计算过程。" msgstr "" "Additionally, you can use the `hints`` parameter. These parameters are " "only valid for the current calculation." -#: ../../source/df-basic.rst:903 +#: ../../source/df-basic.rst:905 msgid "" ">>> iris[iris.sepallength < " "5].to_pandas(hints={'odps.sql.mapper.split.size': 16})" msgstr "" -#: ../../source/df-basic.rst:909 +#: ../../source/df-basic.rst:911 msgid "运行时显示详细信息" msgstr "Display details at runtime" -#: ../../source/df-basic.rst:911 +#: ../../source/df-basic.rst:913 msgid "有时,用户需要查看运行时instance的logview时,需要修改全局配置:" msgstr "" "You sometimes need to modify the global configuration to view the logview" " of an instance." -#: ../../source/df-basic.rst:913 +#: ../../source/df-basic.rst:915 msgid "" ">>> from odps import options\n" ">>> options.verbose = True\n" @@ -1636,11 +1648,11 @@ msgid "" "4 2.9 1.4 0.2 Iris-setosa" msgstr "" -#: ../../source/df-basic.rst:934 +#: ../../source/df-basic.rst:936 msgid "用户可以指定自己的日志记录函数,比如像这样:" msgstr "You can specify a logging function as follows:" -#: ../../source/df-basic.rst:936 +#: ../../source/df-basic.rst:938 msgid "" ">>> my_logs = []\n" ">>> def my_logger(x):\n" @@ -1663,11 +1675,11 @@ msgid "" "\\nLIMIT 5', 'logview:', u'http://logview']" msgstr "" -#: ../../source/df-basic.rst:956 +#: ../../source/df-basic.rst:958 msgid "缓存中间Collection计算结果" msgstr "Cache intermediate results" -#: ../../source/df-basic.rst:958 +#: ../../source/df-basic.rst:960 msgid "" "DataFrame的计算过程中,一些Collection被多处使用,或者用户需要查看中间过程的执行结果, 这时用户可以使用 ``cache``\\" " 标记某个collection需要被优先计算。" @@ -1677,13 +1689,13 @@ msgstr "" "intermediate process, you can use the cache method to mark a collection " "object so that it is calculated first." -#: ../../source/df-basic.rst:963 +#: ../../source/df-basic.rst:965 msgid "值得注意的是,``cache``\\ 延迟执行,调用cache不会触发立即计算。" msgstr "" "Note that ``cache``\\ delays execution. Calling this method does not " "trigger automatic calculation." -#: ../../source/df-basic.rst:965 +#: ../../source/df-basic.rst:967 msgid "" ">>> cached = iris[iris.sepalwidth < 3.5].cache()\n" ">>> df = cached['sepallength', 'name'].head(3)\n" @@ -1712,11 +1724,11 @@ msgstr "" "1 4.7 Iris-setosa\n" "2 4.6 Iris-setosa" -#: ../../source/df-basic.rst:984 +#: ../../source/df-basic.rst:986 msgid "异步和并行执行" msgstr "Asynchronous and parallel executions" -#: ../../source/df-basic.rst:986 +#: ../../source/df-basic.rst:988 msgid "" "DataFrame 支持异步操作,对于立即执行的方法,包括 " "``execute``、``persist``、``head``、``tail``、``to_pandas`` (其他方法不支持), 传入 " @@ -1732,7 +1744,7 @@ msgstr "" "`_ objects." -#: ../../source/df-basic.rst:990 +#: ../../source/df-basic.rst:992 msgid "" ">>> future = iris[iris.sepal_width < 10].head(10, async=True)\n" ">>> future.done()\n" @@ -1751,7 +1763,7 @@ msgid "" "9 4.9 3.1 1.5 0.1 Iris-setosa" msgstr "" -#: ../../source/df-basic.rst:1009 +#: ../../source/df-basic.rst:1011 msgid "" "DataFrame 的并行执行可以使用多线程来并行,单个 expr 的执行可以通过 ``n_parallel`` 参数来指定并发度。 比如,当一个" " DataFrame 的执行依赖的多个 cache 的 DataFrame 能够并行执行时,该参数就会生效。" @@ -1762,7 +1774,7 @@ msgstr "" "the multiple cached DataFrame objects that a single DataFrame execution " "depends on can be executed in parallel." -#: ../../source/df-basic.rst:1012 +#: ../../source/df-basic.rst:1014 msgid "" ">>> expr1 = " "iris.groupby('category').agg(value=iris.sepal_width.sum()).cache()\n" @@ -1807,14 +1819,14 @@ msgstr "" "7 Iris-versicolor 3.000\n" "8 Iris-virginica 4.500" -#: ../../source/df-basic.rst:1032 +#: ../../source/df-basic.rst:1034 msgid "当同时执行多个 expr 时,我们可以用多线程执行,但会面临一个问题, 比如两个 DataFrame 有共同的依赖,这个依赖将会被执行两遍。" msgstr "" "You can use multiple threads to execute multiple expr objects in " "parallel, but you may encounter a problem when two DataFrame objects " "share the same dependency, and this dependency will be executed twice." -#: ../../source/df-basic.rst:1035 +#: ../../source/df-basic.rst:1037 msgid "" "现在我们提供了新的 ``Delay API``, 来将立即执行的操作(包括 " "``execute``、``persist``、``head``、``tail``、``to_pandas``,其他方法不支持)变成延迟操作, " @@ -1829,7 +1841,7 @@ msgstr "" "dependency is executed based on the degree of parallelism you have " "specified. Asynchronous execution is supported." -#: ../../source/df-basic.rst:1040 +#: ../../source/df-basic.rst:1042 msgid "" ">>> from odps.df import Delay\n" ">>> delay = Delay() # 创建Delay对象\n" @@ -1865,7 +1877,7 @@ msgstr "" ">>> future2.result()\n" "3.0540000000000007" -#: ../../source/df-basic.rst:1057 +#: ../../source/df-basic.rst:1059 msgid "" "可以看到上面的例子里,共同依赖的对象会先执行,然后再以并发度为3分别执行future1到future3。 当 ``n_parallel`` " "为1时,执行时间会达到37s。" @@ -1875,7 +1887,7 @@ msgstr "" "parallelism set to 3. When ``n_parallel`` is set to 1, the execution time" " reaches 37s." -#: ../../source/df-basic.rst:1060 +#: ../../source/df-basic.rst:1062 msgid "" "``delay.execute`` 也接受 ``async`` 操作来指定是否异步执行,当异步的时候,也可以指定 ``timeout`` " "参数来指定超时时间。" diff --git a/docs/source/options.rst b/docs/source/options.rst index a8daa6a6..22c4eeb4 100644 --- a/docs/source/options.rst +++ b/docs/source/options.rst @@ -53,7 +53,7 @@ PyODPS 提供了一系列的配置选项,可通过 ``odps.options`` 获得, +------------------------+---------------------------------------------------+-------+ |pool_maxsize | 连接池最大容量 |10 | +------------------------+---------------------------------------------------+-------+ -|connect_timeout | 连接超时 |10 | +|connect_timeout | 连接超时 |120 | +------------------------+---------------------------------------------------+-------+ |read_timeout | 读取超时 |120 | +------------------------+---------------------------------------------------+-------+ diff --git a/odps/_version.py b/odps/_version.py index e5e49bea..7aefa67b 100644 --- a/odps/_version.py +++ b/odps/_version.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -version_info = (0, 11, 2, 2) +version_info = (0, 11, 2, 3) _num_index = max(idx if isinstance(v, int) else 0 for idx, v in enumerate(version_info)) __version__ = '.'.join(map(str, version_info[:_num_index + 1])) + \ diff --git a/odps/config.py b/odps/config.py index 7debeb3d..627bb086 100644 --- a/odps/config.py +++ b/odps/config.py @@ -25,7 +25,7 @@ DEFAULT_CHUNK_SIZE = 1496 DEFAULT_CONNECT_RETRY_TIMES = 4 -DEFAULT_CONNECT_TIMEOUT = 10 +DEFAULT_CONNECT_TIMEOUT = 120 DEFAULT_READ_TIMEOUT = 120 DEFAULT_POOL_CONNECTIONS = 10 DEFAULT_POOL_MAXSIZE = 10 diff --git a/odps/core.py b/odps/core.py index 26883d2a..c3f31ef2 100644 --- a/odps/core.py +++ b/odps/core.py @@ -22,6 +22,7 @@ from .rest import RestClient from .config import options from .errors import NoSuchObject +from .errors import ODPSError from .tempobj import clean_stored_objects from .utils import split_quoted from .compat import six, Iterable @@ -1873,7 +1874,7 @@ def create_session(self, session_worker_count, session_worker_memory, def run_sql_interactive(self, sql, hints=None, **kwargs): """ Run SQL query in interactive mode (a.k.a MaxCompute QueryAcceleration). - + Won't fallback to offline mode automatically if query not supported or fails :param sql: the sql query. :param hints: settings for sql query. :return: instance. @@ -1894,6 +1895,27 @@ def run_sql_interactive(self, sql, hints=None, **kwargs): self._default_session_name = service_name return self._default_session.run_sql(sql, hints, **kwargs) + def run_sql_interactive_with_fallback(self, sql, hints=None, **kwargs): + """ + Run SQL query in interactive mode (a.k.a MaxCompute QueryAcceleration). + If query is not supported or fails, will fallback to offline mode automatically + :param sql: the sql query. + :param hints: settings for sql query. + :return: instance. + """ + inst = None + try: + if inst is None: + inst = self.run_sql_interactive(self, sql, hints=hints, **kwargs) + else: + inst.wait_for_success(interval=0.2) + rd = inst.open_reader(tunnel=True, limit=False) + if not rd: + raise ODPSError('Get sql result fail') + return inst + except: + return self.execute_sql(sql, hints=hints) + @classmethod def _build_account(cls, access_id, secret_access_key): return accounts.AliyunAccount(access_id, secret_access_key) diff --git a/odps/df/backends/odpssql/context.py b/odps/df/backends/odpssql/context.py index a45c4564..3f297f66 100644 --- a/odps/df/backends/odpssql/context.py +++ b/odps/df/backends/odpssql/context.py @@ -155,7 +155,8 @@ def prepare_resources(self, libraries): tar.close() res_name = self._gen_resource_name() + '.tar.gz' - res = self._odps.create_resource(res_name, 'archive', file_obj=tarbinary.getvalue()) + res = self._odps.create_resource(res_name, 'archive', file_obj=tarbinary.getvalue(), temp=True) + tempobj.register_temp_resource(self._odps, res_name) self._path_to_resources[lib] = res self._to_drops.append(res) ret_libs.append(res) @@ -166,7 +167,7 @@ def create_udfs(self, libraries=None): for func, udf in six.iteritems(self._func_to_udfs): udf_name = self._registered_funcs[func] - py_resource = self._odps.create_resource(udf_name + '.py', 'py', file_obj=udf) + py_resource = self._odps.create_resource(udf_name + '.py', 'py', file_obj=udf, temp=True) tempobj.register_temp_resource(self._odps, udf_name + '.py') self._to_drops.append(py_resource) @@ -176,8 +177,7 @@ def create_udfs(self, libraries=None): if not create: resources.append(name) else: - res = self._odps.create_resource(name, 'table', - table_name=table_name) + res = self._odps.create_resource(name, 'table', table_name=table_name, temp=True) tempobj.register_temp_resource(self._odps, name) resources.append(res) self._to_drops.append(res) diff --git a/odps/df/backends/odpssql/tests/test_engine.py b/odps/df/backends/odpssql/tests/test_engine.py index 56df43b7..6812563e 100644 --- a/odps/df/backends/odpssql/tests/test_engine.py +++ b/odps/df/backends/odpssql/tests/test_engine.py @@ -3515,7 +3515,7 @@ def testComposites(self): expr = expr_in[expr_in.name, expr_in.detail.values().explode()] res = self.engine.execute(expr) result = self._get_result(res) - self.assertEqual(result, expected) + self.assertEqual(sorted(result), sorted(expected)) expected = [ ['name1', 4.0, 2.0, False, False, ['HKG', 'PEK', 'SHA', 'YTY'], diff --git a/odps/lib/importer.py b/odps/lib/importer.py index a22fc596..f483af9b 100644 --- a/odps/lib/importer.py +++ b/odps/lib/importer.py @@ -127,25 +127,26 @@ def __init__(self, *compressed_files, **kwargs): # when it is forced to extract even if it is a text package, also extract f = self._extract_archive(f) - dir_prefixes = set() + prefixes = set(['']) + dir_prefixes = set() # only for lists or dicts if isinstance(f, zipfile.ZipFile): for name in f.namelist(): name = name if name.endswith('/') else (name.rsplit('/', 1)[0] + '/') - if name in dir_prefixes: + if name in prefixes: continue try: f.getinfo(name + '__init__.py') except KeyError: - dir_prefixes.add(name) + prefixes.add(name) elif isinstance(f, tarfile.TarFile): for member in f.getmembers(): name = member.name if member.isdir() else member.name.rsplit('/', 1)[0] - if name in dir_prefixes: + if name in prefixes: continue try: f.getmember(name + '/__init__.py') except KeyError: - dir_prefixes.add(name + '/') + prefixes.add(name + '/') elif isinstance(f, (list, dict)): # Force ArchiveResource to run under binary mode to resolve manually # opening __file__ paths in pure-python code. @@ -159,19 +160,22 @@ def __init__(self, *compressed_files, **kwargs): for name in rendered_names: name = name if name.endswith('/') else (name.rsplit('/', 1)[0] + '/') - if name in dir_prefixes or '/tests/' in name or '/__pycache__/' in name: + if name in prefixes or '/tests/' in name or '/__pycache__/' in name: continue if name + '__init__.py' not in rendered_names: + prefixes.add(name) dir_prefixes.add(name) else: if '/' in name.rstrip('/'): ppath = name.rstrip('/').rsplit('/', 1)[0] else: ppath = '' + prefixes.add(ppath) dir_prefixes.add(ppath) - # make sure only root packages are included, - # otherwise relative imports might be broken + # make sure only root packages are included, otherwise relative imports might be broken + # NOTE that it is needed to check sys.path duplication after all pruning done, + # otherwise path might be error once CompressImporter is called twice. path_patch = [] for p in sorted(dir_prefixes): parent_exist = False @@ -191,7 +195,11 @@ def __init__(self, *compressed_files, **kwargs): sys.path = sys.path + path_patch else: self._files.append(f) - self._prefixes[id(f)] = sorted([''] + path_patch) + if path_patch: + path_patch = [p for p in path_patch if p not in sys.path] + self._prefixes[id(f)] = sorted([''] + path_patch) + elif prefixes: + self._prefixes[id(f)] = sorted(prefixes) def _extract_archive(self, archive): if not self._extract_binary and not self._extract_all: diff --git a/odps/lib/tests/test_importer.py b/odps/lib/tests/test_importer.py index 56298e43..01e29839 100644 --- a/odps/lib/tests/test_importer.py +++ b/odps/lib/tests/test_importer.py @@ -48,6 +48,12 @@ def tearDown(self): sys.path = self.sys_path sys.meta_path = self.meta_path + @staticmethod + def _add_tar_directory(tar_file, path): + info = tarfile.TarInfo(name=path) + info.type = tarfile.DIRTYPE + tar_file.addfile(info, fileobj=BytesIO()) + def testImport(self): zip_io = BytesIO() zip_f = zipfile.ZipFile(zip_io, 'w') @@ -93,6 +99,32 @@ def testImport(self): from d import d self.assertEqual(d, 4) + def testRootedArchiveImport(self): + tar_io = BytesIO() + tar_f = tarfile.TarFile(fileobj=tar_io, mode='w') + self._add_tar_directory(tar_f, 'root') + self._add_tar_directory(tar_f, 'root/testb.1.0') + self._add_tar_directory(tar_f, 'root/testb.1.0/testb.info') + tar_f.addfile(tarfile.TarInfo(name='root/testb.1.0/testb.info/INFO.txt'), fileobj=BytesIO()) + self._add_tar_directory(tar_f, 'root/testb.1.0/testb') + tar_f.addfile(tarfile.TarInfo(name='root/testb.1.0/testb/__init__.py'), fileobj=BytesIO()) + info = tarfile.TarInfo(name='root/testb.1.0/testb/b.py') + c = b'b = 2' + s = BytesIO(c) + info.size = len(c) + tar_f.addfile(info, fileobj=s) + tar_f.close() + + tar_io.seek(0) + + tar_f = tarfile.TarFile(fileobj=tar_io) + importer.ALLOW_BINARY = False + imp = CompressImporter(tar_f) + sys.meta_path.append(imp) + + from testb.b import b + self.assertEqual(b, 2) + def testRealImport(self): six_path = os.path.join(os.path.dirname(os.path.abspath(six.__file__)), 'six.py') zip_io = BytesIO() @@ -227,6 +259,19 @@ def testBinaryImport(self): [f.close() for f in six.itervalues(lib_dict)] shutil.rmtree(temp_path) + def testRepeatImport(self): + dict_io_init = dict() + dict_io_init['/opt/test_pyodps_dev/testc/__init__.py'] = BytesIO() + dict_io_init['/opt/test_pyodps_dev/testc/c.py'] = BytesIO(b'from a import a; c = a + 2') + dict_io_init['/opt/test_pyodps_dev/testc/sub/__init__.py'] = BytesIO(b'from . import mod') + dict_io_init['/opt/test_pyodps_dev/testc/sub/mod.py'] = BytesIO(b'from ..c import c') + + sys.meta_path.append(CompressImporter(dict_io_init)) + sys.meta_path.append(CompressImporter(dict_io_init)) + self.assertIn('/opt/test_pyodps_dev', sys.path) + self.assertNotIn('/opt/test_pyodps_dev/testc', sys.path) + self.assertNotIn('/opt/test_pyodps_dev/testc/sub', sys.path) + if __name__ == '__main__': unittest.main() diff --git a/odps/models/instance.py b/odps/models/instance.py index b573a38c..ef0b5416 100644 --- a/odps/models/instance.py +++ b/odps/models/instance.py @@ -14,9 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import sys import base64 import json +import sys import threading import time import warnings @@ -33,6 +33,15 @@ from .readers import TunnelRecordReader, TunnelArrowReader from .worker import WorkerDetail2, LOG_TYPES_MAPPING +try: + from functools import wraps +except ImportError: + def wraps(_f): + def wrapper(fun): + return fun + + return wrapper + _RESULT_LIMIT_HELPER_MSG = ( 'See https://pyodps.readthedocs.io/zh_CN/latest/base-sql.html#read-sql-exec-result ' @@ -41,6 +50,7 @@ def _with_status_api_lock(func): + @wraps(func) def wrapped(self, *args, **kw): with self._status_api_lock: return func(self, *args, **kw) diff --git a/odps/models/readers.py b/odps/models/readers.py index b07a3f73..c156d18b 100644 --- a/odps/models/readers.py +++ b/odps/models/readers.py @@ -16,10 +16,11 @@ class TunnelRecordReader(AbstractRecordReader): - def __init__(self, parent, download_session): + def __init__(self, parent, download_session, columns=None): self._it = iter(self) self._parent = parent self._download_session = download_session + self._column_names = columns @property def download_id(self): @@ -51,6 +52,7 @@ def read(self, start=None, count=None, step=None, start = start or 0 step = step or 1 count = count * step if count is not None else self.count - start + columns = columns or self._column_names if count == 0: return @@ -104,10 +106,11 @@ def __exit__(self, exc_type, exc_val, exc_tb): class TunnelArrowReader(object): - def __init__(self, parent, download_session): + def __init__(self, parent, download_session, columns=None): self._it = iter(self) self._parent = parent self._download_session = download_session + self._column_names = columns @property def download_id(self): @@ -133,6 +136,7 @@ def __next__(self): def read(self, start=None, count=None, columns=None): start = start or 0 count = count if count is not None else self.count - start + columns = columns or self._column_names if count == 0: return @@ -150,6 +154,7 @@ def read(self, start=None, count=None, columns=None): def read_all(self, start=None, count=None, columns=None): start = start or 0 count = count if count is not None else self.count - start + columns = columns or self._column_names if count == 0: return @@ -160,6 +165,7 @@ def read_all(self, start=None, count=None, columns=None): return reader.read() def to_pandas(self, start=None, count=None, columns=None): + columns = columns or self._column_names return self.read_all(start=start, count=count, columns=columns).to_pandas() def __enter__(self): diff --git a/odps/models/table.py b/odps/models/table.py index dc473f73..61496094 100644 --- a/odps/models/table.py +++ b/odps/models/table.py @@ -137,9 +137,9 @@ def __dir__(self): class TableRecordReader(TunnelRecordReader): - def __init__(self, table, download_session, partition_spec=None): + def __init__(self, table, download_session, partition_spec=None, columns=None): super(TableRecordReader, self).__init__( - table, download_session + table, download_session, columns=columns ) self._partition_spec = partition_spec @@ -607,6 +607,7 @@ def open_reader( download_id=None, timeout=None, arrow=False, + columns=None, **kw ): """ @@ -650,9 +651,9 @@ def open_reader( download_ids[partition] = download_session.id if arrow: - return TableArrowReader(self, download_session) + return TableArrowReader(self, download_session, columns=columns) else: - return TableRecordReader(self, download_session, partition) + return TableRecordReader(self, download_session, partition, columns=columns) def open_writer( self, diff --git a/odps/models/tests/test_instances.py b/odps/models/tests/test_instances.py index b1f4a923..c2d9d25c 100644 --- a/odps/models/tests/test_instances.py +++ b/odps/models/tests/test_instances.py @@ -78,6 +78,13 @@ def _open_tunnel_reader(self, **kw): class Test(TestBase): + def setUp(self): + super(Test, self).setUp() + options.connect_timeout = 10 + + def tearDown(self): + super(Test, self).tearDown() + options.connect_timeout = 120 def testInstances(self): self.assertIs(self.odps.get_project().instances, self.odps.get_project().instances) @@ -239,6 +246,9 @@ def testCreateInstance(self): self.assertTrue(instance.is_successful()) self.assertTrue(self.odps.exist_table(test_table)) + instance = self.odps.execute_sql('select id `中文标题` from %s' % test_table) + self.assertTrue(instance.is_successful()) + instance = self.odps.execute_sql('drop table %s' % test_table) self.assertTrue(instance.is_successful()) self.assertFalse(self.odps.exist_table(test_table)) diff --git a/odps/models/tests/test_session.py b/odps/models/tests/test_session.py index bb70d645..79f45de5 100644 --- a/odps/models/tests/test_session.py +++ b/odps/models/tests/test_session.py @@ -39,6 +39,7 @@ def testCreateSession(self): except ODPSError as ex: print("LOGVIEW: " + sess_instance.get_logview_address()) print("Task results: " + str(sess_instance.get_task_results())) + sess_instance.stop() raise ex # the status should keep consistent self.assertTrue(sess_instance.status == Instance.Status.RUNNING) @@ -54,6 +55,7 @@ def testAttachSession(self): except ODPSError as ex: print("LOGVIEW: " + sess_instance.get_logview_address()) print("Task results: " + str(sess_instance.get_task_results())) + sess_instance.stop() raise ex # the status should keep consistent self.assertTrue(sess_instance.status == Instance.Status.RUNNING) @@ -65,6 +67,7 @@ def testAttachSession(self): except ODPSError as ex: print("LOGVIEW: " + att_instance.get_logview_address()) print("Task results: " + str(att_instance.get_task_results())) + sess_instance.stop() raise ex # the status should keep consistent self.assertTrue(att_instance.status == Instance.Status.RUNNING) @@ -82,6 +85,7 @@ def testAttachDefaultSession(self): except ODPSError as ex: print("LOGVIEW: " + sess_instance.get_logview_address()) print("Task results: " + str(sess_instance.get_task_results())) + sess_instance.stop() raise ex # the status should keep consistent self.assertTrue(sess_instance.status == Instance.Status.RUNNING) @@ -96,6 +100,7 @@ def testSessionFailingSQL(self): except ODPSError as ex: print("LOGVIEW: " + sess_instance.get_logview_address()) print("Task results: " + str(sess_instance.get_task_results())) + sess_instance.stop() raise ex # the status should keep consistent self.assertTrue(sess_instance.status == Instance.Status.RUNNING) @@ -131,7 +136,8 @@ def testDirectExecuteFailingSQL(self): self.assertTrue(False) except ODPSError: pass # good - sess_instance.stop() + finally: + sess_instance.stop() def testSessionSQL(self): self.odps.delete_table(TEST_TABLE_NAME, if_exists=True) @@ -145,6 +151,7 @@ def testSessionSQL(self): except ODPSError as ex: print("LOGVIEW: " + sess_instance.get_logview_address()) print("Task results: " + str(sess_instance.get_task_results())) + sess_instance.stop() raise ex # the status should keep consistent self.assertTrue(sess_instance.status == Instance.Status.RUNNING) @@ -160,6 +167,7 @@ def testSessionSQL(self): except BaseException as ex: print("LOGVIEW: " + select_inst.get_logview_address()) print("Task Result:" + str(select_inst.get_task_results())) + sess_instance.stop() raise ex self.assertTrue(len(rows) == len(TEST_DATA)) self.assertTrue(len(rows[0]) == len(TEST_DATA[0])) @@ -167,6 +175,7 @@ def testSessionSQL(self): self.assertTrue(int(rows[index][0]) == int(TEST_DATA[index][0])) # OK, clear up self.odps.delete_table(TEST_TABLE_NAME, if_exists=True) + sess_instance.stop() def testDirectExecuteSQL(self): self.odps.delete_table(TEST_TABLE_NAME, if_exists=True) @@ -181,6 +190,7 @@ def testDirectExecuteSQL(self): except ODPSError as ex: print("LOGVIEW: " + sess_instance.get_logview_address()) print("Task results: " + str(sess_instance.get_task_results())) + sess_instance.stop() raise ex # the status should keep consistent self.assertTrue(sess_instance.status == Instance.Status.RUNNING) @@ -196,6 +206,7 @@ def testDirectExecuteSQL(self): except BaseException as ex: print("LOGVIEW: " + select_inst.get_logview_address()) print("Task Result:" + str(select_inst.get_task_results())) + sess_instance.stop() raise ex self.assertTrue(len(rows) == len(TEST_DATA)) self.assertTrue(len(rows[0]) == len(TEST_DATA[0])) @@ -203,6 +214,47 @@ def testDirectExecuteSQL(self): self.assertTrue(int(rows[index][0]) == int(TEST_DATA[index][0])) # OK, clear up self.odps.delete_table(TEST_TABLE_NAME, if_exists=True) + sess_instance.stop() + + def testDirectExecuteSQLFallback(self): + self.odps.delete_table(TEST_TABLE_NAME, if_exists=True) + table = self.odps.create_table(TEST_TABLE_NAME, TEST_CREATE_SCHEMA) + self.assertTrue(table) + # the default public session may not exist, so we create one beforehand + sess_instance = self.odps.create_session(TEST_SESSION_WORKERS, TEST_SESSION_WORKER_MEMORY) + self.assertTrue(sess_instance) + # wait to running + try: + sess_instance.wait_for_startup() + except ODPSError as ex: + print("LOGVIEW: " + sess_instance.get_logview_address()) + print("Task results: " + str(sess_instance.get_task_results())) + sess_instance.stop() + raise ex + # the status should keep consistent + self.assertTrue(sess_instance.status == Instance.Status.RUNNING) + records = [Record(schema=TEST_CREATE_SCHEMA, values=values) for values in TEST_DATA] + self.odps.write_table(table, 0, records) + hints = {"odps.mcqa.disable":"true"} + select_inst = self.odps.run_sql_interactive_with_fallback(TEST_SELECT_STRING, service_name=sess_instance.name, hints=hints) + select_inst.wait_for_success() + rows = [] + try: + with select_inst.open_reader() as rd: + for each_row in rd: + rows.append(each_row.values) + except BaseException as ex: + print("LOGVIEW: " + select_inst.get_logview_address()) + print("Task Result:" + str(select_inst.get_task_results())) + sess_instance.stop() + raise ex + self.assertTrue(len(rows) == len(TEST_DATA)) + self.assertTrue(len(rows[0]) == len(TEST_DATA[0])) + for index in range(5): + self.assertTrue(int(rows[index][0]) == int(TEST_DATA[index][0])) + # OK, clear up + self.odps.delete_table(TEST_TABLE_NAME, if_exists=True) + sess_instance.stop() def testSessionSQLWithInstanceTunnel(self): self.odps.delete_table(TEST_TABLE_NAME, if_exists=True) @@ -216,6 +268,7 @@ def testSessionSQLWithInstanceTunnel(self): except ODPSError as ex: print("LOGVIEW: " + sess_instance.get_logview_address()) print("Task results: " + str(sess_instance.get_task_results())) + sess_instance.stop() raise ex # the status should keep consistent self.assertTrue(sess_instance.status == Instance.Status.RUNNING) @@ -231,6 +284,7 @@ def testSessionSQLWithInstanceTunnel(self): except BaseException as ex: print("LOGVIEW: " + select_inst.get_logview_address()) print("Task Result:" + str(select_inst.get_task_results())) + sess_instance.stop() raise ex self.assertTrue(len(rows) == len(TEST_DATA)) self.assertTrue(len(rows[0]) == len(TEST_DATA[0])) @@ -238,6 +292,7 @@ def testSessionSQLWithInstanceTunnel(self): self.assertTrue(int(rows[index][0]) == int(TEST_DATA[index][0])) # OK, clear up self.odps.delete_table(TEST_TABLE_NAME, if_exists=True) + sess_instance.stop() if __name__ == '__main__': unittest.main() diff --git a/odps/models/tests/test_tables.py b/odps/models/tests/test_tables.py index 5694bd74..a1f46c74 100644 --- a/odps/models/tests/test_tables.py +++ b/odps/models/tests/test_tables.py @@ -468,6 +468,30 @@ def testMultiProcessToPandas(self): pd_data = reader.to_pandas(n_process=2) assert len(pd_data) == 1000 + @unittest.skipIf(pd is None, "Need pandas to run this test") + def testColumnSelectToPandas(self): + test_table_name = tn('pyodps_t_tmp_col_select_table') + schema = Schema.from_lists(['num1', 'num2'], ['bigint', 'bigint']) + + self.odps.delete_table(test_table_name, if_exists=True) + + table = self.odps.create_table(test_table_name, schema) + with table.open_writer(arrow=True) as writer: + writer.write(pd.DataFrame({ + "num1": np.random.randint(0, 1000, 1000), + "num2": np.random.randint(0, 1000, 1000), + })) + + with table.open_reader(columns=["num1"]) as reader: + pd_data = reader.to_pandas() + assert len(pd_data) == 1000 + assert len(pd_data.columns) == 1 + + with table.open_reader(columns=["num1"], arrow=True) as reader: + pd_data = reader.to_pandas() + assert len(pd_data) == 1000 + assert len(pd_data.columns) == 1 + @unittest.skipIf(pa is None, "Need pyarrow to run this test") def testSimpleArrowReadWriteTable(self): test_table_name = tn('pyodps_t_tmp_simple_arrow_read_write_table') diff --git a/odps/readers.py b/odps/readers.py index e0bbde26..df1e086c 100644 --- a/odps/readers.py +++ b/odps/readers.py @@ -105,7 +105,10 @@ def to_result_frame(self, unknown_as_string=True, as_type=None): elif getattr(self, '_schema', None) is not None: # do not remove as there might be coverage missing kw['schema'] = odps_schema_to_df_schema(self._schema) - elif getattr(self, '_columns', None) is not None: + + if getattr(self, '_column_names', None) is not None: + self._columns = [self.schema[c] for c in self._column_names] + if getattr(self, '_columns', None) is not None: cols = [] for col in self._columns: col = copy.copy(col) @@ -277,4 +280,4 @@ def __enter__(self): return self def __exit__(self, *_): - self.close() \ No newline at end of file + self.close() diff --git a/odps/serializers.py b/odps/serializers.py index 97e6deb9..c0073dfb 100644 --- a/odps/serializers.py +++ b/odps/serializers.py @@ -22,7 +22,7 @@ import requests from . import compat, utils -from .compat import ElementTree, six +from .compat import BytesIO, ElementTree, PY26, six from .utils import to_text @@ -367,7 +367,13 @@ def parse(cls, response, obj=None, **kw): def serialize(self): root = self.serial() - xml_content = ElementTree.tostring(root, 'utf-8') + + sio = BytesIO() + if PY26: + ElementTree.ElementTree(root).write(sio, encoding="utf-8") + else: + ElementTree.ElementTree(root).write(sio, encoding="utf-8", xml_declaration=True) + xml_content = sio.getvalue() prettified_xml = minidom.parseString(xml_content).toprettyxml(indent=' '*2, encoding='utf-8') prettified_xml = to_text(prettified_xml, encoding='utf-8') diff --git a/odps/src/types_c.pyx b/odps/src/types_c.pyx index dd10fa76..f1a3ce90 100644 --- a/odps/src/types_c.pyx +++ b/odps/src/types_c.pyx @@ -270,6 +270,9 @@ cdef class BaseRecord: if values is not None: self._sets(values) + def __reduce__(self): + return type(self), (self._c_columns, None, self._c_values) + @property def _columns(self): return self._c_columns diff --git a/odps/tests/test_types.py b/odps/tests/test_types.py index 0691eb2d..6e719363 100644 --- a/odps/tests/test_types.py +++ b/odps/tests/test_types.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import decimal as _decimal from odps.types import * @@ -227,6 +228,15 @@ def testSetWithCast(self): r['datetime'] = '2016-01-01 0:0:0' self.assertEqual(datetime(2016, 1, 1), r['datetime']) + @bothPyAndC + def testRecordCopy(self): + s = Schema.from_lists(['col1'], ['string']) + r = Record(schema=s) + r.col1 = 'a' + + cr = copy.copy(r) + assert cr.col1 == r.col1 + @bothPyAndC def testRecordSetField(self): s = Schema.from_lists(['col1'], ['string',]) diff --git a/odps/tunnel/io/types.py b/odps/tunnel/io/types.py index 69a89858..e9b65b84 100644 --- a/odps/tunnel/io/types.py +++ b/odps/tunnel/io/types.py @@ -34,30 +34,39 @@ odps_types.float_: pa.float32(), odps_types.double: pa.float64(), odps_types.date: pa.date32(), - odps_types.datetime: pa.timestamp('ns'), + odps_types.datetime: pa.timestamp('ms'), odps_types.timestamp: pa.timestamp('ns') } else: _ODPS_ARROW_TYPE_MAPPING = {} -def odps_schema_to_arrow_schema(odps_schema): +def odps_type_to_arrow_type(odps_type): from ... import types + if odps_type in _ODPS_ARROW_TYPE_MAPPING: + col_type = _ODPS_ARROW_TYPE_MAPPING[odps_type] + else: + if isinstance(odps_type, types.Array): + col_type = pa.list_(odps_type_to_arrow_type(odps_type.value_type)) + elif isinstance(odps_type, types.Map): + col_type = pa.map_( + odps_type_to_arrow_type(odps_type.key_type), + odps_type_to_arrow_type(odps_type.value_type), + ) + elif isinstance(odps_type, types.Decimal): + col_type = pa.decimal128(odps_type.precision, odps_type.scale) + else: + raise TypeError('Unsupported type: {}'.format(odps_type)) + return col_type + + +def odps_schema_to_arrow_schema(odps_schema): + arrow_schema = [] for schema in odps_schema.simple_columns: col_name = schema.name - if schema.type in _ODPS_ARROW_TYPE_MAPPING: - col_type = _ODPS_ARROW_TYPE_MAPPING[schema.type] - else: - if isinstance(schema.type, types.Array): - col_type = pa.list_(_ODPS_ARROW_TYPE_MAPPING[schema.type.value_type]) - elif isinstance(schema.type, types.Decimal): - col_type = pa.decimal128(schema.type.precision, - schema.type.scale) - else: - raise TypeError('Unsupported type: {}'.format(schema.type)) - + col_type = odps_type_to_arrow_type(schema.type) arrow_schema.append(pa.field(col_name, col_type)) return pa.schema(arrow_schema) diff --git a/odps/tunnel/io/writer.py b/odps/tunnel/io/writer.py index 810e3c7f..f20accff 100644 --- a/odps/tunnel/io/writer.py +++ b/odps/tunnel/io/writer.py @@ -20,6 +20,10 @@ import pyarrow as pa except (AttributeError, ImportError): pa = None +try: + import numpy as np +except ImportError: + np = None try: import pandas as pd except ImportError: @@ -525,7 +529,17 @@ def _write_chunk(self, buf): self._cur_chunk_size = 0 def write(self, data): + from ...lib import tzlocal + if isinstance(data, pd.DataFrame): + copied = False + for col_name, dtype in data.dtypes.items(): + # cast timezone as local to make sure timezone of arrow is correct + if isinstance(dtype, np.dtype) and np.issubdtype(dtype, np.datetime64): + if not copied: + data = data.copy() + copied = True + data[col_name] = data[col_name].dt.tz_localize(tzlocal.get_localzone()) arrow_batch = pa.RecordBatch.from_pandas(data) else: arrow_batch = data @@ -542,7 +556,7 @@ def write(self, data): arrays.append(column_dict[name]) else: try: - arrays.append(column_dict[name].cast(tp)) + arrays.append(column_dict[name].cast(tp, safe=False)) except pa.ArrowInvalid: raise ValueError("Failed to cast column %s to type %s" % (name, tp)) arrow_batch = pa.RecordBatch.from_arrays(arrays, names=self._arrow_schema.names) diff --git a/odps/tunnel/tests/test_arrow_tabletunnel.py b/odps/tunnel/tests/test_arrow_tabletunnel.py index 8a2b47c2..ef017097 100644 --- a/odps/tunnel/tests/test_arrow_tabletunnel.py +++ b/odps/tunnel/tests/test_arrow_tabletunnel.py @@ -46,12 +46,12 @@ @unittest.skipIf(pa is None, "need to install pyarrow") class Test(TestBase): def setUp(self): - super().setUp() + super(Test, self).setUp() options.sql.use_odps2_extension = True def tearDown(self): options.sql.use_odps2_extension = None - super().tearDown() + super(Test, self).tearDown() def _upload_data(self, test_table, data, compress=False, **kw): upload_ss = self.tunnel.create_upload_session(test_table, **kw) @@ -62,12 +62,20 @@ def _upload_data(self, test_table, data, compress=False, **kw): upload_ss.commit([0, ]) def _download_data(self, test_table, columns=None, compress=False, **kw): + from odps.lib import tzlocal + count = kw.pop('count', 4) download_ss = self.tunnel.create_download_session(test_table, **kw) with download_ss.open_arrow_reader(0, count, compress=compress, columns=columns) as reader: - return reader.to_pandas() + pd_data = reader.to_pandas() + for col_name, dtype in pd_data.dtypes.items(): + if isinstance(dtype, np.dtype) and np.issubdtype(dtype, np.datetime64): + pd_data[col_name] = pd_data[col_name].dt.tz_localize(tzlocal.get_localzone()) + return pd_data def _gen_data(self, repeat=1): + from odps.lib import tzlocal + data = dict() data['id'] = ['hello \x00\x00 world', 'goodbye', 'c' * 2, 'c' * 20] * repeat data['int_num'] = [2**63-1, 222222, -2 ** 63 + 1, -2 ** 11 + 1] * repeat @@ -80,8 +88,13 @@ def _gen_data(self, repeat=1): datetime.datetime.now() + datetime.timedelta(days=idx) for idx in range(4) ] * repeat + data['dt'] = [ + dt.replace(microsecond=dt.microsecond // 1000 * 1000) for dt in data['dt'] + ] + pd_data = pd.DataFrame(data) + pd_data["dt"] = pd_data["dt"].dt.tz_localize(tzlocal.get_localzone()) - return pa.RecordBatch.from_pandas(pd.DataFrame(data)) + return pa.RecordBatch.from_pandas(pd_data) def _create_table(self, table_name): fields = ['id', 'int_num', 'float_num', 'bool', 'date', 'dt']