public
Description: Oliver's Python utilities
Homepage:
Clone URL: git://github.com/osteele/python-utils.git
python-utils / sqlobj.py
100644 135 lines (113 sloc) 3.97 kb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
""" sqlobj --- Connection object and middleware for SQL
"""
 
__author__ = "Oliver Steele <steele@osteele.com>"
__copyright__ = "Copyright 1999-2001 by Oliver Steele."
__license__ = "Python License"
 
#
# Database Connection
#
 
import MySQLdb
 
class Connection:
    trace = 0
    simulation = 0
 
    def __init__(self, **keys):
        self.db = MySQLdb.Connect(**keys)
        self.cursor = self.db.cursor()
        self.nextid = 0 # for debugging
 
    def sqlvalue(self, s):
        if s is None:
            return 'NULL'
        elif type(s) == type(''):
            return '"' + self.db.escape_string(s) + '"'
        else:
            return "%s" % s
 
    def sqltest(self, k, v):
        return "%s=%s" % (k, self.sqlvalue(v))
 
    def insert(self, expr):
        if self.trace:
            print expr
        if self.simulation:
            pass
        else:
            self.cursor.execute(expr)
 
    def select(self, expr, limit=-1):
        if self.trace:
            print expr
        if self.simulation:
            self.nextid += 1
            return [self.nextid]
        else:
            try:
                self.cursor.execute(expr)
            except:
                import sys
                print >> sys.stderr, expr
                raise
            if limit==1:
                return self.cursor.fetchone()
            else:
                return self.cursor.fetchmany(limit)
 
    def insertRow(self, tableName, fields={}):
        self.insert("INSERT INTO " + tableName + "(" + \
               ','.join(["%s" % key for key in fields.keys()]) + ") " + \
               "VALUES (" + ','.join(map(self.sqlvalue, fields.values())) + ");")
 
    def truncateTables(self):
        if self.simulation: return
        for table, in self.select('show tables'):
            self.cursor.execute('truncate table %s' % table)
 
#
# Tables
#
 
# Base class
class Table:
    def __init__(self, name, keys=[], primaryKey=None, connection=None):
        self.name = name
        self.keys = keys
        self.primaryKey = primaryKey
        self.connection = connection
 
    def insert(self, **keys):
        connection.insertRow(self.name, keys)
 
    def lookup(self, **keys):
        results = connection.select( \
                "SELECT %s FROM %s WHERE " % (self.primaryKey, self.name) + \
                ' AND '.join([sqltest(k,v) for k,v in keys.items()]) + ';',
                limit=2)
        if results:
            assert len(results) == 1
            return results[0]
    
    def select(self, **keys):
        assert self.primaryKey
        if self.keys:
            lookupKeys = {}
            for key in self.keys:
                if keys.has_key(key):
                    lookupKeys[key] = keys[key]
            id = self.lookup(**lookupKeys)
            if id:
                #raise "duplicate entry: %r" % keys
                # todo: update if the fields aren't the same
                return id
        self.insert(**keys)
        return select("SELECT LAST_INSERT_ID();", limit=1)[0]
 
    def update(self, **keys):
        raise "Table.update() called"
        results = select( \
                "SELECT 1 FROM %s WHERE " % self.name + \
                ' AND '.join([sqltest(k,v) for (k,v) in keys.items()]) + ';', limit=1)
        if not results:
            self.insert(**keys)
 
    def get_id(self, **attrs):
        return self.select(**attrs)
 
# Entities have a unique primary key
class EntityTable(Table):
    def __init__(self, name, **attrs):
        assert attrs.get('primaryKey')
        Table.__init__(self, name, **attrs)
 
# Details have a non-unique foreign key. A detail is one-to-many
# where the target is a data primitive.
class DetailTable(Table):
    pass
 
# Relations have at least two foreign keys.
# A relation is one-to-many or many-to-many.
class Relation(Table):
    pass