In [None]:
#| default_exp core

# Source

Source code for FastSQL, a MiniDataAPI spec implementation for SQLAlchemy V2

In [None]:
#| export
from dataclasses import dataclass, is_dataclass, MISSING, fields, field, make_dataclass
from enum import Enum
from pathlib import Path
from typing import Any, Optional, Union, Iterable, Generator, List, Tuple, Dict, get_args

import sqlalchemy as sa
from sqlalchemy.orm import Session
from fastcore.utils import *
from fastcore.test import test_fail, test_eq
from fastcore.xtras import dataclass_src
from itertools import starmap


In [None]:
#| export
class Default:
    pass


DEFAULT = Default()


## `Database` and `DBTable`

We create a `Database` class and a `DBTable` class (which is returned by `Database.create`), using sqlalchemy v2. These classes will allow us to directly work with dataclasses such as these:

In [None]:
#| export
def _db_str(path):
    if isinstance(path, Path): path = str(path)
    if not isinstance(path, str): return path
    if '://' in path: return path
    if path == ':memory:': return 'sqlite:///:memory:'
    return f"sqlite:///{path}"


In [None]:
conn_str = _db_str(":memory:"); conn_str

'sqlite:///:memory:'

In [None]:
#| export
class Database:
    "A connection to a SQLAlchemy database"

    def __init__(self, conn_str):
        self.conn_str = conn_str
        self.engine = sa.create_engine(conn_str)
        self.meta = sa.MetaData()
        self.meta.reflect(bind=self.engine)
        self.meta.bind = self.engine
        self.conn = self.engine.connect()
        self.meta.conn = self.conn
        self._tables = {}

    def execute(self, st, params=None, opts=None): return self.conn.execute(st, params, execution_options=opts)

    def close(self): self.conn.close()

    def __repr__(self): return f"Database({self.conn_str})"


In [None]:
db = Database(conn_str); db

Database(sqlite:///:memory:)

In [None]:
# create a test table
db.execute(sa.text('create table test (id integer primary key, name text, age integer)'))
db.meta.reflect(bind=db.engine)

In [None]:
db.meta.tables

FacadeDict({'test': Table('test', MetaData(), Column('id', INTEGER(), table=<test>, primary_key=True), Column('name', TEXT(), table=<test>), Column('age', INTEGER(), table=<test>), schema=None)})

In [None]:
#| export
@patch
def q(self: Database, sql: str, **params):
    "Query database with raw SQL and optional parameters. Returns list of dicts."
    result = self.execute(sa.text(sql), params=params)
    if result.returns_rows: return list(map(dict, result.mappings()))
    return []


In [None]:
# add some data to it
db.q('insert into test (name, age) values (:name, :age)', name='Alice', age=30)
db.q('insert into test (name, age) values (:name, :age)', name='Bob', age=25)


[]

In [None]:
db.q('select * from test')

[{'id': 1, 'name': 'Alice', 'age': 30}, {'id': 2, 'name': 'Bob', 'age': 25}]

In [None]:
# update Alice's age to 32
db.q('update test set age = :age where name = :name', age=32, name='Alice')

[]

In [None]:
db.q('select * from test')

[{'id': 1, 'name': 'Alice', 'age': 32}, {'id': 2, 'name': 'Bob', 'age': 25}]

In [None]:
#| export
class DBTable:
    "A connection to a SQLAlchemy table, created if needed"

    def __init__(self, table: sa.Table, db: Database, cls, _exists=None):
        store_attr()
        self.xtra_id, self.result = {}, []
        if len(table.columns) > 0:
            table.create(self.db.engine, checkfirst=True)
            self._exists = True

    def __repr__(self):
        if self._exists is False or (self._exists is None and len(self.table.columns) == 0):
            return f"<Table {self.table.name} (does not exist yet)>"
        return f"<Table {self.table.name} ({', '.join(self.table.c.keys())})>"

    def __str__(self):
        return f'"{self.table.name}"'

    @property
    def conn(self):
        return self.db.conn

    def xtra(self, **kwargs):
        "Set `xtra_id`"
        self.xtra_id = kwargs


In [None]:
# create a DBTable with the Test table
tbl = DBTable(db.meta.tables['test'], db, None)
tbl

<Table test (id, name, age)>

In [None]:
#| export
@patch(as_prop=True)
def t(self: DBTable):
    return self.table, self.table.c


In [None]:
tbl.t

(Table('test', MetaData(), Column('id', INTEGER(), table=<test>, primary_key=True), Column('name', TEXT(), table=<test>), Column('age', INTEGER(), table=<test>), schema=None),
 <sqlalchemy.sql.base.ReadOnlyColumnCollection>)

In [None]:
#| export
@patch(as_prop=True)
def pks(self: DBTable):
    return tuple(self.table.primary_key) + tuple(self.table.c[o] for o in self.xtra_id.keys())


In [None]:
tbl.pks

(Column('id', INTEGER(), table=<test>, primary_key=True),)

In [None]:
#| export
@patch(as_prop=True)
def schema(self: DBTable):
    return str(sa.schema.CreateTable(self.table).compile(self.db.engine)).strip()


In [None]:
print(tbl.schema)

CREATE TABLE test (
	id INTEGER, 
	name TEXT, 
	age INTEGER, 
	PRIMARY KEY (id)
)


In [None]:
#| export
@patch
def table(self: Database, nm: str, cls=None):
    if nm in self._tables: return self._tables[nm]

    if nm in self.meta.tables:
        tbl = self.meta.tables[nm]
        exists = True
    else:
        inspector = sa.inspect(self.engine)
        if nm in inspector.get_table_names() or nm in inspector.get_view_names():
            tbl = sa.Table(nm, self.meta, autoload_with=self.engine)
            exists = True
        else:
            tbl = sa.Table(nm, self.meta)
            exists = False

    if cls is None and hasattr(tbl, "cls"): cls = tbl.cls
    res = DBTable(tbl, self, cls, _exists=exists)
    self._tables[nm] = res
    return res


In [None]:
db.table('test')

<Table test (id, name, age)>

In [None]:
#| export
@patch
def __getitem__(self: Database, nm: str):
    return self.table(nm)


In [None]:
db['test']

<Table test (id, name, age)>

In [None]:
#| export
def database(path, wal=True, **kwargs) -> Any:
    "Create a `Database` from a path or connection string"
    conn_str = _db_str(path)
    db = Database(conn_str)
    if wal and str(conn_str).startswith("sqlite:"): db.execute(sa.text("PRAGMA journal_mode=WAL"))
    return db


In [None]:
db = database("sqlite:///:memory:"); db

Database(sqlite:///:memory:)

In [None]:
#| export
def _get_flds(tbl):
    flds = []
    for c in tbl.columns:
        try:
            typ = c.type.python_type
        except Exception:
            typ = str
        flds.append((c.name, typ | None, field(default=UNSET)))
    return flds


In [None]:
#| export
def _dataclass(self: DBTable, store=True, suf="") -> type:
    "Create a `dataclass` with the types and defaults of this table"
    res = make_dataclass(self.table.name.title() + suf, _get_flds(self.table))
    flexiclass(res)
    if store:
        self.cls = res
    return res


DBTable.dataclass = _dataclass


def all_dcs(db, with_views=False, store=True, suf=""):
    "dataclasses for all objects in `db`"
    return [o.dataclass(store=store, suf=suf) for o in list(db.t)]


def create_mod(db, mod_fn, with_views=False, store=True, suf=""):
    "Create module for dataclasses for `db`"
    mod_fn = str(mod_fn)
    if not mod_fn.endswith(".py"): mod_fn += ".py"
    dcs = all_dcs(db, with_views, store=store, suf=suf)
    strlist = ", ".join([f'"{o.__name__}"' for o in dcs])
    with open(mod_fn, "w") as f:
        print(f"__all__ = [{strlist}]", file=f)
        print("from dataclasses import dataclass", file=f)
        print("from fastsql.core import UNSET", file=f)
        print("import datetime, decimal", file=f)
        for o in dcs:
            print(dataclass_src(o), file=f)


In [None]:
#| export
@patch
def link_dcs(self: Database, mod):
    "Set the internal dataclass type links for tables using `mod` (created via `create_mod`)"
    for o in mod.__all__:
        self.t[o.lower()].cls = getattr(mod, o)


@patch
def set_classes(self: Database, glb):
    "Add set all table dataclasses using types in namespace `glb`"
    for tbl in self.t:
        tbl.cls = glb[tbl.table.name.title()]


@patch
def get_tables(self: Database, glb):
    "Add objects for all table objects to namespace `glb`"
    for tbl in self.t:
        glb[tbl.table.name.lower() + "s"] = tbl


In [None]:
#| export
@patch
def lookup(
    self: DBTable,
    lookup_values: Dict[str, Any],
    extra_values: Dict[str, Any] | None = None,
    pk: str | None = "id",
    **kwargs,
):
    if not lookup_values: lookup_values = {}
    lookup_values = {**lookup_values, **kwargs}
    if not lookup_values: return {}
    where = " and ".join([f'"{k}" = :{k}' for k in lookup_values.keys()])
    res = self(where=where, where_args=lookup_values, limit=2)
    if res: return res[0]
    data = {**lookup_values, **(extra_values or {})}
    return self.insert(data)


def _pk_names(self):
    return [c.name for c in self.table.primary_key]


In [None]:
#| export
@patch
def upsert(
    self: DBTable,
    record: Any = None,
    pk=DEFAULT,
    foreign_keys=DEFAULT,
    column_order: Union[List[str], Default, None] = DEFAULT,
    not_null: Union[Iterable[str], Default, None] = DEFAULT,
    defaults: Union[Dict[str, Any], Default, None] = DEFAULT,
    hash_id: Union[str, Default] | None = DEFAULT,
    hash_id_columns: Union[Iterable[str], Default, None] = DEFAULT,
    alter: Union[bool, Default] | None = DEFAULT,
    extracts: Union[Dict[str, str], List[str], Default, None] = DEFAULT,
    conversions: Union[Dict[str, str], Default, None] = DEFAULT,
    columns: Union[Dict[str, Any], Default, None] = DEFAULT,
    strict: Union[bool, Default] | None = DEFAULT,
    **kwargs,
) -> Any:
    record = _process_row(record)
    record = {**record, **kwargs}
    if not record: return {}
    if pk == DEFAULT: pk = _pk_names(self)
    if isinstance(pk, str): pk = [pk]
    missing = [k for k in pk if k not in record]
    if missing: raise MissingPrimaryKey()
    record = {**record, **self.xtra_id}
    dialect = self.db.engine.dialect.name
    if dialect in ("sqlite", "postgresql"):
        if dialect == "sqlite":
            from sqlalchemy.dialects.sqlite import insert as dialect_insert
        else:
            from sqlalchemy.dialects.postgresql import insert as dialect_insert
        ins = dialect_insert(self.table).values(**record)
        stmt = ins.on_conflict_do_update(index_elements=pk, set_=record).returning(
            *self.table.columns
        )
        row = self.conn.execute(stmt).one()
        self.conn.commit()
        return _row_to_obj(self, row)
    if dialect in ("mysql", "mariadb"):
        from sqlalchemy.dialects.mysql import insert as dialect_insert

        ins = dialect_insert(self.table).values(**record)
        stmt = ins.on_duplicate_key_update(**record)
        result = self.conn.execute(stmt)
        self.conn.commit()
        try:
            row = result.one()
            return _row_to_obj(self, row)
        except Exception:
            return self.get([record[k] for k in pk])
    existing = None
    try:
        existing = self.get([record[k] for k in pk])
    except NotFoundError:
        pass
    if existing: return self.update(record)
    return self.insert(record)


## CRUD Operations

In [None]:
#| export
def _is_enum(o):
    return isinstance(o, type) and issubclass(o, Enum)


def _enum_types(e):
    return {type(v.value) for v in e}


def _parse_typ(t):
    if not (_args := get_args(t)): return t
    return first(_args, bool)


def get_typ(t):
    "Get the underlying type."
    t = _parse_typ(t)
    if _is_enum(t) and len(types := _enum_types(t)) == 1: return first(types)
    return t


In [None]:
get_typ(int|None), get_typ(list[str])

(int, str)

In [None]:
#| export
_type_map = {
    int: sa.Integer,
    str: sa.String,
    bool: sa.Boolean,
    float: sa.Float,
    bytes: sa.LargeBinary,
}


def _sa_type(typ):
    typ = get_typ(typ)
    if typ in _type_map: return _type_map[typ]
    return sa.String


In [None]:
_sa_type(int)

sqlalchemy.sql.sqltypes.Integer

In [None]:
#| export
def _column(name, typ, primary=False, nullable=True, default=MISSING):
    args = {}
    if default is not MISSING and default is not UNSET: args["default"] = default
    return sa.Column(name, _sa_type(typ), primary_key=primary, nullable=nullable, **args)


In [None]:
_column('age', int)

Column('age', Integer(), table=None)

In [None]:
#| export
@patch
def create(
    self: Database,
    cls: type,
    pk="id",
    name: str | None = None,
    foreign_keys=None,
    defaults=None,
    column_order=None,
    not_null=None,
    hash_id=None,
    hash_id_columns=None,
    extracts=None,
    if_not_exists=False,
    replace=False,
    ignore=True,
    transform=False,
    strict=False,
):
    "Get a table object, creating in DB if needed"
    pk = listify(pk)
    not_null = setify(not_null)
    defaults = defaults or {}
    flexiclass(cls)
    if name is None: name = camel2snake(cls.__name__)
    cols = [
        _column(
            o.name,
            o.type,
            primary=o.name in pk,
            nullable=o.name not in not_null,
            default=defaults.get(o.name, MISSING),
        )
        for o in fields(cls)
    ]
    tbl = sa.Table(name, self.meta, *cols, extend_existing=True)
    res = DBTable(tbl, self, cls)
    tbl.cls = cls
    self._tables[name] = res
    return res


In [None]:
class User: name:str; pwd:str
class Todo: title:str; name:str; id:int; done:bool=False; details:str=''
class Student: id:int; grad_year:int; name:str

In [None]:
users = db.create(User, pk='name')
todos = db.create(Todo, pk='id')
students = db.create(Student, pk=('id', 'grad_year'))

In [None]:
#| export
@patch
def table_names(self: Database):
    return sa.inspect(self.engine).get_table_names()


@patch
def view_names(self: Database):
    return sa.inspect(self.engine).get_view_names()


In [None]:
db.table_names()

['student', 'todo', 'user']

In [None]:
#| export
@patch
def schema(self: Database):
    "Show all tables and columns"
    inspector = sa.inspect(self.engine)
    res = ""
    for table_name in inspector.get_table_names():
        res += f"Table: {table_name}\n"
        pk_cols = inspector.get_pk_constraint(table_name)["constrained_columns"]
        for column in inspector.get_columns(table_name):
            pk_marker = "*" if column["name"] in pk_cols else "-"
            res += f"  {pk_marker} {column['name']}: {column['type']}\n"
    return res

In [None]:
print(db.schema())

Table: student
  * id: INTEGER
  * grad_year: INTEGER
  - name: VARCHAR
Table: todo
  - title: VARCHAR
  - name: VARCHAR
  * id: INTEGER
  - done: BOOLEAN
  - details: VARCHAR
Table: user
  * name: VARCHAR
  - pwd: VARCHAR



In [None]:
#| export
class _Getter:
    "Abstract class with dynamic attributes providing access to DB objects"

    def __init__(self, db):
        self.db = db

    def __repr__(self): return ", ".join(dir(self))

    def __contains__(self, s): return (s if isinstance(s, str) else s.name) in dir(self)

    def __iter__(self): return iter(self[dir(self)])

    def __getitem__(self, idxs):
        if isinstance(idxs, str): return self.db.table(idxs)
        return [self.db.table(o) for o in idxs]

    def __getattr__(self, k):
        if k[0] == "_": raise AttributeError
        return self.db[k]


class _TablesGetter(_Getter):
    def __dir__(self): return [o for o in self.db.table_names() if not o.startswith("sqlite_")]


@patch(as_prop=True)
def t(self: Database):
    return _TablesGetter(self)


By returning a `_TablesGetter` we get a repr and auto-complete that shows all tables in the DB.

In [None]:
db.t

student, todo, user

In [None]:
#| export
class _Col:
    def __init__(self, t, c):
        self.t, self.c = t, c

    def __str__(self): return f'"{self.t}"."{self.c}"'

    def __repr__(self): return self.c

    def __iter__(self): return iter(self.c)


class _ColsGetter:
    def __init__(self, tbl):
        self.tbl = tbl

    def __dir__(self): return map(repr, self())

    def __call__(self): return [_Col(self.tbl.name, o.name) for o in self.tbl.columns]

    def __contains__(self, s): return (s if isinstance(s, str) else s.c) in self.tbl.columns

    def __repr__(self): return ", ".join(dir(self))

    def __getattr__(self, k):
        if k[0] == "_": raise AttributeError
        return _Col(self.tbl.name, k)


@patch(as_prop=True)
def c(self: DBTable):
    return _ColsGetter(self.table)


Similarly, we return `_ColsGetter` for a table's columns

In [None]:
users.c

name, pwd

In [None]:
#| export
@patch
def exists(self: DBTable):
    "Check if this table exists in the DB"
    return sa.inspect(self.db.engine).has_table(self.table.name)


@patch
def create(
    self: DBTable,
    columns: Dict[str, Any] = None,
    pk: Any = None,
    foreign_keys=None,
    column_order: Optional[List[str]] = None,
    not_null: Optional[Iterable[str]] = None,
    defaults: Optional[Dict[str, Any]] = None,
    hash_id: str | None = None,
    hash_id_columns: Optional[Iterable[str]] = None,
    extracts: Union[Dict[str, str], List[str], None] = None,
    if_not_exists: bool = False,
    replace: bool = False,
    ignore: bool = False,
    transform: bool = False,
    strict: bool = False,
    **kwargs,
):
    "Create table from column definitions passed as kwargs or columns dict"
    if columns is None: columns = {}
    columns = {**columns, **kwargs}

    if not columns: raise ValueError("No columns specified for table creation")

    if pk is None and "id" in columns:
        pk = "id"
    elif pk is None:
        pk = list(columns.keys())[0]

    pk = listify(pk)
    not_null = setify(not_null)
    defaults = defaults or {}

    cols = [
        _column(
            name,
            typ,
            primary=name in pk,
            nullable=name not in not_null,
            default=defaults.get(name, MISSING),
        )
        for name, typ in columns.items()
    ]

    if self.table.name in self.db.meta.tables: self.db.meta.remove(self.table)

    new_tbl = sa.Table(self.table.name, self.db.meta, *cols, extend_existing=True)

    new_tbl.create(self.db.engine, checkfirst=True)

    self.table = new_tbl
    self._exists = True

    return self


In [None]:
users.exists()

True

In [None]:
u0 = User('jph','foo')
u1 = User('rlt','bar')
t0 = Todo('do it', 'jph')
t1 = Todo('get it done', 'rlt')

In [None]:
#| export
def _obj_to_dict(obj):
    if obj is None: return {}
    if isinstance(obj, dict): return obj
    if is_dataclass(obj): return asdict(obj)
    if hasattr(obj, "__dict__"): return dict(obj.__dict__)
    return {}


def _process_row(row):
    row = _obj_to_dict(row)
    if not row: return {}
    res = {}
    for k, v in row.items():
        if v is UNSET: continue
        if isinstance(v, Enum): v = v.value
        res[k] = v
    return res


def _row_to_obj(self, row, as_cls=True):
    if row is None: return {}
    data = row._asdict() if hasattr(row, "_asdict") else dict(row)
    return self.cls(**data) if as_cls and self.cls else data


In [None]:
#| export
@patch
def insert(
    self: DBTable,
    record: Dict[str, Any] = None,
    pk=DEFAULT,
    foreign_keys=DEFAULT,
    column_order: Union[List[str], Default, None] = DEFAULT,
    not_null: Union[Iterable[str], Default, None] = DEFAULT,
    defaults: Union[Dict[str, Any], Default, None] = DEFAULT,
    hash_id: Union[str, Default, None] = DEFAULT,
    hash_id_columns: Union[Iterable[str], Default, None] = DEFAULT,
    alter: Union[bool, Default, None] = DEFAULT,
    ignore: Union[bool, Default, None] = DEFAULT,
    replace: Union[bool, Default, None] = DEFAULT,
    extracts: Union[Dict[str, str], List[str], Default, None] = DEFAULT,
    conversions: Union[Dict[str, str], Default, None] = DEFAULT,
    columns: Union[Dict[str, Any], Default, None] = DEFAULT,
    strict: Union[bool, Default, None] = DEFAULT,
    **kwargs,
) -> Any:
    "Insert an object into this table, and return it"
    record = _process_row(record)
    record = {**record, **kwargs}
    if not record: return {}
    record = {**record, **self.xtra_id}
    result = self.conn.execute(
        sa.insert(self.table).values(**record).returning(*self.table.columns)
    )
    row = result.one()
    self.conn.commit()
    return _row_to_obj(self, row)


In [None]:
t = todos.insert(t0)
assert t.id
t

Todo(title='do it', name='jph', id=1, done=False, details='')

In [None]:
u = users.insert(u0)
assert u.name=='jph'
users.insert(u1)
u

User(name='jph', pwd='foo')

In [None]:
todos.insert(t1)

Todo(title='get it done', name='rlt', id=2, done=False, details='')

In [None]:
#| export
@patch
def insert_all(
    self: DBTable,
    records: Iterable | None = None,
    pk=DEFAULT,
    foreign_keys=DEFAULT,
    column_order: Union[List[str], Default, None] = DEFAULT,
    not_null: Union[Iterable[str], Default, None] = DEFAULT,
    defaults: Union[Dict[str, Any], Default, None] = DEFAULT,
    batch_size=DEFAULT,
    hash_id: Union[str, Default, None] = DEFAULT,
    hash_id_columns: Union[Iterable[str], Default, None] = DEFAULT,
    alter: Union[bool, Default, None] = DEFAULT,
    ignore: Union[bool, Default, None] = DEFAULT,
    replace: Union[bool, Default, None] = DEFAULT,
    truncate=False,
    extracts: Union[Dict[str, str], List[str], Default, None] = DEFAULT,
    conversions: Union[Dict[str, str], Default, None] = DEFAULT,
    columns: Union[Dict[str, Any], Default, None] = DEFAULT,
    strict: Union[bool, Default, None] = DEFAULT,
    upsert: bool = False,
    analyze: bool = False,
    xtra: dict | None = None,
    **kwargs,
) -> "DBTable":
    if records is None: records = []
    if not xtra: xtra = getattr(self, "xtra_id", {})
    recs = []
    for o in list(records):
        row = _process_row(o)
        if not row: continue
        recs.append({**row, **xtra})
    if not recs:
        self.result = []
        return self
    stmt = sa.insert(self.table).returning(*self.table.columns)
    result = self.conn.execute(stmt, recs)
    rows = result.fetchall()
    self.conn.commit()
    self.result = [_row_to_obj(self, r) for r in rows]
    return self


In [None]:
todos.insert_all([t0,t1])

<Table todo (title, name, id, done, details)>

In [None]:
#| export
def _bind_where(where, where_args):
    if not where_args: return where, {}
    if isinstance(where_args, dict): return where, where_args
    if not where: return where, {}
    params = {}
    if "?" in where:
        for i, v in enumerate(where_args):
            key = f"param_{i}"
            where = where.replace("?", f":{key}", 1)
            params[key] = v
        return where, params
    for i, v in enumerate(where_args):
        params[f"param_{i}"] = v
    return where, params


In [None]:
#| export
def _where(
    where: Optional[str] = None,
    where_args: Optional[Union[Iterable, dict]] = None,
    xtra: Optional[str] = None,
    **kw,
):
    if xtra:
        xw = " and ".join(f'"{k}" = :xtra_{k}' for k in xtra.keys())
        where = f"{xw} and {where}" if where else xw
        kw = {**kw, **{f"xtra_{k}": v for k, v in xtra.items()}}
    params = {}
    if isinstance(where_args, dict):
        params = {**where_args, **kw}
        where, params = _bind_where(where, params)
    elif where_args is None:
        params = kw
        where, params = _bind_where(where, params)
    else:
        where, params = _bind_where(where, where_args)
        params = {**params, **kw}
    return sa.text(where).bindparams(**params)


In [None]:
#| export
@patch
def count_where(
    self: DBTable,
    where: Optional[str] = None,
    where_args: Optional[Union[Iterable, dict]] = None,
    **kw,
) -> int:
    stmt = sa.select(sa.func.count()).select_from(self.table)
    if where: stmt = stmt.where(_where(where, where_args, **kw))
    return int(self.conn.execute(stmt).scalar_one())


In [None]:
#| export
@patch(as_prop=True)
def count(self: DBTable):
    return self.count_where()


In [None]:
#| export
@patch
def __len__(self: DBTable):
    return self.count


In [None]:
todos.count, len(todos)

(4, 4)

In [None]:
#| export
@patch
def rows_where(
    self: DBTable,
    where: Optional[str] = None,
    where_args: Optional[Union[Iterable, dict]] = None,
    order_by: Optional[str] = None,
    select: str = "*",
    limit: Optional[int] = None,
    offset: Optional[int] = None,
    xtra: dict | None = None,
    **kw,
) -> Generator[dict, None, None]:
    if select == "*": query = sa.select(self.table)
    else:
        columns = [sa.text(col.strip()) for col in select.split(",")]
        query = sa.select(*columns).select_from(self.table)
    if where or xtra or kw: query = query.where(_where(where, where_args, xtra or getattr(self, "xtra_id", {}), **kw))
    if order_by: query = query.order_by(sa.text(order_by))
    if limit is not None: query = query.limit(limit)
    if offset is not None: query = query.offset(offset)
    rows = self.conn.execute(query).mappings().all()
    for row in rows: yield dict(row)


In [None]:
#| export
@patch
def pks_and_rows_where(
    self: DBTable,
    where: Optional[str] = None,
    where_args: Optional[Union[Iterable, dict]] = None,
    order_by: Optional[str] = None,
    select: str = "*",
    limit: Optional[int] = None,
    offset: Optional[int] = None,
    xtra: dict | None = None,
    **kw,
) -> Generator[Tuple[Any, Dict], None, None]:
    for row in self.rows_where(
        where=where,
        where_args=where_args,
        order_by=order_by,
        limit=limit,
        offset=offset,
        select=select,
        xtra=xtra,
        **kw,
    ):
        pk_cols = [c.name for c in self.table.primary_key]
        pk = tuple(row[o] for o in pk_cols)
        pk = pk[0] if len(pk) == 1 else pk
        yield pk, row


In [None]:
#| export
@patch(as_prop=True)
def rows(self: DBTable): return self.rows_where()

In [None]:
#| export
@patch
def __call__(
    self: DBTable,
    where: str | None = None,  # SQL where fragment to use, for example `id > ?`
    where_args: Iterable | dict | NoneType = None,  # Parameters to use with `where`; iterable for `id>?`, or dict for `id>:id`
    order_by: str | None = None,  # Column or fragment of SQL to order by
    limit: int | None = None,  # Number of rows to limit to
    offset: int | None = None,  # SQL offset
    select: str = "*",  # Comma-separated list of columns to select
    with_pk: bool = False,  # Return tuple of (pk,row)?
    as_cls: bool = True,  # Convert returned dict to stored dataclass?
    xtra: dict | None = None,  # Extra constraints
    **kw,
):  # Combined with `where_args`
    "Result of `select` query on the table. Returns list of returned objects."
    f = self.pks_and_rows_where if with_pk else self.rows_where
    res = f(where=where, where_args=where_args, order_by=order_by, limit=limit, offset=offset, select=select, xtra=xtra or getattr(self, "xtra_id", {}), **kw)
    if as_cls and self.cls:
        if with_pk: res = ((k, self.cls(**v)) for k, v in res)
        else: res = (self.cls(**o) for o in res)
    return list(res)


In [None]:
assert users()==[u0,u1]
users()

[User(name='jph', pwd='foo'), User(name='rlt', pwd='bar')]

In [None]:
r = users(where="pwd LIKE :pwd", pwd="b%")
assert r==[u1]
r

[User(name='rlt', pwd='bar')]

In [None]:
users.xtra(name='rlt')
users(order_by='name')

[User(name='rlt', pwd='bar')]

In [None]:
users(where='name="rlt"')

[User(name='rlt', pwd='bar')]

In [None]:
users(where='name="jph"')

[]

In [None]:
assert len(todos())==4
todos()

[Todo(title='do it', name='jph', id=1, done=False, details=''),
 Todo(title='get it done', name='rlt', id=2, done=False, details=''),
 Todo(title='do it', name='jph', id=3, done=False, details=''),
 Todo(title='get it done', name='rlt', id=4, done=False, details='')]

In [None]:
#| export
@patch
def create_view(
    self: Database, name: str, sql: str, ignore: bool = False, replace: bool = False
):
    "Create a view with the given name and SQL query"
    if replace: self.execute(sa.text(f"DROP VIEW IF EXISTS {name}"))
    elif not ignore and name in self.view_names(): raise ValueError(f"View {name} already exists")
    self.execute(sa.text(f"CREATE VIEW IF NOT EXISTS {name} AS {sql}"))
    self.meta.reflect(bind=self.engine)

In [None]:
# Create a view showing only incomplete todos
db.create_view('pending_todos', 'SELECT * FROM todo WHERE done = 0')

In [None]:
#| export
class _ViewsGetter(_Getter):
    def __dir__(self): return self.db.view_names()

@patch(as_prop=True)
def v(self: Database): return _ViewsGetter(self)

In [None]:
db.v.pending_todos()

[{'title': 'do it', 'name': 'jph', 'id': 1, 'done': False, 'details': ''},
 {'title': 'get it done',
  'name': 'rlt',
  'id': 2,
  'done': False,
  'details': ''},
 {'title': 'do it', 'name': 'jph', 'id': 3, 'done': False, 'details': ''},
 {'title': 'get it done',
  'name': 'rlt',
  'id': 4,
  'done': False,
  'details': ''}]

In [None]:
#| export
@patch
def _pk_where(self: DBTable, meth, key):
    if not isinstance(key, (tuple, list)): key = (key,)
    if isinstance(key, list): key = tuple(key)
    xtra = self.xtra_id
    pkv = zip(self.pks, key + tuple(xtra.values()))
    cond = sa.and_(*[col == val for col, val in pkv])
    return getattr(self.table, meth)().where(cond)


In [None]:
#| export
class NotFoundError(Exception): pass
class MissingPrimaryKey(Exception): pass

In [None]:
#| export
@patch
def get(
    self: DBTable,
    pk_values,
    as_cls: bool = True,
    xtra: dict | None = None,
    default: Any = UNSET,
) -> Any:
    "Get item with PK `key`"
    if not isinstance(pk_values, (list, tuple)): pk_values = [pk_values]
    xtra = xtra or getattr(self, "xtra_id", {})
    vals = list(pk_values) + list(xtra.values())
    cols = list(self.table.primary_key.columns) + [self.table.c[k] for k in xtra.keys()]
    if len(cols) != len(vals): raise NotFoundError(f"Need {len(cols)} pk")
    cond = sa.and_(*[col == val for col, val in zip(cols, vals)])
    qry = sa.select(self.table).where(cond)
    result = self.conn.execute(qry).first()
    if not result:
        if default is UNSET: raise NotFoundError()
        return default
    return _row_to_obj(self, result, as_cls=as_cls)


@patch
def __getitem__(self: DBTable, key):
    return self.get(key)

In [None]:
users.xtra(name='jph')
assert users['jph']==u0
users['jph']

User(name='jph', pwd='foo')

In [None]:
users.xtra(name='rlt')
test_fail(lambda: users['jph']==u0)

In [None]:
#| export
@patch
def selectone(
    self: DBTable,
    where: str | None = None,
    where_args: Iterable | dict | NoneType = None,
    select: str = "*",
    as_cls: bool = True,
    xtra: dict | None = None,
    **kwargs,
):
    res = self(
        where=where,
        where_args=where_args,
        select=select,
        as_cls=as_cls,
        xtra=xtra,
        limit=2,
        **kwargs,
    )
    if len(res) == 0: raise NotFoundError
    if len(res) > 1: raise ValueError(f"Not unique: {len(res)} results")
    return res[0]


In [None]:
users.xtra()  # clear xtra
test_eq(users.selectone('name=?', ['jph']).name, 'jph')

In [None]:
#| export
@patch
def update(
    self: DBTable,
    updates: dict | None = None,
    pk_values: list | tuple | str | int | float | None = None,
    alter: bool = False,
    conversions: dict | None = None,
    xtra: dict | None = None,
    **kwargs,
) -> Any:
    d = _process_row(updates or {})
    d = {**d, **kwargs, **(xtra or getattr(self, "xtra_id", {}))}
    if not d: return {}
    if pk_values is None: pk_values = [d[o.name] for o in self.table.primary_key]
    else: pk_values = listify(pk_values)
    qry = self._pk_where("update", pk_values).values(**d).returning(*self.table.columns)
    result = self.conn.execute(qry)
    if (row := result.one_or_none()) is None:
        self.conn.rollback()
        raise NotFoundError()

    self.conn.commit()
    return _row_to_obj(self, row)

In [None]:
users.xtra(name='jph')
u.pwd = 'new'
users.update(u)
users.xtra()
users()

[User(name='jph', pwd='new'), User(name='rlt', pwd='bar')]

In [None]:
users.xtra(name='rlt')
u.pwd = 'foo'
users.update(u)
users.xtra()
test_eq(users['jph'].pwd, 'new')

In [None]:
#| export
@patch
def update_where(
    self: DBTable,
    updates: dict,
    where: str | None = None,
    where_args: dict | Iterable | None = None,
    xtra: dict | None = None,
    **kw,
) -> list:
    "Update rows matching `where` with `updates`. Returns updated rows."
    stmt = self.table.update().values(**updates)
    if where: stmt = stmt.where(_where(where, where_args, xtra or getattr(self, "xtra_id", {}), **kw))
    rows = self.conn.execute(stmt.returning(*self.table.columns)).fetchall()
    self.conn.commit()
    return [_row_to_obj(self, r) for r in rows]


In [None]:
#| export
@patch
def update_where(
    self: DBTable,
    updates: dict,
    where: str | None = None,
    where_args: dict | Iterable | None = None,
    xtra: dict | None = None,
    **kw,
) -> list:
    "Update rows matching `where` with `updates`. Returns updated rows."
    stmt = self.table.update().values(**updates)
    if where:
        stmt = stmt.where(
            _where(where, where_args, xtra or getattr(self, "xtra_id", {}), **kw)
        )
    rows = self.conn.execute(stmt.returning(*self.table.columns)).fetchall()
    self.conn.commit()
    return [_row_to_obj(self, r) for r in rows]


In [None]:
todos.update_where({'done': True}, where='name = :name', name='jph')

[Todo(title='do it', name='jph', id=1, done=True, details=''),
 Todo(title='do it', name='jph', id=3, done=True, details='')]

In [None]:
#| export
@patch
def delete(self: DBTable, key):
    "Delete item with PK `key` and return the deleted object"
    result = self.conn.execute(
        self._pk_where("delete", key).returning(*self.table.columns)
    )
    row = result.one()
    self.conn.commit()
    return _row_to_obj(self, row)


In [None]:
assert users.delete('jph')
test_fail(lambda: users['jph'])

In [None]:
#| export
@patch
def delete_where(
    self: DBTable,
    where: Optional[str] = None,
    where_args: Optional[Union[Iterable, dict]] = None,
    xtra: dict | None = None,
    **kw,
):
    stmt = self.table.delete()
    if where: stmt = stmt.where(_where(where, where_args, xtra or getattr(self, "xtra_id", {}), **kw))
    rows = self.conn.execute(stmt.returning(*self.table.columns)).fetchall()
    self.conn.commit()
    return [_row_to_obj(self, r) for r in rows]


In [None]:
todos.delete_where("name = ?", ["jph"])

[Todo(title='do it', name='jph', id=1, done=True, details=''),
 Todo(title='do it', name='jph', id=3, done=True, details='')]

In [None]:
todos('name=?', ['jph'])

[]

In [None]:
#| export
@patch
def __contains__(
    self: DBTable,
    pk_values: Union[
        list, tuple, str, int
    ],  # A single value, or a tuple of values for tables that have a compound primary key
) -> bool:
    "Is the item with the specified primary key value in this table?"
    if isinstance(pk_values, (str, int)): pk_values = (pk_values,)
    try:
        self[pk_values]
        return True
    except NotFoundError:
        return False

Demonstration with single field primary key:

In [None]:
assert not 'jph' in users
assert 'rlt' in users

For compound primary keys, lets whether a student is in the students table or not. 

In [None]:
students.insert(Student(1, 2021, 'jph'))


Student(id=1, grad_year=2021, name='jph')

In [None]:
assert (1,2021) in students
assert (1,2030) not in students

In [None]:
#| export
@patch
def drop(self: DBTable, ignore: bool = False):
    "Drop this table from the database"
    try:
        self.table.drop(self.db.engine)
        self.conn.commit()
        if self.table.name in self.db._tables: del self.db._tables[self.table.name]
        if self.table.name in self.db.meta.tables: self.db.meta.remove(self.table)
    except Exception as e:
        if not ignore: raise

In [None]:
students.drop()
assert 'student' not in db.t

## SQLAlchemy helpers

In [None]:
#| export
from fastcore.net import urlsave

from collections import namedtuple
from sqlalchemy import create_engine, text, MetaData, Table, Column, engine, sql
from sqlalchemy.sql.base import ReadOnlyColumnCollection
from sqlalchemy.engine.base import Connection
from sqlalchemy.engine.cursor import CursorResult


In [None]:
#| export
@patch
def __dir__(self: MetaData):
    return self._orig___dir__() + list(self.tables)


@patch
def __dir__(self: ReadOnlyColumnCollection):
    return self._orig___dir__() + self.keys()


def _getattr_(self, n):
    if n[0] == "_": raise AttributeError
    if n in self.tables: return self.tables[n]
    raise AttributeError


MetaData.__getattr__ = _getattr_


In [None]:
dbm = db.meta

In [None]:
' '.join(dbm.tables)

'user todo pending_todos'

In [None]:
t = dbm.todo

In [None]:
list(t.c)

[Column('title', String(), table=<todo>),
 Column('name', String(), table=<todo>),
 Column('id', Integer(), table=<todo>, primary_key=True),
 Column('done', Boolean(), table=<todo>),
 Column('details', String(), table=<todo>)]

In [None]:
from sqlalchemy.exc import ResourceClosedError

In [None]:
#| export
@patch
def tuples(self: CursorResult, nm="Row"):
    "Get all results as named tuples"
    rs = self.mappings().fetchall()
    nt = namedtuple(nm, self.keys())
    return [nt(**o) for o in rs]


@patch
def sql(self: Connection, statement, nm="Row", *args, **kwargs):
    "Execute `statement` string and return results (if any)"
    if isinstance(statement, str): statement = text(statement)
    t = self.execute(statement)
    try:
        return t.tuples()
    except ResourceClosedError:
        pass  # statement didn't return anything


@patch
def sql(self: MetaData, statement, *args, **kwargs):
    "Execute `statement` string and return `DataFrame` of results (if any)"
    return self.conn.sql(statement, *args, **kwargs)


In [None]:
# dbm.sql('delete from todo')
# db.conn.commit()

In [None]:
rs = dbm.sql('select * from user')
rs[0]

Row(name='rlt', pwd='foo')

In [None]:
#| export
@patch
def get(self: Table, where=None, limit=None):
    "Select from table, optionally limited by `where` and `limit` clauses"
    return self.metadata.conn.sql(self.select().where(where).limit(limit))


In [None]:
t.get(t.c.title.startswith('d'), limit=5)

[]

This is the query that will run behind the scenes:

In [None]:
print(t.select().where(t.c.title.startswith('d')).limit(5))

SELECT todo.title, todo.name, todo.id, todo.done, todo.details 
FROM todo 
WHERE (todo.title LIKE :title_1 || '%')
 LIMIT :param_1


In [None]:
#| export
@patch
def close(self: MetaData):
    "Close the connection"
    self.conn.close()

In [None]:
dbm.close()

# Export -

In [None]:
#|hide
import nbdev; nbdev.nbdev_export()