Skip to content

Commit

Permalink
Fixed add_data() method so that sequences can be supplied to the *dat…
Browse files Browse the repository at this point in the history
…a argument.
  • Loading branch information
hover2pi committed Mar 14, 2016
1 parent 010f374 commit 2b45f87
Showing 1 changed file with 36 additions and 17 deletions.
53 changes: 36 additions & 17 deletions astrodbkit/astrodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,18 +88,17 @@ def add_data(self, data, table, delimiter='|', bands=''):
"""
# Store raw entry
entry = data
del_records = []

# Digest the ascii file into table
if isinstance(data,str) and os.path.isfile(data):
data = ii.read(data)

# Or add the sequence of data elements into a table
# Or read the sequence of data elements into a table
elif isinstance(data,(list,tuple,np.ndarray)):
data = at.Table(list(np.asarray(data[1:]).T), names=data[0], dtype=[type(i) for i in data[1]])
data = ii.read(['|'.join(map(str,row)) for row in data], data_start=1, delimiter='|')

else: data = None

if data:

# Get list of all columns and make an empty table for new records
Expand Down Expand Up @@ -128,10 +127,16 @@ def add_data(self, data, table, delimiter='|', bands=''):

# Reject rows that fail column requirements, e.g. NOT NULL fields like 'source_id'
for r in columns[np.where(np.logical_and(required,columns!='id'))]:
# Null values...
new_records = new_records[np.where(new_records[r])]

# Masked values...
new_records = new_records[~new_records[r].mask]

# NaN values...
if new_records.dtype[r] in (int,float):
new_records = new_records[~np.isnan(new_records[r])]

# For spectra, try to populate the table by reading the FITS header
if table.lower()=='spectra':
for n,new_rec in enumerate(new_records):
Expand Down Expand Up @@ -191,7 +196,7 @@ def clean_up(self, table):
# Get the table info and all the records
metadata = self.query("PRAGMA table_info({})".format(table), fmt='table')
columns, types, required = [np.array(metadata[n]) for n in ['name','type','notnull']]
records = self.query("SELECT * FROM {}".format(table), fmt='table')
records = self.query("SELECT * FROM {}".format(table), fmt='table', use_converters=False)
ignore = self.query("SELECT * FROM ignore WHERE tablename LIKE ?", (table,))
duplicate, command = [1], ''

Expand Down Expand Up @@ -256,7 +261,8 @@ def _compare_records(self, table, duplicate, options=['r','c','k','sql']):
"""
# Print the old and new records suspectred of being duplicates
data = self.query("SELECT * FROM {} WHERE id IN ({})".format(table,','.join(map(str,duplicate))), fmt='table', verbose=True)
data = self.query("SELECT * FROM {} WHERE id IN ({})".format(table,','.join(map(str,duplicate))), \
fmt='table', verbose=True, use_converters=False)
columns = data.colnames[1:]
old, new = [[data[n][k] for k in columns[1:]] for n in [0,1]]

Expand Down Expand Up @@ -646,7 +652,7 @@ def plot_spectrum(self, spectrum_id, overplot=False, color='b', norm=False):
except: print "Could not plot spectrum {}".format(spectrum_id); plt.close()
else: print "No spectrum {} in the SPECTRA table.".format(spectrum_id)

def query(self, SQL, params='', fmt='array', fetch='all', unpack=False, export='', verbose=False):
def query(self, SQL, params='', fmt='array', fetch='all', unpack=False, export='', verbose=False, use_converters=True):
"""
Wrapper for cursors so data can be retrieved as a list or dictionary from same method
Expand All @@ -664,6 +670,8 @@ def query(self, SQL, params='', fmt='array', fetch='all', unpack=False, export='
The file path of the ascii file to which the data should be exported
verbose: bool
Print the data also
use_converters: bool
Apply converters to columns with custom data types
Returns
-------
Expand All @@ -675,7 +683,7 @@ def query(self, SQL, params='', fmt='array', fetch='all', unpack=False, export='
if SQL.lower().startswith('select') or SQL.lower().startswith('pragma'):

# Make the query explicit so that column and table names are preserved
SQL, columns = self._explicit_query(SQL)
SQL, columns = self._explicit_query(SQL, use_converters=use_converters)

# Get the data as a dictionary
dictionary = self.dict(SQL, params).fetchall()
Expand Down Expand Up @@ -857,7 +865,7 @@ def table(self, table, columns, types, constraints='', new_table=False):
types, and constraints are formatted properly.'.format(table.upper(),\
'created' if new_table else 'modified')

def _explicit_query(self, SQL):
def _explicit_query(self, SQL, use_converters=True):
"""
Sorts the column names so they are returned in the same order they are queried. Also turns
ambiguous SELECT statements into explicit SQLite language in case column names are not unique.
Expand All @@ -866,6 +874,8 @@ def _explicit_query(self, SQL):
----------
SQL: str
The SQLite query to parse
use_converters: bool
Apply converters to columns with custom data types
Returns
-------
Expand All @@ -889,9 +899,9 @@ def _explicit_query(self, SQL):
except:
tdict[t] = t

# Get all the column names
# Get all the column names and dtype placeholders
columns = SQL.replace(' ','').lower().split('distinct' if 'distinct' in SQL.lower() else 'select')[1].split('from')[0].split(',')

# Replace * with the field names
for n,col in enumerate(columns):
if '.' in col:
Expand All @@ -906,15 +916,24 @@ def _explicit_query(self, SQL):

columns[n] = ["{}.{}".format(t,c) if len(tables)>1 else c for c in col]

# Flatten the list of columns
# Flatten the list of columns and dtypes
columns = [j for k in columns for j in k]

# Get the dtypes
dSQL = "SELECT " \
+ ','.join(["typeof({})".format(col) for col in columns])\
+ ' FROM '+SQL.replace('from','FROM').split('FROM')[-1]
if use_converters: dtypes = [None]*len(columns)
else: dtypes = self.list(dSQL).fetchone()

# Reconstruct SQL query
SQL = "SELECT {}".format('DISTINCT ' if 'distinct' in SQL.lower() else '')\
+','.join(["{0} AS '{0}'".format(col) for col in columns])\
+' FROM '\
+SQL.replace('from','FROM').split('FROM')[-1]

+ (','.join(["{0} AS '{0}'".format(col) for col in columns])\
if use_converters else ','.join(["{1}{0}{2} AS '{0}'".format(col,'CAST(' if dt!='null' else '',' AS {})'.format(dt) if dt!='null' else '') \
for dt,col in zip(dtypes,columns)])) \
+ ' FROM '\
+ SQL.replace('from','FROM').split('FROM')[-1]

elif 'pragma' in SQL.lower():
columns = ['cid','name','type','notnull','dflt_value','pk']

Expand Down

0 comments on commit 2b45f87

Please sign in to comment.