In [None]:
#| default_exp core

# FastSQL

- A MiniDataAPI spec implementation for SQLAlchemy V2

In [None]:
from aimagic import create_magic,models

In [None]:
create_magic(models[0]);

In [None]:
%ai reset

In [None]:
#| export
from dataclasses import dataclass,is_dataclass,asdict,MISSING
import sqlalchemy as sa
from sqlalchemy.orm import Session
from fastcore.utils import *

We create a `Database` class and a `DBTable` class (which is returned by `Database.create`), using sqlalchemy v2. These classes will allow us to directly work with dataclasses such as these:

In [None]:
@dataclass
class User: name:str; pwd:str
@dataclass
class Todo: title:str; name:str; done:bool=False; details:str=''; id:int=None

In [None]:
#| export
_type_map = {int: sa.Integer, str: sa.String, bool: sa.Boolean}
def _column(name, typ, primary=False):
    return sa.Column(name, _type_map[typ], primary_key=primary)

In [None]:
#| export
class Database:
    "A connection to a SQLAlchemy database"
    def __init__(self, conn_str):
        self.conn_str = conn_str
        self.engine = sa.create_engine(conn_str)
        self.meta = sa.MetaData()

    def __repr__(self): return f"Database({self.conn_str})"
    
    @property
    def conn(self): return self.engine.connect()

In [None]:
db = Database("sqlite:///:memory:")

In [None]:
#| export
class DBTable:
    "A connection to a SQLAlchemy table, created if needed"
    def __init__(self, table: sa.Table, database: Database, cls):
        self.table,self.db,self.cls = table,database,cls
        table.create(self.db.engine, checkfirst=True)

    def __repr__(self) -> str: return self.table.name
    
    @property
    def conn(self): return self.db.conn

In [None]:
#| export
@patch
def create(self:Database, cls, pk:str|None=None):
    "Get a table object, creating in DB if needed"
    cols = {k:v for k,v in cls.__dataclass_fields__.items()}
    columns = [] if pk is None else [_column(pk, cols.pop(pk).type, primary=True)]
    columns += [_column(k, v.type) for k,v in cols.items()]
    tbl = sa.Table(cls.__name__, self.meta, *columns)
    return DBTable(tbl, self, cls)

In [None]:
users = db.create(User, pk='name')
todos = db.create(Todo, pk='id')

In [None]:
#| export
@patch
def print_schema(self:Database):
    "Show all tables and columns"
    inspector = sa.inspect(self.engine)
    for table_name in inspector.get_table_names():
        print(f"Table: {table_name}")
        pk_cols = inspector.get_pk_constraint(table_name)['constrained_columns']
        for column in inspector.get_columns(table_name):
            pk_marker = '*' if column['name'] in pk_cols else ''
            print(f"  - {pk_marker}{column['name']}: {column['type']}")

In [None]:
db.print_schema()

Table: Todo
  - *id: INTEGER
  - title: VARCHAR
  - name: VARCHAR
  - done: BOOLEAN
  - details: VARCHAR
Table: User
  - *name: VARCHAR
  - pwd: VARCHAR


In [None]:
#| export
@patch
def exists(self:DBTable):
    "Check if this table exists in the DB"
    return sa.inspect(self.db.engine).has_table(self.table.name)

In [None]:
users.exists()

True

In [None]:
u0 = User('jph','foo')
u1 = User('rlt','bar')
t0 = Todo('do it', 'jph')

In [None]:
#| export
def _wanted(obj): return {k:v for k,v in asdict(obj).items() if v not in (None,MISSING)}

In [None]:
#| export
@patch
def insert(self:DBTable, obj):
    "Insert an object into this table, and return it"
    with self.conn as conn:
        result = conn.execute(sa.insert(self.table).values(**_wanted(obj)).returning(*self.table.columns))
        row = result.one()  # Consume the result set
        conn.commit()
        return self.cls(**row._asdict())

In [None]:
u = users.insert(u0)
assert u.name=='jph'
users.insert(u1)
u

User(name='jph', pwd='foo')

In [None]:
t = todos.insert(t0)
assert t.id
t

Todo(title='do it', name='jph', done=False, details='', id=1)

In [None]:
#| export
@patch
def __call__(self:DBTable, where:str|None=None, where_args:Iterable|dict|None=None,
             order_by:str|None=None, limit:int|None=None, offset:int|None=None, **kw):
    "Query this table"
    query = sa.select(self.table)
    if where_args: kw = {**kw, **where_args}
    if kw: query = query.where(sa.text(where).bindparams(**kw))
    if order_by: query = query.order_by(sa.text(order_by))
    if limit is not None: query = query.limit(limit)
    if offset is not None: query = query.offset(offset)
    with self.conn as conn:
        rows = conn.execute(query).all()
        return [self.cls(**row._asdict()) for row in rows]

In [None]:
assert users()==[u0,u1]
users()

[User(name='jph', pwd='foo'), User(name='rlt', pwd='bar')]

In [None]:
users(where="pwd LIKE :pwd", pwd="b%")

[User(name='rlt', pwd='bar')]

In [None]:
assert todos()==[t]
todos()

[Todo(title='do it', name='jph', done=False, details='', id=1)]

In [None]:
#| export
@patch
def _pk_where(self:DBTable, meth,key):
    if not isinstance(key,tuple): key = (key,)
    pkv = zip(self.table.primary_key.columns, key)
    cond = sa.and_(*[col==val for col,val in pkv])
    return getattr(self.table,meth)().where(cond)

In [None]:
#| export
@patch
def __getitem__(self:DBTable, key):
    "Get item with PK `key`"
    with self.conn as conn:
        qry = self._pk_where('select', key)
        result = conn.execute(qry).first()
    return self.cls(**result._asdict()) if result else None

In [None]:
assert users['jph']==u0
users['jph']

User(name='jph', pwd='foo')

In [None]:
#| export
@patch
def update(self:DBTable, obj):
    d = _wanted(obj)
    pks = tuple(d[k.name] for k in self.table.primary_key)
    with self.conn as conn:
        qry = self._pk_where('update', pks).values(**d).returning(*self.table.columns)
        result = conn.execute(qry)
        row = result.one()
        conn.commit()
        return self.cls(**row._asdict())

In [None]:
u.pwd = 'new'
users.update(u)
users()

[User(name='jph', pwd='new'), User(name='rlt', pwd='bar')]

In [None]:
#| export
@patch
def delete(self:DBTable, key):
    "Delete item with PK `key` and return count deleted"
    with self.conn as conn:
        result = conn.execute(self._pk_where('delete', key))
        conn.commit()
        return result.rowcount

In [None]:
assert users.delete('jph')
assert not users['jph']