Skip to content
This repository was archived by the owner on Feb 2, 2024. It is now read-only.

Commit f8e775d

Browse files
committed
Return back csv_reader_py changes from #918
This reverts commit 30122b2.
1 parent 5d09505 commit f8e775d

File tree

6 files changed

+139
-121
lines changed

6 files changed

+139
-121
lines changed

sdc/datatypes/hpat_pandas_functions.py

Lines changed: 45 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,13 @@
3939

4040
from sdc.io.csv_ext import (
4141
_gen_csv_reader_py_pyarrow_py_func,
42-
_gen_csv_reader_py_pyarrow_func_text_dataframe,
42+
_gen_pandas_read_csv_func_text,
4343
)
4444
from sdc.str_arr_ext import string_array_type
4545

4646
from sdc.hiframes import join, aggregate, sort
4747
from sdc.types import CategoricalDtypeType, Categorical
48+
from sdc.datatypes.categorical.pdimpl import _reconstruct_CategoricalDtype
4849

4950

5051
def get_numba_array_types_for_csv(df):
@@ -255,45 +256,69 @@ def sdc_pandas_read_csv(
255256
usecols = [col.literal_value for col in usecols]
256257

257258
if infer_from_params:
258-
# dtype should be constants and is important only for inference from params
259+
# dtype is a tuple of format ('A', A_dtype, 'B', B_dtype, ...)
260+
# where column names should be constants and is important only for inference from params
259261
if isinstance(dtype, types.Tuple):
260-
assert all(isinstance(key, types.Literal) for key in dtype[::2])
262+
assert all(isinstance(key, types.StringLiteral) for key in dtype[::2])
261263
keys = (k.literal_value for k in dtype[::2])
262-
263264
values = dtype[1::2]
264-
values = [v.typing_key if isinstance(v, types.Function) else v for v in values]
265-
values = [types.Array(numba.from_dtype(np.dtype(v.literal_value)), 1, 'C')
266-
if isinstance(v, types.Literal) else v for v in values]
267-
values = [types.Array(types.int_, 1, 'C') if v == int else v for v in values]
268-
values = [types.Array(types.float64, 1, 'C') if v == float else v for v in values]
269-
values = [string_array_type if v == str else v for v in values]
270-
values = [Categorical(v) if isinstance(v, CategoricalDtypeType) else v for v in values]
271265

272-
dtype = dict(zip(keys, values))
266+
def _get_df_col_type(dtype):
267+
if isinstance(dtype, types.Function):
268+
if dtype.typing_key == int:
269+
return types.Array(types.int_, 1, 'C')
270+
elif dtype.typing_key == float:
271+
return types.Array(types.float64, 1, 'C')
272+
elif dtype.typing_key == str:
273+
return string_array_type
274+
else:
275+
assert False, f"map_dtype_to_col_type: failing to infer column type for dtype={dtype}"
276+
277+
if isinstance(dtype, types.StringLiteral):
278+
if dtype.literal_value == 'str':
279+
return string_array_type
280+
else:
281+
return types.Array(numba.from_dtype(np.dtype(dtype.literal_value)), 1, 'C')
282+
283+
if isinstance(dtype, types.NumberClass):
284+
return types.Array(dtype.dtype, 1, 'C')
285+
286+
if isinstance(dtype, CategoricalDtypeType):
287+
return Categorical(dtype)
288+
289+
col_types_map = dict(zip(keys, map(_get_df_col_type, values)))
273290

274291
# in case of both are available
275292
# inferencing from params has priority over inferencing from file
276293
if infer_from_params:
277-
col_names = names
278294
# all names should be in dtype
279-
return_columns = usecols if usecols else names
280-
col_typs = [dtype[n] for n in return_columns]
295+
col_names = usecols if usecols else names
296+
col_types = [col_types_map[n] for n in col_names]
281297

282298
elif infer_from_file:
283-
col_names, col_typs = infer_column_names_and_types_from_constant_filename(
299+
col_names, col_types = infer_column_names_and_types_from_constant_filename(
284300
filepath_or_buffer, delimiter, names, usecols, skiprows)
285301

286302
else:
287303
return None
288304

289-
dtype_present = not isinstance(dtype, (types.Omitted, type(None)))
305+
def _get_py_col_dtype(ctype):
306+
""" Re-creates column dtype as python type to be used in read_csv call """
307+
dtype = ctype.dtype
308+
if ctype == string_array_type:
309+
return str
310+
if isinstance(ctype, Categorical):
311+
return _reconstruct_CategoricalDtype(ctype.pd_dtype)
312+
return numpy_support.as_dtype(dtype)
313+
314+
py_col_dtypes = {cname: _get_py_col_dtype(ctype) for cname, ctype in zip(col_names, col_types)}
290315

291316
# generate function text with signature and returning DataFrame
292-
func_text, func_name = _gen_csv_reader_py_pyarrow_func_text_dataframe(
293-
col_names, col_typs, dtype_present, usecols, signature)
317+
func_text, func_name, global_vars = _gen_pandas_read_csv_func_text(
318+
col_names, col_types, py_col_dtypes, usecols, signature)
294319

295320
# compile with Python
296-
csv_reader_py = _gen_csv_reader_py_pyarrow_py_func(func_text, func_name)
321+
csv_reader_py = _gen_csv_reader_py_pyarrow_py_func(func_text, func_name, global_vars)
297322

298323
return csv_reader_py
299324

sdc/hiframes/pd_dataframe_ext.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626

2727

2828
import operator
29-
from typing import NamedTuple
3029

3130
import numba
3231
from numba import types
@@ -39,7 +38,7 @@
3938
from numba.core.imputils import impl_ret_new_ref, impl_ret_borrowed
4039

4140
from sdc.hiframes.pd_series_ext import SeriesType
42-
from sdc.hiframes.pd_dataframe_type import DataFrameType
41+
from sdc.hiframes.pd_dataframe_type import DataFrameType, ColumnLoc
4342
from sdc.str_ext import string_type
4443

4544

@@ -54,10 +53,6 @@ def generic_resolve(self, df, attr):
5453
return SeriesType(arr_typ.dtype, arr_typ, df.index, True)
5554

5655

57-
class ColumnLoc(NamedTuple):
58-
type_id: int
59-
col_id: int
60-
6156

6257
def get_structure_maps(col_types, col_names):
6358
# Define map column name to column location ex. {'A': (0,0), 'B': (1,0), 'C': (0,1)}

sdc/hiframes/pd_dataframe_type.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
# EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2525
# *****************************************************************************
2626

27+
import re
28+
from typing import NamedTuple
2729

2830
import numba
2931
from numba import types
@@ -48,7 +50,7 @@ def __init__(self, data=None, index=None, columns=None, has_parent=False, column
4850
self.has_parent = has_parent
4951
self.column_loc = column_loc
5052
super(DataFrameType, self).__init__(
51-
name="dataframe({}, {}, {}, {})".format(data, index, columns, has_parent))
53+
name="DataFrameType({}, {}, {}, {})".format(data, index, columns, has_parent))
5254

5355
def copy(self, index=None, has_parent=None):
5456
# XXX is copy necessary?
@@ -83,6 +85,16 @@ def unify(self, typingctx, other):
8385
def is_precise(self):
8486
return all(a.is_precise() for a in self.data) and self.index.is_precise()
8587

88+
def __repr__(self):
89+
90+
# To have correct repr of DataFrame we need some changes to what types.Type gives:
91+
# (1) e.g. array(int64, 1d, C) should be Array(int64, 1, 'C')
92+
# (2) ColumnLoc is not part of DataFrame name, so we need to add it
93+
default_repr = super(DataFrameType, self).__repr__()
94+
res = re.sub(r'array\((\w+), 1d, C\)', r'Array(\1, 1, \'C\')', default_repr)
95+
res = re.sub(r'\)$', f', column_loc={self.column_loc})', res)
96+
return res
97+
8698

8799
@register_model(DataFrameType)
88100
class DataFrameModel(models.StructModel):
@@ -104,6 +116,15 @@ def __init__(self, dmm, fe_type):
104116
super(DataFrameModel, self).__init__(dmm, fe_type, members)
105117

106118

119+
class ColumnLoc(NamedTuple):
120+
type_id: int
121+
col_id: int
122+
123+
124+
# FIXME_Numba#3372: add into numba.types to allow returning from objmode
125+
types.DataFrameType = DataFrameType
126+
types.ColumnLoc = ColumnLoc
127+
107128
make_attribute_wrapper(DataFrameType, 'data', '_data')
108129
make_attribute_wrapper(DataFrameType, 'index', '_index')
109130
make_attribute_wrapper(DataFrameType, 'columns', '_columns')

0 commit comments

Comments
 (0)