Switch branches/tags
Nothing to show
Find file
Fetching contributors…
Cannot retrieve contributors at this time
executable file 357 lines (296 sloc) 12.8 KB
# -*- coding: utf-8 -*-
import os
import sqlite3
import time
from utils import smart_str
class SimpleDB(object):
def __init__(self, filename=None, schema=None, **kwargs):
This class originally inspired by:
if os.path.exists(filename) or schema:
self.db = filename
self.connection = sqlite3.connect(filename,
self.cursor = self.connection.cursor()
except sqlite3.OperationalError:
elif not schema:
# neither a schema nor a filename was found
raise Exception, "The specified database file does not exist,\
and you haven't provided a schema"
# create a new table from the given schema
for table_name, fields in schema.items():
self.query("""CREATE TABLE {0} ({1})""".format(
", ".join(fields)))
def __del__(self):
def query(self, *args, **kwargs):
return self.cursor.execute(*args, **kwargs)
def get_last_row(self, table):
''' Returns the very last row for the given table name
NOTE - this is not the same as lastrowid
last_row = self.cursor.execute("""SELECT * FROM {0}
ORDER BY id DESC LIMIT 1""".format(table))
return last_row[0]
def get_field_list(self, table):
''' Return info about the fields of a given table
@table - the name of the table whose fields you want info on
returns a list of tuples in the form:
(cid, name, type, notnull, dflt_value, pk)
(0, u'id', u'INTEGER', 1, None, 1),
(1, u'timestamp', u'TEXT', 1, None, 0),
(2, u'longitude', u'REAL', 1, None, 0),
(3, u'latitude', u'REAL', 1, None, 0),
(4, u'elevation', u'TEXT', 0, None, 0)
self.cursor.execute("PRAGMA table_info({0})".format(table))
def clear_table(self, table):
''' Clear a table of all its data
self.cursor.execute("DELETE FROM {0}".format(table))
self.cursor.execute("DELETE FROM sqlite_sequence \
WHERE name='{0}'".format(table))
print('deleted all data in the `{0}` table'.format(table))
def dump_table(self, table):
''' Returns an sql command for creating the given table
@table - the name of the table whose string you want
example output:
CREATE TABLE location (
timestamp TEXT NOT NULL,
longitude REAL NOT NULL,
latitude REAL NOT NULL,
elevation TEXT
return self.cursor.execute("""SELECT sql
FROM sqlite_master
WHERE type='table'
AND name='{0}' """.format(table)).fetchone()[0]
def _table_info(self, table):
''' Returns the fields in a given table (taken from Django)
{'name': u'id', 'null_ok': False, 'pk': 1, 'type': u'INTEGER'},
{'name': u'timestamp', 'null_ok': False, 'pk': 0, 'type': u'TEXT'},
{'name': u'longitude', 'null_ok': False, 'pk': 0, 'type': u'REAL'},
{'name': u'latitude', 'null_ok': False, 'pk': 0, 'type': u'REAL'},
{'name': u'elevation', 'null_ok': True, 'pk': 0, 'type': u'TEXT'}
self.cursor.execute('PRAGMA table_info({0})'.format(table))
# cid, name, type, notnull, dflt_value, pk
return [{'name': x[1],
'type': x[2],
'null_ok': not x[3],
'pk': x[5]
} for x in self.cursor.fetchall()]
def change_fieldname(self, table, old_field_name, new_field_name):
''' Rename a a column (field) to some new name.
@table - name of the table whose field you want altered
@old_field_name - current name of the field you want altered
@new_field_name - the name you want to rename the field to
- make sure the field selection works propertly
- Wrapping all this in a BEGIN/END TRANSACTION; and COMMIT;
or ROLLBACK to ensure that the renaming either completes
successfully or not at all- is also probably a good idea.
# self.connection.rollback()
field_list = self.get_field_list(table)
tmp_table = 'tmp_'+table
# get the create statements for the original table and new table
table_schema = self.dump_table(table)
table_schema = table_schema.replace(old_field_name, new_field_name)
# temporarily rename the table
self.cursor.execute("""ALTER TABLE {0} RENAME TO {1}""".format(table, tmp_table))
# create your new table using the original table name & new fields
# return a list of tuples containing the rows of data in the tmp_table
table_data = self.cursor.execute("""SELECT * FROM {0}""".format(
# if there is any data to insert, then insert it
if table_data:
# generate one question mark per field in your data
qmarks = ', '.join('?' * len(table_data[0]))
# insert the data into your new table with updated field name
self.cursor.executemany('''INSERT INTO {0} VALUES({1})'''.format(
table, qmarks), table_data)
# TODO - check to make sure the new table is identical to the original
# remember that right now the orginal table and its data is saved
# as the tmp_table. One way to check if they are the same is to
# query the table to make and make sure the same number of rows
# of data are in both tables
# now drop the temporary table since its no longer needed
# NOTE - this is temporarily disabled until you create a data integrity check
# query('''DROP TABLE {0}'''.format(tmp_table))
# con.commit()
def rename_table(self, old_name, new_name):
''' Renames a table in the database
@old_name - name of the table you want to rename
@new_name - name you want your table to become
def is_duplicate(self, table, row):
'''Checks to see if the row is in the database
Returns True if match found, else false
@table = the table you want to perform the test on
@row = the row whose existince your checking for
# get names of all fields in the table except for the id
#fields = [x[1] for x in self.get_field_list(table)][1:]
# temporary fix to remove loc_id from logs table
#fields = [x for x in fields if x != 'loc_id']
# get table info (to be used when determining field type below)
#field_type = {x['name']:x['type'] for x in self._table_info(table)}
conditionals = []
for i, x in enumerate(row.items()):
field = smart_str(x[0])
value = smart_str(x[1])
if i == 0:
conditionals.append(u"WHERE {0}=\"{1}\"".format(
field, value.decode('utf8')))
conditionals.append(u" AND {0}=\"{1}\"".format(
field, value.decode('utf8')))
conditionals = u"".join(x for x in conditionals)
field_names = tuple(row.keys())
field_values = u"VALUES{0}".format(
tuple([smart_str(x) for x in row.values()]))
sql = u"SELECT id FROM {0} {1}".format(table, conditionals)
out = u"INSERT INTO {0} {1} {2}".format(
table, field_names, field_values)
# if its not a duplicate, return the insert statment
rs = self.cursor.execute(sql).fetchone()
if not rs:
return out # yield
# its a duplicate
return None
def test(tbl='logs'):
>>> db, d = test('location')
for row in d.dict:
sql = db.is_duplicate('location', row)
if sql: db.query(sql)
from utils import unicode_csv_reader, replace_txt, remove_duplicates
import tablib
db = SimpleDB(os.path.join(os.getcwd(), 'test.db'))
if tbl == 'logs':
f_name = 'test_eternity.csv'
headers = ('day','start_time','stop_time','duration',
elif tbl == 'location':
f_name = 'test_gps.csv'
headers = ('latitude','longitude','elevation','timestamp')
raise Exception, "test(args) must = eternity or gps"
# get data
with open(os.path.join(os.getcwd(), f_name), 'r') as f:
#d = list(set([tuple(row) for row in unicode_csv_reader(f)]))
d = remove_duplicates([tuple(row) for row in unicode_csv_reader(f)])
data = tablib.Dataset(*d, headers=headers)
# TODO - adjust replace_txt() function to accept orderedDicts
# since the order of replacement is important.
# replacement dicts
parent_dict = {
u'Media>': u'MEDIA',
u'MISC - Real Life>': u'REAL_LIFE',
u'Basic Routine>Meals & Snacks>': u'BASIC',
u'Basic Routine>': u'BASIC',
u'Salubrious Living>': u'HEALTH',
activity_dict = {
u'RL - MISC - Home': u'HOME',
u'RL - MISC - Outside': u'OUTSIDE',
u'へんたい': u'HENTAI',
u'アニメ': u'ANIME',
u'Grocery Shopping': u'GROCERY-SHOPPING',
u'Restaurant': u'RESTAURANT',
u'Shower & Bathroom': u'SHOWER-BATHROOM'
# test for duplicates in data (skip over the first row to avoid headers)
for row in data.dict[1:]:
if tbl == 'logs':
row['parent'] = replace_txt(row['parent'], parent_dict)
row['activity'] = replace_txt(row['activity'], activity_dict)
sql = db.is_duplicate(tbl, row)
if sql:
CREATE TABLE "location" (
timestamp TEXT NOT NULL,
longitude REAL NOT NULL,
latitude REAL NOT NULL,
elevation REAL)
conditionals = []
field_names = []
field_values = []
for i, field in enumerate(fields):
# format row into its proper type
# val = _getFieldType(field, row[i])
# remove unicode
if i == 0:
conditionals.append("WHERE {0}='{1}'".format(field, row[i]))
conditionals.append(" AND {0}='{1}'".format(field, row[i]))
# WHERE timestamp=? AND longitude=? AND latitude=? and elevation=?
# WHERE latitiude=? AND longitude=? AND elevation=? AND timestamp=?
conditionals = "".join(x for x in conditionals)
field_names = tuple(field_names)
field_values = "VALUES{0}".format(tuple(field_values))
# get table info (to be used when determining field type below)
field_type = {x['name']:x['type'] for x in self._table_info(table)}
def _getFieldType(field, row):
typ = field_type[field]
if typ == u'TEXT': foo = str(row)
elif typ == u'INTEGER': foo = int(row)
elif typ == u'REAL': foo = float(row)
raise Exception, "_getFieldType should always match a fieldtype"
return foo
# replace the old field name with new field name
field_names = ", ".join(x[1] for x in field_list)
def copy_table(self, table, new_name):
rs = self.execute.execute('SELECT * FROM {0}'.format(tab_name))
cnt = 0
for record in rs.fetchall():
val_str = convert2str(record)
dst_cursor.execute("""INSERT INTO {0} VALUES({1})""".format(
new_name, val_str))
cnt += 1
print cnt, val_str
return cnt