Skip to content

Commit

Permalink
Merge 55cc102 into 90e3854
Browse files Browse the repository at this point in the history
  • Loading branch information
jssuzanne committed Apr 12, 2021
2 parents 90e3854 + 55cc102 commit b5598d1
Show file tree
Hide file tree
Showing 17 changed files with 179 additions and 104 deletions.
16 changes: 8 additions & 8 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ sudo: true

python:
- "3.6"
- "3.7"
- "3.8"
- "3.9"
- "3.10-dev"
- "nightly"
# - "3.7"
# - "3.8"
# - "3.9"
# - "3.10-dev"
# - "nightly"

env:
global:
Expand Down Expand Up @@ -42,9 +42,9 @@ matrix:
- python: "3.10-dev"
- python: "nightly"
- python: "pypy3"
include:
- python: "pypy3"
env: ANYBLOK_DATABASE_DRIVER=postgresql+psycopg2cffi ANYBLOK_DATABASE_USER=postgres SQLSERVER='psql -c' SQLPYCLIENT='psycopg2cffi'
# include:
# - python: "pypy3"
# env: ANYBLOK_DATABASE_DRIVER=postgresql+psycopg2cffi ANYBLOK_DATABASE_USER=postgres SQLSERVER='psql -c' SQLPYCLIENT='psycopg2cffi'
fast_finish: true

virtualenv:
Expand Down
13 changes: 13 additions & 0 deletions anyblok/bloks/anyblok_core/core/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# obtain one at http://mozilla.org/MPL/2.0/.
from sqlalchemy.orm import Session as SA_Session
from anyblok import Declarations
from sqlalchemy.orm.decl_api import DeclarativeMeta


@Declarations.register(Declarations.Core)
Expand All @@ -17,3 +18,15 @@ class Session(SA_Session):
def __init__(self, *args, **kwargs):
kwargs['query_cls'] = self.registry_query
super(Session, self).__init__(*args, **kwargs)

def get_bind(self, bind=None, mapper=None, clause=None, **kwargs):
if mapper is not None:
return mapper.class_.get_bind()
if clause is not None and hasattr(clause, 'Model'):
return self.anyblok.get(clause.Model).get_bind()
if self._flushing and bind and isinstance(bind.class_, DeclarativeMeta):
return bind.class_.get_bind()
if bind:
return bind

return self.anyblok.named_binds['default']
8 changes: 8 additions & 0 deletions anyblok/bloks/anyblok_core/core/sqlbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,14 @@ def __repr__(self):
self.__class__.__registry_name__, ', '.join(field_reprs)
)

@classmethod
def get_engine(self):
return self.anyblok.named_engines[self.engine_name]

@classmethod
def get_bind(self):
return self.anyblok.named_binds[self.engine_name]

@classmethod
def define_table_args(cls):
return ()
Expand Down
4 changes: 2 additions & 2 deletions anyblok/bloks/anyblok_core/system/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def initialize_model(cls):
""" Create the sequence to determine name """
super(Sequence, cls).initialize_model()
seq = SQLASequence(cls._cls_seq_name)
seq.create(cls.anyblok.bind)
seq.create(cls.anyblok.named_binds['default'])

to_create = getattr(cls.anyblok,
'_need_sequence_to_create_if_not_exist', ())
Expand Down Expand Up @@ -140,7 +140,7 @@ def create_sequence(cls, values):
values['seq_name'] = seq_name

seq = SQLASequence(seq_name, start=start)
seq.create(cls.anyblok.bind)
seq.create(cls.get_bind())
return values

@classmethod
Expand Down
56 changes: 29 additions & 27 deletions anyblok/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def autodoc_get_properties(self):
('is crypted', False),
))

def native_type(self, registry):
def native_type(self, engine):
"""Return the native SqlAlchemy type
:param registry:
Expand Down Expand Up @@ -278,18 +278,20 @@ def get_sqlalchemy_mapping(self, registry, namespace, fieldname,
else:
kwargs['default'] = self.default_val

sqlalchemy_type = self.native_type(registry)
sqlalchemy_type = self.native_type(
registry.named_engines[properties['engine_name']])

if self.encrypt_key:
encrypt_key = self.format_encrypt_key(registry, namespace)
engine = registry.named_engines[properties['engine_name']]
sqlalchemy_type = self.get_encrypt_key_type(
registry, sqlalchemy_type, encrypt_key)
engine, sqlalchemy_type, encrypt_key)

return SA_Column(db_column_name, sqlalchemy_type, *args, **kwargs)

def get_encrypt_key_type(self, registry, sqlalchemy_type, encrypt_key):
def get_encrypt_key_type(self, engine, sqlalchemy_type, encrypt_key):
sqlalchemy_type = StringEncryptedType(sqlalchemy_type, encrypt_key)
if sgdb_in(registry.engine, ['MySQL', 'MariaDB']):
if sgdb_in(engine, ['MySQL', 'MariaDB']):
sqlalchemy_type.impl = types.String(64)

return sqlalchemy_type
Expand Down Expand Up @@ -663,7 +665,7 @@ class Test:
"""
sqlalchemy_type = types.Interval

def native_type(self, registry):
def native_type(self, engine):
if self.encrypt_key:
return types.VARCHAR(1024)

Expand Down Expand Up @@ -728,9 +730,9 @@ def autodoc_get_properties(self):
res['size'] = self.size
return res

def get_encrypt_key_type(self, registry, sqlalchemy_type, encrypt_key):
def get_encrypt_key_type(self, engine, sqlalchemy_type, encrypt_key):
sqlalchemy_type = StringEncryptedType(sqlalchemy_type, encrypt_key)
if sgdb_in(registry.engine, ['MySQL', 'MariaDB']):
if sgdb_in(engine, ['MySQL', 'MariaDB']):
sqlalchemy_type.impl = types.String(max(self.size, 64))

return sqlalchemy_type
Expand Down Expand Up @@ -836,9 +838,9 @@ def setter_format_value(self, value):
value = SAU_PWD(value, context=self.sqlalchemy_type.context)
return value

def native_type(self, registry):
def native_type(self, engine):
""" Return the native SqlAlchemy type """
if sgdb_in(registry.engine, ['MsSQL']):
if sgdb_in(engine, ['MsSQL']):
return MsSQLPasswordType(max_length=self.size, **self.crypt_context)

return self.sqlalchemy_type
Expand Down Expand Up @@ -885,9 +887,9 @@ class Test:
"""
sqlalchemy_type = TextType

def get_encrypt_key_type(self, registry, sqlalchemy_type, encrypt_key):
def get_encrypt_key_type(self, engine, sqlalchemy_type, encrypt_key):
sqlalchemy_type = StringEncryptedType(sqlalchemy_type, encrypt_key)
if sgdb_in(registry.engine, ['MySQL', 'MariaDB']):
if sgdb_in(engine, ['MySQL', 'MariaDB']):
sqlalchemy_type.impl = types.Text()

return sqlalchemy_type
Expand Down Expand Up @@ -1102,7 +1104,7 @@ def update_table_args(self, registry, Model):
# can add new entry
return []

if sgdb_in(registry.engine, ['MariaDB', 'MsSQL']):
if sgdb_in(Model.get_engine(), ['MariaDB', 'MsSQL']):
# No check constraint in MariaDB
return []

Expand Down Expand Up @@ -1134,9 +1136,9 @@ def update_table_args(self, registry, Model):

return []

def get_encrypt_key_type(self, registry, sqlalchemy_type, encrypt_key):
def get_encrypt_key_type(self, engine, sqlalchemy_type, encrypt_key):
sqlalchemy_type = StringEncryptedType(sqlalchemy_type, encrypt_key)
if sgdb_in(registry.engine, ['MySQL', 'MariaDB']):
if sgdb_in(engine, ['MySQL', 'MariaDB']):
sqlalchemy_type.impl = types.String(max(self.size, 64))

return sqlalchemy_type
Expand Down Expand Up @@ -1166,9 +1168,9 @@ class Test:
"""
sqlalchemy_type = types.JSON(none_as_null=True)

def native_type(self, registry):
def native_type(self, engine):
""" Return the native SqlAlchemy type """
if sgdb_in(registry.engine, ['MariaDB', 'MsSQL']):
if sgdb_in(engine, ['MariaDB', 'MsSQL']):
return JSONType

return self.sqlalchemy_type
Expand All @@ -1188,9 +1190,9 @@ def getter_format_value(self, value):

return value

def get_encrypt_key_type(self, registry, sqlalchemy_type, encrypt_key):
def get_encrypt_key_type(self, engine, sqlalchemy_type, encrypt_key):
sqlalchemy_type = StringEncryptedType(sqlalchemy_type, encrypt_key)
if sgdb_in(registry.engine, ['MySQL', 'MariaDB']):
if sgdb_in(engine, ['MySQL', 'MariaDB']):
sqlalchemy_type.impl = types.Text()

return sqlalchemy_type
Expand All @@ -1217,7 +1219,7 @@ class Test:
"""
sqlalchemy_type = types.LargeBinary

def native_type(self, registry):
def native_type(self, engine):
if self.encrypt_key:
return types.Text

Expand All @@ -1235,9 +1237,9 @@ def getter_format_value(self, value):

return value

def get_encrypt_key_type(self, registry, sqlalchemy_type, encrypt_key):
def get_encrypt_key_type(self, engine, sqlalchemy_type, encrypt_key):
sqlalchemy_type = StringEncryptedType(sqlalchemy_type, encrypt_key)
if sgdb_in(registry.engine, ['MySQL', 'MariaDB']):
if sgdb_in(engine, ['MySQL', 'MariaDB']):
sqlalchemy_type.impl = types.Text()

return sqlalchemy_type
Expand Down Expand Up @@ -1353,9 +1355,9 @@ def autodoc_get_properties(self):
res['size'] = self.max_length
return res

def get_encrypt_key_type(self, registry, sqlalchemy_type, encrypt_key):
def get_encrypt_key_type(self, engine, sqlalchemy_type, encrypt_key):
sqlalchemy_type = StringEncryptedType(sqlalchemy_type, encrypt_key)
if sgdb_in(registry.engine, ['MySQL', 'MariaDB']):
if sgdb_in(engine, ['MySQL', 'MariaDB']):
sqlalchemy_type.impl = types.String(max(self.max_length, 64))

return sqlalchemy_type
Expand Down Expand Up @@ -1478,9 +1480,9 @@ def autodoc_get_properties(self):
res['max_length'] = self.max_length
return res

def get_encrypt_key_type(self, registry, sqlalchemy_type, encrypt_key):
def get_encrypt_key_type(self, engine, sqlalchemy_type, encrypt_key):
sqlalchemy_type = StringEncryptedType(sqlalchemy_type, encrypt_key)
if sgdb_in(registry.engine, ['MySQL', 'MariaDB']):
if sgdb_in(engine, ['MySQL', 'MariaDB']):
sqlalchemy_type.impl = types.String(max(self.max_length, 64))

return sqlalchemy_type
Expand Down Expand Up @@ -1621,7 +1623,7 @@ def update_table_args(self, registry, Model):
# can add new entry
return []

if sgdb_in(registry.engine, ['MariaDB', 'MsSQL']):
if sgdb_in(Model.get_engine(), ['MariaDB', 'MsSQL']):
# No Check constraint in MariaDB
return []

Expand Down
6 changes: 4 additions & 2 deletions anyblok/migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1440,10 +1440,12 @@ class Migration:

def __init__(self, registry):
self.withoutautomigration = registry.withoutautomigration
self.conn = registry.connection()
self.conn = registry.connection(
bind=registry.named_binds['default']
)
self.loaded_namespaces = registry.loaded_namespaces
self.loaded_views = registry.loaded_views
self.metadata = registry.declarativebase.metadata
self.metadata = registry.named_declarativebases['default'].metadata
self.ddl_compiler = self.conn.dialect.ddl_compiler(
self.conn.dialect, None)
self.ignore_migration_for = registry.ignore_migration_for
Expand Down
8 changes: 6 additions & 2 deletions anyblok/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,9 @@ def register(self, parent, name, cls_, **kwargs):
if parent is Declarations:
return # pragma: no cover

if 'engine_name' not in kwargs:
kwargs['engine_name'] = 'default'

kwargs['__registry_name__'] = _registryname
kwargs['__tablename__'] = tablename
update_factory(kwargs)
Expand Down Expand Up @@ -480,8 +483,9 @@ def load_namespace_second_step(cls, registry, namespace,
tablename = properties['__tablename__']
modelname = namespace.replace('.', '')
cls.init_core_properties_and_bases(registry, bases, properties)

if tablename in registry.declarativebase.metadata.tables:
declarativebase = registry.named_declarativebases[
properties['engine_name']]
if tablename in declarativebase.metadata.tables:
cls.apply_existing_table(
registry, namespace, tablename, properties,
bases, transformation_properties)
Expand Down
11 changes: 8 additions & 3 deletions anyblok/model/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ def insert_core_bases(self, bases, properties):
if has_sql_fields(bases):
bases.extend(
[x for x in self.registry.loaded_cores['SqlBase']])
bases.append(self.registry.declarativebase)
declarativebase = self.registry.named_declarativebases[
properties['engine_name']]
bases.append(declarativebase)
else:
# remove tablename to inherit from a sqlmodel
del properties['__tablename__']
Expand Down Expand Up @@ -97,6 +99,9 @@ def apply_view(self, base, properties):
:exception: ViewException
"""
tablename = base.__tablename__
declarativebase = self.registry.named_declarativebases[
properties['engine_name']]

if hasattr(base, '__view__'):
view = base.__view__
elif tablename in self.registry.loaded_views:
Expand All @@ -118,7 +123,7 @@ def apply_view(self, base, properties):
col = c._make_proxy(view)[1]
view._columns.replace(col)

metadata = self.registry.declarativebase.metadata
metadata = declarativebase.metadata
event.listen(metadata, 'before_create', DropView(
view, if_exists=True))
event.listen(metadata, 'after_create', CreateView(
Expand All @@ -137,7 +142,7 @@ def apply_view(self, base, properties):

pks = [getattr(view.c, x) for x in pks]
mapper_properties = self.get_mapper_properties(base, view, properties)
base.anyblok.declarativebase.registry.map_imperatively(
declarativebase.registry.map_imperatively(
base, view, primary_key=pks, properties=mapper_properties)
setattr(base, '__view__', view)

Expand Down
3 changes: 2 additions & 1 deletion anyblok/model/table_and_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ def insert_in_bases(self, new_base, namespace, properties,
transformation_properties['table_args'] = True

if transformation_properties['table_kwargs'] is True:
if sgdb_in(self.registry.engine, ['MySQL', 'MariaDB']):
engine = self.registry.named_engines[properties['engine_name']]
if sgdb_in(engine, ['MySQL', 'MariaDB']):
new_base.define_table_kwargs = self.define_table_kwargs(
new_base, namespace)

Expand Down

0 comments on commit b5598d1

Please sign in to comment.