Skip to content

Commit

Permalink
Merge 0d2cd55 into 3cc6824
Browse files Browse the repository at this point in the history
  • Loading branch information
oleurud committed Sep 20, 2019
2 parents 3cc6824 + 0d2cd55 commit fc2d03c
Show file tree
Hide file tree
Showing 6 changed files with 497 additions and 220 deletions.
186 changes: 35 additions & 151 deletions cartoframes/data/dataset/registry/dataframe_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
from carto.exceptions import CartoException, CartoRateLimitException

from .base_dataset import BaseDataset
from ....utils.columns import Column, normalize_names
from ....utils.geom_utils import decode_geometry, compute_geodataframe, \
detect_encoding_type, save_index_as_column
from ....utils.columns import DataframeColumnsInfo, _first_value
from ....utils.geom_utils import decode_geometry, compute_geodataframe, save_index_as_column
from ....utils.utils import map_geom_type, load_geojson, is_geojson


Expand Down Expand Up @@ -53,14 +52,14 @@ def download(self, limit, decode_geom, retry_times):
def upload(self, if_exists, with_lnglat):
self._is_ready_for_upload_validation()

normalized_column_names = _normalize_column_names(self._df)
dataframe_columns_info = DataframeColumnsInfo(self._df, with_lnglat)

if if_exists == BaseDataset.REPLACE or not self.exists():
self._create_table(normalized_column_names, with_lnglat)
self._create_table(dataframe_columns_info.columns)
elif if_exists == BaseDataset.FAIL:
raise self._already_exists_error()

self._copyfrom(normalized_column_names, with_lnglat)
self._copyfrom(dataframe_columns_info, with_lnglat)

def delete(self):
raise ValueError('Method not allowed in DataFrameDataset. You should use a TableDataset: `Dataset(my_table)`')
Expand All @@ -80,48 +79,19 @@ def get_column_names(self, exclude=None):

return columns

def _copyfrom(self, normalized_column_names, with_lnglat):
geom_col = _get_geom_col_name(self._df)
enc_type = _detect_encoding_type(self._df, geom_col)
columns_normalized, columns_origin = self._copyfrom_column_names(
geom_col,
normalized_column_names,
with_lnglat)

def _copyfrom(self, dataframe_columns_info, with_lnglat):
query = """COPY {table_name}({columns}) FROM stdin WITH (FORMAT csv, DELIMITER '|');""".format(
table_name=self._table_name,
columns=','.join(columns_normalized))
columns=','.join(c.database for c in dataframe_columns_info.columns))

data = _rows(
self._df,
columns_origin,
with_lnglat,
geom_col,
enc_type,
len(columns_normalized))
data = _rows(self._df, dataframe_columns_info, with_lnglat)

self._context.upload(query, data)

def _copyfrom_column_names(self, geom_col, normalized_column_names, with_lnglat=None):
columns_normalized = []
columns_origin = []

if geom_col:
columns_origin.append(geom_col)

for norm, orig in normalized_column_names:
columns_normalized.append(norm)
columns_origin.append(orig)

if geom_col or with_lnglat:
columns_normalized.append('the_geom')

return columns_normalized, columns_origin

def _create_table(self, normalized_column_names, with_lnglat=None):
def _create_table(self, columns):
query = '''BEGIN; {drop}; {create}; {cartodbfy}; COMMIT;'''.format(
drop=self._drop_table_query(),
create=self._create_table_query(normalized_column_names, with_lnglat),
create=self._create_table_query(columns),
cartodbfy=self._cartodbfy_query())

try:
Expand All @@ -131,22 +101,12 @@ def _create_table(self, normalized_column_names, with_lnglat=None):
except CartoException as err:
raise CartoException('Cannot create table: {}.'.format(err))

def _create_table_query(self, normalized_column_names, with_lnglat=None):
if with_lnglat is None:
geom_type = _get_geom_col_type(self._df)
else:
geom_type = 'Point'
def _create_table_query(self, columns):
cols = ['{column} {type}'.format(column=c.database, type=c.database_type) for c in columns]

col = ('{col} {ctype}')
cols = ', '.join(col.format(col=norm,
ctype=_dtypes2pg(self._df.dtypes[orig]))
for norm, orig in normalized_column_names)

if geom_type:
cols += ', {geom_colname} geometry({geom_type}, 4326)'.format(geom_colname='the_geom', geom_type=geom_type)

create_query = '''CREATE TABLE {table_name} ({cols})'''.format(table_name=self._table_name, cols=cols)
return create_query
return '''CREATE TABLE {table_name} ({cols})'''.format(
table_name=self._table_name,
cols=', '.join(cols))

def _get_geom_type(self):
"""Compute geom type of the local dataframe"""
Expand All @@ -156,38 +116,34 @@ def _get_geom_type(self):
return map_geom_type(geometry.geom_type)


def _rows(df, cols, with_lnglat, geom_col, enc_type, columns_number=None):
columns_number = columns_number or len(cols)

def _rows(df, dataframe_columns_info, with_lnglat):
for i, row in df.iterrows():
row_data = []
the_geom_val = None
lng_val = None
lat_val = None
for col in cols:
for c in dataframe_columns_info.columns:
col = c.dataframe
if col not in df.columns: # we could have filtered columns in the df. See DataframeColumnsInfo
continue
val = row[col]

if _is_null(val):
val = ''
if with_lnglat:
if col == with_lnglat[0]:
lng_val = row[col]
if col == with_lnglat[1]:
lat_val = row[col]
if geom_col and col == geom_col:
the_geom_val = row[col]

if dataframe_columns_info.geom_column and col == dataframe_columns_info.geom_column:
geom = decode_geometry(val, dataframe_columns_info.enc_type)
if geom:
row_data.append('SRID=4326;{}'.format(geom.wkt))
else:
row_data.append('')
else:
row_data.append('{}'.format(val))

if the_geom_val is not None:
geom = decode_geometry(the_geom_val, enc_type)
if geom:
row_data.append('SRID=4326;{geom}'.format(geom=geom.wkt))

if len(row_data) < columns_number and with_lnglat is not None and lng_val is not None and lat_val is not None:
row_data.append('SRID=4326;POINT({lng} {lat})'.format(lng=lng_val, lat=lat_val))

if len(row_data) < columns_number:
row_data.append('')
if with_lnglat:
lng_val = row[with_lnglat[0]]
lat_val = row[with_lnglat[1]]
if lng_val and lat_val:
row_data.append('SRID=4326;POINT ({lng} {lat})'.format(lng=lng_val, lat=lat_val))
else:
row_data.append('')

csv_row = '|'.join(row_data)
csv_row += '\n'
Expand All @@ -201,75 +157,3 @@ def _is_null(val):
return vnull
else:
return vnull.all()


def _normalize_column_names(df):
column_names = [c for c in df.columns if c not in Column.RESERVED_COLUMN_NAMES]
normalized_columns = normalize_names(column_names)

column_tuples = [(norm, orig) for orig, norm in zip(column_names, normalized_columns)]

changed_cols = '\n'.join([
'\033[1m{orig}\033[0m -> \033[1m{new}\033[0m'.format(
orig=orig,
new=norm)
for norm, orig in column_tuples if norm != orig])

if changed_cols != '':
tqdm.write('The following columns were changed in the CARTO '
'copy of this dataframe:\n{0}'.format(changed_cols))

return column_tuples


def _get_geom_col_name(df):
geom_col = getattr(df, '_geometry_column_name', None)
if geom_col is None:
try:
geom_col = next(x for x in df.columns if x.lower() in Column.SUPPORTED_GEOM_COL_NAMES)
except StopIteration:
pass

return geom_col


def _detect_encoding_type(df, geom_col):
if geom_col is not None:
first_geom = _first_value(df[geom_col])
if first_geom:
return detect_encoding_type(first_geom)
return ''


def _dtypes2pg(dtype):
"""Returns equivalent PostgreSQL type for input `dtype`"""
mapping = {
'float64': 'numeric',
'int64': 'bigint',
'float32': 'numeric',
'int32': 'integer',
'object': 'text',
'bool': 'boolean',
'datetime64[ns]': 'timestamp',
'datetime64[ns, UTC]': 'timestamp',
}
return mapping.get(str(dtype), 'text')


def _get_geom_col_type(df):
geom_col = _get_geom_col_name(df)
if geom_col is not None:
first_geom = _first_value(df[geom_col])
if first_geom:
enc_type = detect_encoding_type(first_geom)
geom = decode_geometry(first_geom, enc_type)
if geom is not None:
return geom.geom_type
else:
warn('Dataset with null geometries')


def _first_value(array):
array = array.loc[~array.isnull()] # Remove null values
if len(array) > 0:
return array.iloc[0]
115 changes: 113 additions & 2 deletions cartoframes/utils/columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@

from unidecode import unidecode

from .geom_utils import detect_encoding_type, decode_geometry


class Column(object):
DATETIME_DTYPES = ['datetime64[D]', 'datetime64[ns]', 'datetime64[ns, UTC]']
SUPPORTED_GEOM_COL_NAMES = ['geom', 'the_geom', 'geometry']
RESERVED_COLUMN_NAMES = SUPPORTED_GEOM_COL_NAMES + ['the_geom_webmercator', 'cartodb_id']
SUPPORTED_GEOM_COL_NAMES = ['the_geom', 'geom', 'geometry']
FORBIDDEN_COLUMN_NAMES = ['the_geom_webmercator']
MAX_LENGTH = 63
MAX_COLLISION_LENGTH = MAX_LENGTH - 4
RESERVED_WORDS = ('ALL', 'ANALYSE', 'ANALYZE', 'AND', 'ANY', 'ARRAY', 'AS', 'ASC', 'ASYMMETRIC', 'AUTHORIZATION',
Expand Down Expand Up @@ -84,6 +86,94 @@ def _slugify(self, value):
return value


class DataframeColumnInfo(object):
def __init__(self, column, geom_column=None, geom_type=None, dtype=None):
if column:
self.dataframe = column
self.database = self._database_column_name(geom_column)
self.database_type = self._db_column_type(geom_column, geom_type, dtype)
else:
self.dataframe = None
self.database = 'the_geom'
self.database_type = 'geometry(Point, 4326)'

def _database_column_name(self, geom_column):
if geom_column and self.dataframe == geom_column:
normalized_name = 'the_geom'
else:
normalized_name = normalize_name(self.dataframe)

return normalized_name

def _db_column_type(self, geom_column, geom_type, dtype):
if geom_column and self.dataframe == geom_column:
db_type = 'geometry({}, 4326)'.format(geom_type or 'Point')
else:
db_type = _dtypes2pg(dtype)

return db_type

def __eq__(self, obj):
if isinstance(obj, dict):
return self.dataframe == obj['dataframe'] and \
self.database == obj['database'] and \
self.database_type == obj['database_type']
else:
return self.dataframe == obj.dataframe and \
self.database == obj.database and \
self.database_type == obj.database_type


class DataframeColumnsInfo(object):
def __init__(self, df, with_lnglat=None):
self.df = df
self.with_lnglat = with_lnglat

self.geom_column = self._get_geom_col_name()
geom_type, enc_type = self._get_geometry_type()
self.geom_type = geom_type
self.enc_type = enc_type

self.columns = self._get_columns_info()

def _get_columns_info(self):
columns = []
for c in self.df.columns:
if self._filter_column(c):
continue

columns.append(DataframeColumnInfo(c, self.geom_column, self.geom_type, self.df.dtypes[c]))

if self.with_lnglat:
columns.append(DataframeColumnInfo(None))

return columns

def _filter_column(self, column):
return column.lower() in Column.FORBIDDEN_COLUMN_NAMES or (self.with_lnglat and column == self.geom_column)

def _get_geom_col_name(self):
geom_col = getattr(self.df, '_geometry_column_name', None)
if geom_col is None:
try:
df_columns = [x.lower() for x in self.df.columns]
geom_col = next(x for x in Column.SUPPORTED_GEOM_COL_NAMES if x in df_columns)
except StopIteration:
pass

return geom_col

def _get_geometry_type(self):
if self.geom_column is not None:
first_geom = _first_value(self.df[self.geom_column])
if first_geom:
enc_type = detect_encoding_type(first_geom)
geom = decode_geometry(first_geom, enc_type)
return geom.geom_type, enc_type

return None, None


def normalize_names(column_names):
"""Given an arbitrary column name, translate to a SQL-normalized column
name a la CARTO's Import API will translate to
Expand Down Expand Up @@ -165,3 +255,24 @@ def pg2dtypes(pgtype):
'USER-DEFINED': 'object',
}
return mapping.get(str(pgtype), 'object')


def _dtypes2pg(dtype):
"""Returns equivalent PostgreSQL type for input `dtype`"""
mapping = {
'float64': 'numeric',
'int64': 'bigint',
'float32': 'numeric',
'int32': 'integer',
'object': 'text',
'bool': 'boolean',
'datetime64[ns]': 'timestamp',
'datetime64[ns, UTC]': 'timestamp',
}
return mapping.get(str(dtype), 'text')


def _first_value(series):
series = series.loc[~series.isnull()] # Remove null values
if len(series) > 0:
return series.iloc[0]
Loading

0 comments on commit fc2d03c

Please sign in to comment.