Skip to content

Commit

Permalink
added the dbclient support using resty
Browse files Browse the repository at this point in the history
  • Loading branch information
darius BERNARD committed Nov 23, 2016
1 parent 151581a commit 85899cf
Show file tree
Hide file tree
Showing 12 changed files with 296 additions and 30 deletions.
15 changes: 8 additions & 7 deletions rest_models/backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,14 @@ def get_connection_params(self):
def get_new_connection(self, conn_params):
return ApiConnexion(**conn_params)

@property
def timeout(self):
return self.settings_dict['OPTIONS'].get('TIMEOUT', 4)

def init_connection_state(self):
c = self.connection
self.autocommit = True
r = c.head('', timeout=4)
if r.status_code == 403:
raise FakeDatabaseDbAPI2.OperationalError("bad credentials for database %s on %s" %
(self.alias, self.settings_dict['NAME']))
c.head('', timeout=self.timeout) # it will raise an exceptions if 403

def create_cursor(self):
return FakeCursor()
Expand All @@ -93,11 +94,11 @@ def _start_transaction_under_autocommit(self):
pass

def is_usable(self):
c = self.connection # type: requests.Session
c = self.connection
try:
c.head('', timeout=4)
c.head('', timeout=self.timeout)
return True
except requests.RequestException:
except FakeDatabaseDbAPI2.OperationalError:
return False

def _set_autocommit(self, autocommit):
Expand Down
46 changes: 42 additions & 4 deletions rest_models/backend/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,58 @@
import subprocess

from django.db.backends.base.client import BaseDatabaseClient
from django.test.testcases import LiveServerThread, _StaticFilesHandler

from rest_models.backend.connexion import LocalApiAdapter


class DatabaseClient(BaseDatabaseClient):
executable_name = 'bash'
port_range = (8097, 9015)

def start_server_thread(self):
self.server_thread = LiveServerThread('localhost', range(*self.port_range),
_StaticFilesHandler)
self.server_thread.daemon = True
self.server_thread.start()

# Wait for the live server to be ready
self.server_thread.is_ready.wait()
if self.server_thread.error:
# Clean up behind ourselves, since tearDownClass won't get called in
# case of errors.
self.stop_server_thread()
raise self.server_thread.error
return 'http://%s:%s' % (
self.server_thread.host, self.server_thread.port)

def stop_server_thread(self):
# There may not be a 'server_thread' attribute if setUpClass() for some
# reasons has raised an exception.
if hasattr(self, 'server_thread'):
# Terminate the live server's thread
self.server_thread.terminate()
self.server_thread.join()

def runshell(self):
resty_path = os.path.join(os.path.dirname(__file__), "exec", "resty")
args = [self.executable_name, "--init-file", resty_path]

subprocess.call(args, env={
"_EXTRA_CURL_AUTH": self.get_middleware_curl_args(),
"_resty_host": self.connection.settings_dict['NAME'],
cname = self.connection.settings_dict['NAME']
if cname.startswith(LocalApiAdapter.SPECIAL_URL):
cname = cname.replace(LocalApiAdapter.SPECIAL_URL, self.start_server_thread())
cname = cname.rstrip("/") + "*"
envs = os.environ.copy()
envs.update(dict(
_EXTRA_CURL_AUTH=self.get_middleware_curl_args(),
_resty_host=cname)
)
self.execute_subprocess(args=args, env=envs)

self.stop_server_thread()

})
def execute_subprocess(self, args, env):
subprocess.call(args, env=env) # pragma: no cover

def get_middleware_curl_args(self):
"""
Expand Down
37 changes: 32 additions & 5 deletions rest_models/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
from collections import namedtuple

from django.db.models.lookups import Lookup
from django.db.models.lookups import Lookup, Exact, IsNull
from django.db.models.sql.compiler import SQLCompiler as BaseSQLCompiler
from django.db.models.sql.constants import MULTI, NO_RESULTS, SINGLE
from django.db.models.sql.where import SubqueryConstraint, WhereNode
Expand All @@ -23,6 +23,22 @@
"""


def extract_exact_pk_value(where):
"""
check if the where node given represent a exclude(val=Model) node, which seem
more complicated, but can be passed as is to the api
:param django.db.models.sql.where.WhereNode where: the where node
:return: the real is exact
"""
if len(where.children) == 2:
exact, isnull = where.children

if (
isinstance(exact, Exact) and isinstance(isnull, IsNull) and
exact.lhs.target == isnull.lhs.target):
return exact
return None


class QueryParser(object):
"""
Expand Down Expand Up @@ -213,7 +229,12 @@ def check_compatibility(self):
is_and = where.connector == 'AND'
is_negated = where.negated
# AND xor negated
if len(where.children) == 1 or (is_and and not is_negated):
is_simple_lookup = len(where.children) == 1

exact_pk_value = extract_exact_pk_value(where)
if exact_pk_value is not None:
pass
elif is_simple_lookup or (is_and and not is_negated):
for child in where.children:
if isinstance(child, WhereNode):
where_nodes.append(child)
Expand Down Expand Up @@ -245,8 +266,13 @@ def flaten_where_clause(self, where_node):
"""
res = []
for child in where_node.children:

if isinstance(child, WhereNode):
res.extend(self.flaten_where_clause(child))
exact_pk_value = extract_exact_pk_value(child)
if exact_pk_value is not None:
res.append((child.negated, exact_pk_value))
else:
res.extend(self.flaten_where_clause(child))
else:
res.append((where_node.negated, child))
return res
Expand All @@ -262,7 +288,7 @@ def build_filter_params(self):
query = self.query
for negated, lookup in self.flaten_where_clause(query.where): # type: bool, Lookup
negated_mark = "-" if negated else ""
field = lookup.lhs.field.name
field = self.query_parser.get_rest_path_for_col(lookup.lhs)
if lookup.lookup_name == 'exact': # implicite lookup is not needed
fieldname = field
else:
Expand Down Expand Up @@ -321,10 +347,11 @@ def results_iter(self, results=None):
"""
Returns an iterator over the results from executing this query.
"""
raise NotImplementedError()# pragma: no cover
raise


class SQLInsertCompiler(SQLCompiler):

def execute_sql(self, return_id=False):
query = self.query
""":type: django.db.models.sql.subqueries.InsertQuery"""
Expand Down
4 changes: 3 additions & 1 deletion rest_models/backend/connexion.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

class LocalApiAdapter(BaseAdapter):

SPECIAL_URL = "http://localapi"

def __init__(self):
self.request_factory = RequestFactory()
super(LocalApiAdapter, self).__init__()
Expand Down Expand Up @@ -97,7 +99,7 @@ class ApiConnexion(object):
"""
def __init__(self, url, auth=None, retry=3):
self.session = requests.Session()
self.session.mount('http://localapi', LocalApiAdapter())
self.session.mount(LocalApiAdapter.SPECIAL_URL, LocalApiAdapter())
self.session.auth = self.auth = auth
self.url = url
self.retry = retry
Expand Down
12 changes: 8 additions & 4 deletions rest_models/backend/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from django.conf import settings
from django.db.backends.base.creation import BaseDatabaseCreation

from rest_models.backend.connexion import LocalApiAdapter

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -40,9 +42,11 @@ def create_test_db(self, verbosity=1, autoclobber=False, serialize=True, keepdb=
database already exists. Returns the name of the test database created.
"""
# Don't import django.core.management if it isn't needed.
test_database_name = self._get_test_db_name()
settings.DATABASES[self.connection.alias]["NAME"] = test_database_name
self.connection.settings_dict["NAME"] = test_database_name
if not self.connection.alias.startswith('TEST_'):
test_database_name = self._get_test_db_name()
settings.DATABASES[self.connection.alias]["NAME"] = test_database_name
self.connection.settings_dict["NAME"] = test_database_name
return self.connection.settings_dict["NAME"]

def _get_test_db_name(self):
"""
Expand All @@ -55,4 +59,4 @@ def _get_test_db_name(self):
if settings.DATABASES.get(test_alias):
return settings.DATABASES[test_alias]['NAME']
name = self.connection.settings_dict['NAME']
return re.sub('https?://[^/]+/', 'http://localapi/', name, count=1)
return re.sub('https?://[^/]+/', LocalApiAdapter.SPECIAL_URL + "/", name, count=1)
3 changes: 2 additions & 1 deletion rest_models/backend/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from django.db.utils import (DatabaseError, Error, IntegrityError, # noqa
InterfaceError, InternalError, NotSupportedError,
OperationalError, ProgrammingError)
OperationalError, ProgrammingError, DataError)

__ALL__ = ['ProgrammingError', 'OperationalError', 'IntegrityError', 'InternalError',
'NotSupportedError', 'DatabaseError', 'InterfaceError', 'Error']
Expand All @@ -17,3 +17,4 @@ class FakeDatabaseDbAPI2(object):
DatabaseError = DatabaseError
InterfaceError = InterfaceError
Error = Error
DataError = DataError
4 changes: 2 additions & 2 deletions rest_models/backend/exec/resty
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
# Copyright 2009, no rights reserved.
#

export _resty_host=""
export _resty_path=""
export _resty_nohistory=""

Expand Down Expand Up @@ -62,6 +61,7 @@ function resty() {
[ "$method" = "-v" ] && echo "$_resty_host $_resty_opts" && return
[ -z "$method" ] && echo "$_resty_host" && return
[ -n "$_path" ] && _resty_path=$_path

domain=$(echo -n "$_resty_host" \
|perl -ane '/^https?:\/\/([^\/\*]+)/; print $1')
_path="${_resty_host//\*/$_resty_path}"
Expand Down Expand Up @@ -166,4 +166,4 @@ echo "use GET,POST,PUT,PATCH,OPTIONS,TRACE to query the server."
echo "ie : GET /users/"
tput sgr0

export PS1='resty> '
export PS1='resty:$_resty_host> '
52 changes: 52 additions & 0 deletions rest_models/tests/test_clients.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# -*- coding: utf-8 -*-
from __future__ import unicode_literals, absolute_import, print_function

from django.core.management import call_command
from django.test.testcases import TestCase

from rest_models.backend.client import DatabaseClient


class ClientTest(TestCase):
def setUp(self):
self.original_execute = DatabaseClient.execute_subprocess
self.original_ports = DatabaseClient.port_range

def tearDown(self):
DatabaseClient.execute_subprocess = self.original_execute
DatabaseClient.port_range = self.original_ports

def test_existing_db(self):
called = []

def tmp_exec(self_dc, args, env):
self.assertEqual(env['_resty_host'], 'http://localhost:8097/api/v2*')
called.append(args)

DatabaseClient.execute_subprocess = tmp_exec

self.assertEqual(len(called), 0)
call_command('dbshell', database='api')
self.assertEqual(len(called), 1)

def test_to_run_db(self):
called = []

def tmp_exec(self_dc, args, env):
self.assertEqual(env['_resty_host'], 'http://127.0.0.1:8080/api/v2*')
called.append(args)

DatabaseClient.execute_subprocess = tmp_exec

self.assertEqual(len(called), 0)
call_command('dbshell', database='api2')
self.assertEqual(len(called), 1)

def test_run_error_server(self):
DatabaseClient.port_range = (80, 80) # this port wont work

def tmp_exec(self_dc, args, env):
raise AssertionError('the exec shall not being called')

DatabaseClient.execute_subprocess = tmp_exec
self.assertRaises(Exception, call_command, 'dbshell', database='api')
Loading

0 comments on commit 85899cf

Please sign in to comment.