|
39 | 39 |
|
40 | 40 | from sdc.io.csv_ext import ( |
41 | 41 | _gen_csv_reader_py_pyarrow_py_func, |
42 | | - _gen_csv_reader_py_pyarrow_func_text_dataframe, |
| 42 | + _gen_pandas_read_csv_func_text, |
43 | 43 | ) |
44 | 44 | from sdc.str_arr_ext import string_array_type |
45 | 45 |
|
46 | 46 | from sdc.hiframes import join, aggregate, sort |
47 | 47 | from sdc.types import CategoricalDtypeType, Categorical |
| 48 | +from sdc.datatypes.categorical.pdimpl import _reconstruct_CategoricalDtype |
48 | 49 |
|
49 | 50 |
|
50 | 51 | def get_numba_array_types_for_csv(df): |
@@ -255,45 +256,69 @@ def sdc_pandas_read_csv( |
255 | 256 | usecols = [col.literal_value for col in usecols] |
256 | 257 |
|
257 | 258 | if infer_from_params: |
258 | | - # dtype should be constants and is important only for inference from params |
| 259 | + # dtype is a tuple of format ('A', A_dtype, 'B', B_dtype, ...) |
| 260 | + # where column names should be constants and is important only for inference from params |
259 | 261 | if isinstance(dtype, types.Tuple): |
260 | | - assert all(isinstance(key, types.Literal) for key in dtype[::2]) |
| 262 | + assert all(isinstance(key, types.StringLiteral) for key in dtype[::2]) |
261 | 263 | keys = (k.literal_value for k in dtype[::2]) |
262 | | - |
263 | 264 | values = dtype[1::2] |
264 | | - values = [v.typing_key if isinstance(v, types.Function) else v for v in values] |
265 | | - values = [types.Array(numba.from_dtype(np.dtype(v.literal_value)), 1, 'C') |
266 | | - if isinstance(v, types.Literal) else v for v in values] |
267 | | - values = [types.Array(types.int_, 1, 'C') if v == int else v for v in values] |
268 | | - values = [types.Array(types.float64, 1, 'C') if v == float else v for v in values] |
269 | | - values = [string_array_type if v == str else v for v in values] |
270 | | - values = [Categorical(v) if isinstance(v, CategoricalDtypeType) else v for v in values] |
271 | 265 |
|
272 | | - dtype = dict(zip(keys, values)) |
| 266 | + def _get_df_col_type(dtype): |
| 267 | + if isinstance(dtype, types.Function): |
| 268 | + if dtype.typing_key == int: |
| 269 | + return types.Array(types.int_, 1, 'C') |
| 270 | + elif dtype.typing_key == float: |
| 271 | + return types.Array(types.float64, 1, 'C') |
| 272 | + elif dtype.typing_key == str: |
| 273 | + return string_array_type |
| 274 | + else: |
| 275 | + assert False, f"map_dtype_to_col_type: failing to infer column type for dtype={dtype}" |
| 276 | + |
| 277 | + if isinstance(dtype, types.StringLiteral): |
| 278 | + if dtype.literal_value == 'str': |
| 279 | + return string_array_type |
| 280 | + else: |
| 281 | + return types.Array(numba.from_dtype(np.dtype(dtype.literal_value)), 1, 'C') |
| 282 | + |
| 283 | + if isinstance(dtype, types.NumberClass): |
| 284 | + return types.Array(dtype.dtype, 1, 'C') |
| 285 | + |
| 286 | + if isinstance(dtype, CategoricalDtypeType): |
| 287 | + return Categorical(dtype) |
| 288 | + |
| 289 | + col_types_map = dict(zip(keys, map(_get_df_col_type, values))) |
273 | 290 |
|
274 | 291 | # in case of both are available |
275 | 292 | # inferencing from params has priority over inferencing from file |
276 | 293 | if infer_from_params: |
277 | | - col_names = names |
278 | 294 | # all names should be in dtype |
279 | | - return_columns = usecols if usecols else names |
280 | | - col_typs = [dtype[n] for n in return_columns] |
| 295 | + col_names = usecols if usecols else names |
| 296 | + col_types = [col_types_map[n] for n in col_names] |
281 | 297 |
|
282 | 298 | elif infer_from_file: |
283 | | - col_names, col_typs = infer_column_names_and_types_from_constant_filename( |
| 299 | + col_names, col_types = infer_column_names_and_types_from_constant_filename( |
284 | 300 | filepath_or_buffer, delimiter, names, usecols, skiprows) |
285 | 301 |
|
286 | 302 | else: |
287 | 303 | return None |
288 | 304 |
|
289 | | - dtype_present = not isinstance(dtype, (types.Omitted, type(None))) |
| 305 | + def _get_py_col_dtype(ctype): |
| 306 | + """ Re-creates column dtype as python type to be used in read_csv call """ |
| 307 | + dtype = ctype.dtype |
| 308 | + if ctype == string_array_type: |
| 309 | + return str |
| 310 | + if isinstance(ctype, Categorical): |
| 311 | + return _reconstruct_CategoricalDtype(ctype.pd_dtype) |
| 312 | + return numpy_support.as_dtype(dtype) |
| 313 | + |
| 314 | + py_col_dtypes = {cname: _get_py_col_dtype(ctype) for cname, ctype in zip(col_names, col_types)} |
290 | 315 |
|
291 | 316 | # generate function text with signature and returning DataFrame |
292 | | - func_text, func_name = _gen_csv_reader_py_pyarrow_func_text_dataframe( |
293 | | - col_names, col_typs, dtype_present, usecols, signature) |
| 317 | + func_text, func_name, global_vars = _gen_pandas_read_csv_func_text( |
| 318 | + col_names, col_types, py_col_dtypes, usecols, signature) |
294 | 319 |
|
295 | 320 | # compile with Python |
296 | | - csv_reader_py = _gen_csv_reader_py_pyarrow_py_func(func_text, func_name) |
| 321 | + csv_reader_py = _gen_csv_reader_py_pyarrow_py_func(func_text, func_name, global_vars) |
297 | 322 |
|
298 | 323 | return csv_reader_py |
299 | 324 |
|
|
0 commit comments