Skip to content

Commit

Permalink
datastore: peewee->sqlite migration and sqlite optimizations
Browse files Browse the repository at this point in the history
  • Loading branch information
johan-bjareholt committed May 13, 2018
1 parent 8a4249a commit 6985b82
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 74 deletions.
3 changes: 2 additions & 1 deletion aw_datastore/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from typing import Dict, Callable, Any
import platform as _platform

from .migration import check_for_migration

from . import storages
from .datastore import Datastore


# The Callable[[Any], ...] here should be Callable[..., ...] but Python 3.5.2 doesn't
# like ellipsises. See here: https://github.com/python/typing/issues/259
def get_storage_methods() -> Dict[str, Callable[[Any], storages.AbstractStorage]]:
Expand Down
50 changes: 50 additions & 0 deletions aw_datastore/migration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from typing import Optional, List
import os
import re
import logging

from aw_core.dirs import get_data_dir
from .storages import AbstractStorage

logger = logging.getLogger(__name__)

def detect_db_files(data_dir: str, datastore_name: str = None, version=None) -> List[str]:
db_files = [filename for filename in os.listdir(data_dir)]
if datastore_name:
db_files = [filename for filename in db_files if filename.split(".")[0] == datastore_name]
if version:
db_files = [filename for filename in db_files if filename.split(".")[1] == "v{}".format(version)]
return db_files

def check_for_migration(datastore: AbstractStorage, datastore_name: str, version: int):
data_dir = get_data_dir("aw-server")

if datastore.sid == "sqlite":
peewee_type = "peewee-sqlite"
peewee_name = peewee_type + "-testing" if datastore.testing else ""
# Migrate from peewee v2
peewee_db_v2 = detect_db_files(data_dir, peewee_name, 2)
if len(peewee_db_v2) > 0:
peewee_v2_to_sqlite_v1(datastore)

def peewee_v2_to_sqlite_v1(datastore):
logger.info("Migrating database from peewee v2 to sqlite v1")
from .storages import PeeweeStorage
pw_db = PeeweeStorage(datastore.testing)
# Fetch buckets and events
buckets = pw_db.buckets()
# Insert buckets and events to new db
for bucket_id in buckets:
logger.info("Migrating bucket {}".format(bucket_id))
bucket = buckets[bucket_id]
datastore.create_bucket(
bucket["id"],
bucket["type"],
bucket["client"],
bucket["hostname"],
bucket["created"],
bucket["name"]
)
bucket_events = pw_db.get_events(bucket_id, -1)
datastore.insert_many(bucket_id, bucket_events)
logger.info("Migration of peewee v2 to sqlite v1 finished")
23 changes: 0 additions & 23 deletions aw_datastore/storages/peewee.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,34 +75,11 @@ def json(self):
}


def detect_db_files(data_dir: str) -> List[str]:
return [filename for filename in os.listdir(data_dir) if "peewee-sqlite" in filename]


def detect_db_version(data_dir: str, max_version: Optional[int] = None) -> Optional[int]:
"""Returns the most recent version number of any database file found (up to max_version)"""
import re
files = detect_db_files(data_dir)
r = re.compile("v[0-9]+")
re_matches = [r.search(filename) for filename in files]
versions = [int(match.group(0)[1:]) for match in re_matches if match]
if max_version:
versions = [v for v in versions if v <= max_version]
return max(versions) if versions else None


class PeeweeStorage(AbstractStorage):
sid = "peewee"

def __init__(self, testing):
data_dir = get_data_dir("aw-server")
current_db_version = detect_db_version(data_dir, max_version=LATEST_VERSION)

if current_db_version is not None and current_db_version < LATEST_VERSION:
# DB file found but was of an older version
logger.info("Latest version database file found was of an older version")
logger.info("Creating database file for new version {}".format(LATEST_VERSION))
logger.warning("ActivityWatch does not currently support database migrations, new database file will be empty")

filename = 'peewee-sqlite' + ('-testing' if testing else '') + ".v{}".format(LATEST_VERSION) + '.db'
filepath = os.path.join(data_dir, filename)
Expand Down
82 changes: 32 additions & 50 deletions aw_datastore/storages/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,6 @@

LATEST_VERSION=1

def detect_db_files(data_dir: str) -> List[str]:
return [filename for filename in os.listdir(data_dir) if "sqlite" in filename]


def detect_db_version(data_dir: str, max_version: Optional[int] = None) -> Optional[int]:
"""Returns the most recent version number of any database file found (up to max_version)"""
import re
files = detect_db_files(data_dir)
r = re.compile("v[0-9]+")
re_matches = [r.search(filename) for filename in files]
versions = [int(match.group(0)[1:]) for match in re_matches if match]
if max_version:
versions = list(filter(lambda v: v <= max_version, versions))
return max(versions) if versions else None


CREATE_BUCKETS_TABLE = """
CREATE TABLE IF NOT EXISTS buckets (
id TEXT PRIMARY KEY,
Expand All @@ -55,36 +39,39 @@ def detect_db_version(data_dir: str, max_version: Optional[int] = None) -> Optio
"""

INDEX_EVENTS_TABLE = """
CREATE INDEX IF NOT EXISTS event_index ON events(bucket, starttime);
CREATE INDEX IF NOT EXISTS event_index ON events(bucket, starttime, endtime);
"""


class SqliteStorage(AbstractStorage):
sid = "sqlite"

def __init__(self, testing):
self.testing = testing
data_dir = get_data_dir("aw-server")
current_db_version = detect_db_version(data_dir, max_version=LATEST_VERSION)

if current_db_version is not None and current_db_version < LATEST_VERSION:
# DB file found but was of an older version
logger.info("Latest version database file found was of an older version")
logger.info("Creating database file for new version {}".format(LATEST_VERSION))
logger.warning("ActivityWatch does not currently support database migrations, new database file will be empty")

filename = self.sid + ('-testing' if testing else '') + ".v{}".format(LATEST_VERSION) + '.db'
ds_name = self.sid + ('-testing' if testing else '')
filename = ds_name + ".v{}".format(LATEST_VERSION) + '.db'
filepath = os.path.join(data_dir, filename)
new_db_file = not os.path.exists(filepath)
self.conn = sqlite3.connect(filepath)
logger.info("Using database file: {}".format(filepath))

# Create tables
c = self.conn.cursor()
c.execute(CREATE_BUCKETS_TABLE)
c.execute(CREATE_EVENTS_TABLE)
c.execute(INDEX_EVENTS_TABLE)

c.execute("PRAGMA journal_mode = WAL;");

self.conn.execute(CREATE_BUCKETS_TABLE)
self.conn.execute(CREATE_EVENTS_TABLE)
self.conn.execute(INDEX_EVENTS_TABLE)
self.conn.execute("PRAGMA journal_mode = WAL;");
self.commit()

if new_db_file:
logger.info("Created new SQlite db file")
from aw_datastore import check_for_migration
check_for_migration(self, ds_name, LATEST_VERSION)

def commit(self):
# Useful for debugging and trying to lower the amount of
# unnecessary commits
self.conn.commit()

def buckets(self):
Expand All @@ -103,21 +90,19 @@ def buckets(self):

def create_bucket(self, bucket_id: str, type_id: str, client: str,
hostname: str, created: str, name: Optional[str] = None):
c = self.conn.cursor()
c.execute("INSERT INTO buckets VALUES (?, ?, ?, ?, ?, ?)",
self.conn.execute("INSERT INTO buckets VALUES (?, ?, ?, ?, ?, ?)",
[bucket_id, name, type_id, client, hostname, created])
self.conn.commit();
self.commit();
return self.get_metadata(bucket_id)

def delete_bucket(self, bucket_id: str):
c = self.conn.cursor()
c.execute("DELETE FROM events WHERE bucket = ?", [bucket_id])
c.execute("DELETE FROM buckets WHERE id = ?", [bucket_id])
self.conn.commit()
self.conn.execute("DELETE FROM events WHERE bucket = ?", [bucket_id])
self.conn.execute("DELETE FROM buckets WHERE id = ?", [bucket_id])
self.commit()

def get_metadata(self, bucket_id: str):
c = self.conn.cursor()
res = c.execute("SELECT * FROM buckets")
res = c.execute("SELECT * FROM buckets WHERE id = ?", [bucket_id])
row = res.fetchone()
bucket = {
"id": row[0],
Expand All @@ -131,7 +116,7 @@ def get_metadata(self, bucket_id: str):

def insert_one(self, bucket_id: str, event: Event) -> Event:
c = self.conn.cursor()
starttime = event.timestamp.timestamp()*1000000
starttime = event.timestamp.timestamp() * 1000000
endtime = starttime + (event.duration.total_seconds() * 1000000)
datastr = json.dumps(event.data)
c.execute("INSERT INTO events(bucket, starttime, endtime, datastr) VALUES (?, ?, ?, ?)",
Expand All @@ -154,40 +139,37 @@ def insert_many(self, bucket_id, events: List[Event], fast=False) -> None:
"VALUES (?, ?, ?, ?)"
self.conn.executemany(query, event_rows)
if len(event_rows) > 50:
self.conn.commit();
self.commit();

def replace_last(self, bucket_id, event):
c = self.conn.cursor()
starttime = event.timestamp.timestamp()*1000000
endtime = starttime + (event.duration.total_seconds() * 1000000)
datastr = json.dumps(event.data)
query = "UPDATE events " + \
"SET starttime = ?, endtime = ?, datastr = ? " + \
"WHERE endtime = (SELECT max(endtime) FROM events WHERE bucket = ?) AND bucket = ?"
c.execute(query, [starttime, endtime, datastr, bucket_id, bucket_id])
self.conn.execute(query, [starttime, endtime, datastr, bucket_id, bucket_id])
return True

def delete(self, bucket_id, event_id):
c = self.conn.cursor()
query = "DELETE FROM events WHERE bucket = ? AND id = ?"
c.execute(query, [bucket_id, event_id])
self.conn.execute(query, [bucket_id, event_id])
# TODO: Handle if event doesn't exist
return True

def replace(self, bucket_id, event_id, event):
c = self.conn.cursor()
starttime = event.timestamp.timestamp()*1000000
endtime = starttime + (event.duration.total_seconds() * 1000000)
datastr = json.dumps(event.data)
query = "UPDATE events " + \
"SET bucket = ?, starttime = ?, endtime = ?, datastr = ? " + \
"WHERE id = ?"
c.execute(query, [bucket_id, starttime, endtime, datastr, event_id])
self.conn.execute(query, [bucket_id, starttime, endtime, datastr, event_id])
return True

def get_events(self, bucket_id: str, limit: int,
starttime: Optional[datetime] = None, endtime: Optional[datetime] = None):
self.conn.commit()
self.commit()
c = self.conn.cursor()
if limit <= 0:
limit = -1
Expand Down Expand Up @@ -216,7 +198,7 @@ def get_events(self, bucket_id: str, limit: int,

def get_eventcount(self, bucket_id: str,
starttime: Optional[datetime] = None, endtime: Optional[datetime] = None):
self.conn.commit()
self.commit()
c = self.conn.cursor()
if not starttime:
starttime = 0
Expand Down

0 comments on commit 6985b82

Please sign in to comment.