Skip to content

Commit

Permalink
began the refactoring of code to enable running on PostgreSQL backend
Browse files Browse the repository at this point in the history
  • Loading branch information
benadida committed Aug 18, 2008
1 parent 3c1589d commit 9e8d221
Show file tree
Hide file tree
Showing 21 changed files with 769 additions and 296 deletions.
127 changes: 127 additions & 0 deletions base/DB.py
@@ -0,0 +1,127 @@
"""
DB Abstraction for PG
Author: ben@adida.net, arjun@arjun.nu
"""

import psycopg2
import psycopg2.extensions
import utils, config
from DBUtils.PooledDB import PooledDB


# do unicode
psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
pool = PooledDB(psycopg2, port = config.DB_PORT, database = config.DB_NAME, host = config.DB_HOST)

class Context:
pass

default_context = Context()

#
# Thread Safe way to manage context
#
def set_context_function(func):
DB.get_context = func

# by default, context is just this module
def get_context():
return default_context

def load():
ctx = get_context()
ctx.conn = pool.connection()
ctx.transact_count = 0
pass

def unload():
ctx = get_context()
del ctx.conn
pass

def _get_cursor():
ctx = get_context()

if not hasattr(ctx,'conn'):
load()

if hasattr(ctx,'cursor'):
del ctx.cursor

ctx.cursor = ctx.conn.cursor()
ctx.cursor.execute('set client_encoding = \'UNICODE\'')
return ctx.cursor

def _cursor_execute(cursor,sql, vars):
cursor.execute(sql, vars)

def _done():
ctx = get_context()
if ctx.transact_count == 0:
ctx.conn.commit()

# transactions
def transact():
ctx = get_context()

if not hasattr(ctx, 'transact_count'):
load()

# nested transactions by just waiting for the last commit
ctx.transact_count += 1

def commit():
ctx = get_context()

ctx.transact_count -= 1;
if ctx.transact_count == 0:
ctx.conn.commit()

def rollback():
ctx = get_context()

ctx.conn.rollback()

ctx.transact_count -= 1;
if ctx.transact_count > 0:
ctx.transact_count = 0
raise Exception('cannot roll back an inner transaction, rolling back everything.')

def perform(sql, level=0, extra_vars= None):
db_cursor = _get_cursor()

_cursor_execute(db_cursor,sql, utils.parent_vars(level+1,extra_vars))
_done()

def oneval(sql, level=0):
singlerow = onerow(sql, level+1)
if singlerow == None:
return None

return singlerow.values()[0]

def onerow(sql, level=0):
rows= multirow(sql, level+1)
if len(rows) == 0:
return None
return rows[0]

def multirow(sql, level=0, extra_vars = None):
db_cursor = _get_cursor()

_cursor_execute(db_cursor,sql, utils.parent_vars(level+1, extra_vars))
rows= db_cursor.fetchall()
colnames = [t[0] for t in db_cursor.description]
dict_rows = [dict(zip(colnames, row)) for row in rows]

# if we'er not in a transaction, commit since pyscopg2 opens a transaction on every query
_done()
return dict_rows

def dbstr(the_str):
if type(the_str) == int or type(the_str) == long or type(the_str) == bool:
return the_str
return "'"+the_str.replace("'","''").replace("%","%%")+"'"


166 changes: 6 additions & 160 deletions base/DBObject.py
@@ -1,21 +1,9 @@
"""
The DBObject base class
Database objects that could be DB-backed, or GAE backed. For now, just GAE.
This needs more work to be much more generic.
ben@adida.net
"""

import utils

from django.utils import simplejson
import datetime

from google.appengine.ext import db


def from_utf8(string):
if type(string) == str:
return string.decode('utf-8')
Expand All @@ -28,151 +16,9 @@ def to_utf8(string):
else:
return string

class DBObject(db.Model):

# GAE get_id
def get_id(self):
return self.key()

@classmethod
def selectById(cls, key_value):
return cls.get(key_value)

@classmethod
def selectByKey(cls, key_name, key_value):
obj = cls()
if obj.select(keys={key_name:key_value}):
return obj
else:
return None

@classmethod
def selectByKeys(cls, keys):
# GAE
all_values= cls.selectAllByKeys(keys)
if len(all_values) == 0:
return None
else:
return all_values[0]

@classmethod
def selectAll(cls, order_by = None, offset = None, limit = None):
# GAE query
query = cls.all()

# order
if order_by:
query.order(order_by)

return query.fetch(limit or 1000, offset or 0)

@classmethod
def selectAllByKey(cls, key_name, key_value, order_by = None, offset = None, limit = None):
keys = dict()
keys[key_name] = key_value
return cls.selectAllByKeys(keys, order_by, offset, limit)

@classmethod
def selectAllByKeys(cls, keys, order_by = None, offset = None, limit = None):
# unicode
for k,v in keys.items():
keys[k] = to_utf8(v)

# GAE query
query = cls.all()

# order
if order_by:
query.order(order_by)

# conditions
for k,v in keys.items():
query.filter('%s' % k, v)

return query.fetch(limit or 1000, offset or 0)

def _load_from_row(self, row, extra_fields=[]):

prepared_row = self._prepare_object_values(row)

for field in self.FIELDS:
# unicode
self.__dict__[field] = from_utf8(prepared_row[field])

for field in extra_fields:
# unicode
self.__dict__[field] = from_utf8(prepared_row[field])

def insert(self):
"""
Insert a new object, but only if it hasn't been inserted yet
"""
self.save()

def update(self):
"""
Update an object
"""
# GAE
self.save()

# DELETE inherited from GAE

@classmethod
def multirow_to_array(cls, multirow, extra_fields=[]):
objects = []

if multirow == None:
return objects

for row in multirow:
one_object = cls()
one_object._load_from_row(row, extra_fields)
objects.append(one_object)

return objects

def toJSONDict(self, extra_fields = []):
# a helper recursive procedure to navigate down the items
# even if they don't have a toJSONDict() method
def toJSONRecurse(item):
if type(item) == int or type(item) == bool or hasattr(item, 'encode') or not item:
return item

if hasattr(item,'toJSONDict'):
return item.toJSONDict()

if type(item) == dict:
new_dict = dict()
for k in item.keys():
new_dict[k] = toJSONRecurse(item[k])
return new_dict

if hasattr(item,'__iter__'):
return [toJSONRecurse(el) for el in item]

return str(item)

# limit the fields to just JSON_FIELDS if it exists
json_dict = dict()
if hasattr(self.__class__,'JSON_FIELDS'):
keys = self.__class__.JSON_FIELDS + extra_fields
else:
keys = extra_fields

# go through the keys and recurse down each one
for f in keys:
## FIXME: major hack here while I figure out how to dynamically get the right field
if hasattr(self, f):
json_dict[f] = toJSONRecurse(getattr(self, f))
else:
if self.__dict__.has_key(f):
json_dict[f] = toJSONRecurse(self.__dict__[f])
else:
continue


return json_dict

def toJSON(self):
return simplejson.dumps(self.toJSONDict())
#from DBObjectGAE import *
try:
from google.appengine.ext import db
from DBObjectGAE import *
except:
from DBObjectStandalone import *

0 comments on commit 9e8d221

Please sign in to comment.