Skip to content

Commit

Permalink
chore: Applying pre-commit checks
Browse files Browse the repository at this point in the history
  • Loading branch information
aadel committed Apr 2, 2024
1 parent e26c4b7 commit e7fa18c
Show file tree
Hide file tree
Showing 23 changed files with 171 additions and 144 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,3 @@ solr = "sqlalchemy_solr.http:SolrDialect_http"

[tool.setuptools.dynamic]
version = {attr = "sqlalchemy_solr.__version__"}

5 changes: 2 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

from setuptools import setup

readme = os.path.join(os.path.dirname(__file__), "README.md")
with open(readme, encoding='utf-8') as f:
with open(readme, encoding="utf-8") as f:
long_description = f.read()

setup(
long_description=long_description,
long_description_content_type="text/markdown",
tests_require=["pysolr", "pytest >= 6.2.1"],
test_suite="nose.collector",
zip_safe=False
zip_safe=False,
)
10 changes: 6 additions & 4 deletions src/sqlalchemy_solr/admin/solr_spec.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
from requests import Session


class SolrSpec:

_spec = None

def __init__(self, solr_base_url):
session = Session()
sys_info_response = session.get(solr_base_url + "/admin/info/system",
params={"wt": "json"})
sys_info_response = session.get(
solr_base_url + "/admin/info/system", params={"wt": "json"}
)
spec_version = sys_info_response.json()["lucene"]["solr-spec-version"]
self._spec = list(map(int, spec_version.split('.')))
self._spec = list(map(int, spec_version.split(".")))

def __str__(self) -> str:
return '.'.join(list(map(str, self._spec)))
return ".".join(list(map(str, self._spec)))

def spec(self):
return self._spec
65 changes: 37 additions & 28 deletions src/sqlalchemy_solr/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,14 @@
from sqlalchemy.sql import operators
from sqlalchemy.sql.expression import BindParameter

from .solrdbapi import Connection

from . import solrdbapi as module

from .type_map import type_map

from .solr_type_compiler import SolrTypeCompiler
from .solrdbapi import Connection
from .type_map import type_map

logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.ERROR)


class SolrCompiler(compiler.SQLCompiler):
# pylint: disable=abstract-method

Expand All @@ -61,13 +59,15 @@ def default_from(self):
return " FROM (values(1))"

# pylint: disable=too-many-arguments, too-many-branches
def visit_binary(self,
def visit_binary(
self,
binary,
override_operator=None,
eager_grouping=False,
from_linter=None,
lateral_from_linter=None,
**kw):
**kw,
):

# Handled in Solr 9
if Connection.solr_spec.spec()[0] >= self.SOLR_DATE_RANGE_TRANS_RELEASE:
Expand Down Expand Up @@ -104,13 +104,17 @@ def visit_binary(self,
if isinstance(
kw[str(binary.left)][uoperator].right.text, BindParameter
):
udatetime = parser.parse(self.unescape_colon(
kw[str(binary.left)][uoperator].right.effective_value
))
udatetime = parser.parse(
self.unescape_colon(
kw[str(binary.left)][uoperator].right.effective_value
)
)
else:
udatetime = parser.parse(self.unescape_colon(
kw[str(binary.left)][uoperator].right.text
))
udatetime = parser.parse(
self.unescape_colon(
kw[str(binary.left)][uoperator].right.text
)
)
else:
ubound, udatetime = "]", "*"
else:
Expand All @@ -123,16 +127,18 @@ def visit_binary(self,
if operators.ge in kw[str(binary.left)]
else ("{", operators.gt)
)
if isinstance(
kw[str(binary.left)][loperator].right, BindParameter
):
ldatetime = parser.parse(self.unescape_colon(
kw[str(binary.left)][loperator].right.effective_value
))
if isinstance(kw[str(binary.left)][loperator].right, BindParameter):
ldatetime = parser.parse(
self.unescape_colon(
kw[str(binary.left)][loperator].right.effective_value
)
)
else:
ldatetime = parser.parse(self.unescape_colon(
kw[str(binary.left)][loperator].right.text
))
ldatetime = parser.parse(
self.unescape_colon(
kw[str(binary.left)][loperator].right.text
)
)
else:
lbound, ldatetime = "[", "*"

Expand Down Expand Up @@ -170,23 +176,26 @@ def visit_clauselist(self, clauselist, **kw):
return super().visit_clauselist(clauselist, **kw)

def visit_array(self, element, **kw):
return "ARRAY[%s]" % self.visit_clauselist(element, **kw) # pylint: disable=consider-using-f-string
return f"ARRAY[{self.visit_clauselist(element, **kw)}]"

def unescape_colon(self, s: str) -> str:
"""Unescape colon if present in datetime value"""
return s.replace(r'\:', ':')
return s.replace(r"\:", ":")

def datetime_str(self, dt) -> str:
if dt == "*":
return dt

return dt.isoformat() + "Z"


class SolrIdentifierPreparer(compiler.IdentifierPreparer):
# pylint: disable=too-few-public-methods

# Solr has no schema concept
schema_for_object = lambda self, obj: () # pylint: disable=unnecessary-lambda-assignment
# Solr has no schema concept
schema_for_object = (
lambda self, obj: () # pylint: disable=unnecessary-lambda-assignment
)

reserved_words = compiler.RESERVED_WORDS.copy()
reserved_words.update(
Expand Down Expand Up @@ -546,14 +555,14 @@ def get_foreign_keys(self, connection, table_name, schema=None, **kw):
return []

def get_indexes(self, connection, table_name, schema=None, **kw):
"""Solr has no support for indexes. Returns an empty list. """
"""Solr has no support for indexes. Returns an empty list."""
return []

def get_pk_constraint(self, connection, table_name, schema=None, **kw):
"""Solr has no support for primary keys. Retunrs an empty list."""
return []

def get_schema_names(self, connection, **kw): # pylint: disable=unused-argument
def get_schema_names(self, connection, **kw): # pylint: disable=unused-argument
return tuple(["default"])

def get_view_names(self, connection, schema=None, **kw):
Expand Down
27 changes: 15 additions & 12 deletions src/sqlalchemy_solr/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,19 @@
# -*- coding: utf-8 -*-
import logging

from requests import RequestException, Session

from requests import RequestException
from requests import Session
from sqlalchemy_solr.solrdbapi.api_exceptions import DatabaseError

from .api_globals import _HEADER
from .api_globals import _PAYLOAD

from .base import SolrDialect
from .message_formatter import MessageFormatter

logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.ERROR)


class SolrDialect_http(SolrDialect): # pylint: disable=invalid-name
class SolrDialect_http(SolrDialect): # pylint: disable=invalid-name
# pylint: disable=abstract-method,too-many-instance-attributes

supports_statement_cache = True
Expand Down Expand Up @@ -72,7 +71,7 @@ def create_connect_args(self, url):
self.proto = "https://"

if "token" in url.query:
if url.query["token"] is not None :
if url.query["token"] is not None:
self.token = url.query["token"]

# Mapping server path and collection
Expand All @@ -97,11 +96,15 @@ def create_connect_args(self, url):
# Prepare a session with proper authorization handling.
session = Session()
# session.verify property which is bydefault true so Handled here
if "verify_ssl" in url.query and url.query["verify_ssl"] in [False, "False", "false"]:
if "verify_ssl" in url.query and url.query["verify_ssl"] in [
False,
"False",
"false",
]:
session.verify = False

if self.token is not None:
session.headers.update({'Authorization': f'Bearer {self.token}'})
session.headers.update({"Authorization": f"Bearer {self.token}"})
else:
session.auth = (self.username, self.password)
# Utilize this session in other methods.
Expand All @@ -126,13 +129,13 @@ def get_columns(self, connection, table_name, schema=None, **kw):
local_payload = _PAYLOAD.copy()

if "columns" in kw:
columns = kw['columns']
columns = kw["columns"]
else:
columns = []
self._get_aliases(local_payload)

if table_name in self.aliases:
for collection in self.aliases[table_name].split(','):
for collection in self.aliases[table_name].split(","):
self.get_columns(None, collection, columns=columns)

return self.get_unique_columns(columns)
Expand Down Expand Up @@ -199,10 +202,10 @@ def _session_get(self, payload, path: str):

def get_unique_columns(self, columns):
unique_columns = []
columns_set = {column['name'] for column in columns}
columns_set = {column["name"] for column in columns}
for c in columns:
if c['name'] in columns_set:
if c["name"] in columns_set:
unique_columns.append(c)
columns_set.remove(c['name'])
columns_set.remove(c["name"])

return unique_columns
2 changes: 1 addition & 1 deletion src/sqlalchemy_solr/solr_type_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

class SolrTypeCompiler(compiler.GenericTypeCompiler):
# pylint: disable=too-few-public-methods
def visit_ARRAY(self, type_, **kw): # pylint: disable=invalid-name,unused-argument
def visit_ARRAY(self, type_, **kw): # pylint: disable=invalid-name,unused-argument

inner = self.process(type_.item_type)
return inner
32 changes: 15 additions & 17 deletions src/sqlalchemy_solr/solrdbapi/_solrdbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,21 @@
from requests import Session

from ..admin.solr_spec import SolrSpec

from ..api_globals import _HEADER
from ..api_globals import _PAYLOAD
from ..message_formatter import MessageFormatter

from .api_exceptions import ConnectionClosedException, DatabaseHTTPError
from .api_exceptions import UninitializedResultSetError
from .api_exceptions import ConnectionClosedException
from .api_exceptions import CursorClosedException
from .api_exceptions import DatabaseHTTPError
from .api_exceptions import ProgrammingError
from .api_exceptions import UninitializedResultSetError
from .solr_reflect import SolrTableReflection

apilevel = "2.0" # pylint: disable=invalid-name
threadsafety = 3 # pylint: disable=invalid-name
paramstyle = "qmark" # pylint: disable=invalid-name
default_storage_plugin = "" # pylint: disable=invalid-name
apilevel = "2.0" # pylint: disable=invalid-name
threadsafety = 3 # pylint: disable=invalid-name
paramstyle = "qmark" # pylint: disable=invalid-name
default_storage_plugin = "" # pylint: disable=invalid-name


# Python DB API 2.0 classes
class Cursor:
Expand Down Expand Up @@ -83,16 +83,14 @@ def substitute_in_query(string_query, parameters):
query = string_query
for param in parameters:
if isinstance(param, str):
query = query.replace("?", f"'{param}'", 1)
query = query.replace("?", f"{param!r}", 1)
else:
query = query.replace("?", str(param), 1)
return query

@staticmethod
# pylint: disable=too-many-arguments
def submit_query(
query, host, port, proto, server_path, collection, session
):
def submit_query(query, host, port, proto, server_path, collection, session):
local_payload = _PAYLOAD.copy()
local_payload["stmt"] = query
return session.get(
Expand Down Expand Up @@ -175,7 +173,7 @@ def execute(self, operation, parameters=()):
self.rowcount = len(self._result_set)
self._result_set_status = iter(range(len(self._result_set)))
self.description = tuple(
zip(
zip( # noqa: B905
column_names,
column_types,
[None for i in range(len(self._result_set.dtypes.index))],
Expand Down Expand Up @@ -303,7 +301,7 @@ def __init__(
self._session = session
self._connected = True

Connection.solr_spec = SolrSpec(f'{proto}{host}:{port}/{server_path}')
Connection.solr_spec = SolrSpec(f"{proto}{host}:{port}/{server_path}")

SolrTableReflection.connection = self

Expand Down Expand Up @@ -344,7 +342,6 @@ def rollback(self):
Solr does not support rollback
"""


@connected_
def cursor(self):
return Cursor(
Expand Down Expand Up @@ -377,7 +374,7 @@ def connect(

session = Session()
# bydefault session.verify is set to True
if verify_ssl is not None and verify_ssl in [False,"False","false"]:
if verify_ssl is not None and verify_ssl in [False, "False", "false"]:
session.verify = False

if use_ssl in [True, "True", "true"]:
Expand All @@ -401,8 +398,9 @@ def connect(
host, db, username, password, server_path, collection, port, proto, session
)


def add_authorization(session, username, password, token):
if token is not None:
session.headers.update({'Authorization': f'Bearer {token}'})
session.headers.update({"Authorization": f"Bearer {token}"})
else:
session.auth = (username, password)
Loading

0 comments on commit e7fa18c

Please sign in to comment.