Skip to content

Commit

Permalink
More refactoring from code review
Browse files Browse the repository at this point in the history
  • Loading branch information
ml-evs committed Mar 9, 2021
1 parent 980d90c commit f225114
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 40 deletions.
2 changes: 1 addition & 1 deletion docs/static/default_config.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"config_file": "~/.optimade.json",
"debug": false,
"use_real_mongo": false,
"use_production_backend": false,
"mongo_database": "optimade",
"mongo_uri": "localhost:27017",
"links_collection": "links",
Expand Down
42 changes: 19 additions & 23 deletions optimade/server/entry_collections/elasticsearch.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,7 @@
import os

import json
from pathlib import Path

from typing import Tuple, List, Optional, Dict, Any, Iterable, Union
from fastapi import HTTPException
from elasticsearch_dsl import Search
from elasticsearch.helpers import bulk
import json
import os.path

from optimade.filterparser import LarkParser
from optimade.filtertransformers.elasticsearch import ElasticTransformer, Quantity
Expand All @@ -24,13 +18,11 @@

if CONFIG.database_backend.value == "elastic" or CI_FORCE_ELASTIC:
from elasticsearch import Elasticsearch
from elasticsearch.helpers import bulk
from elasticsearch_dsl import Search

CLIENT = Elasticsearch(hosts=CONFIG.elastic_hosts)
print("Using: Real Elastic (elasticsearch)")


with open(Path(__file__).parent.joinpath("elastic_indexes.json")) as f:
INDEX_DEFINITIONS = json.load(f)
LOGGER.info(f"Using: Elasticsearch backend at {CONFIG.elastic_hosts!r}")


class ElasticCollection(EntryCollection):
Expand Down Expand Up @@ -81,7 +73,7 @@ def __init__(
self.transformer = ElasticTransformer(quantities=quantities.values())

self.name = name
body = INDEX_DEFINITIONS.get(name)
body = self.predefined_index.get(name)
if body is None:
body = self.create_elastic_index_from_mapper(
resource_mapper, self.all_fields
Expand All @@ -98,6 +90,13 @@ def __init__(

LOGGER.debug(f"Created index for {self.name!r} with body {body}")

@property
def predefined_index(self) -> Dict[str, Any]:
"""Loads and returns the default pre-defined index."""
with open(Path(__file__).parent.joinpath("elastic_indexes.json")) as f:
index = json.load(f)
return index

@staticmethod
def create_elastic_index_from_mapper(
resource_mapper: BaseResourceMapper, fields: Iterable[str]
Expand Down Expand Up @@ -153,15 +152,19 @@ def get_id(item):
bulk(
self.client,
[
{"_index": self.name, "_id": get_id(item), "_type": "doc", "_source": item}
{
"_index": self.name,
"_id": get_id(item),
"_type": "doc",
"_source": item,
}
for item in data
],
)

def _run_db_query(
self, criteria: Dict[str, Any], single_entry=False
) -> Tuple[Union[List[Dict[str, Any]], Dict[str, Any]], int, bool]:
"""Execute the query on the Elasticsearch backend."""

search = Search(using=self.client, index=self.name)

Expand Down Expand Up @@ -190,19 +193,12 @@ def _run_db_query(

results = [hit.to_dict() for hit in response.hits]

nresults_now = len(results)
if not single_entry:
data_returned = response.hits.total
more_data_available = page_offset + limit < data_returned
else:
# SingleEntryQueryParams, e.g., /structures/{entry_id}
data_returned = nresults_now
data_returned = len(results)
more_data_available = False
if nresults_now > 1:
raise HTTPException(
status_code=404,
detail=f"Instead of a single entry, {nresults_now} entries were found",
)
results = results[0] if results else None

return results, data_returned, more_data_available
11 changes: 11 additions & 0 deletions optimade/server/entry_collections/entry_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import re

from lark import Transformer
from fastapi import HTTPException

from optimade.filterparser import LarkParser
from optimade.models import EntryResource
Expand Down Expand Up @@ -139,11 +140,21 @@ def find(
"""
criteria, response_fields = self.handle_query_params(params)
single_entry = isinstance(params, SingleEntryQueryParams)

results, data_returned, more_data_available = self._run_db_query(
criteria, single_entry=isinstance(params, SingleEntryQueryParams)
)

if single_entry:
results = results[0] if results else None

if data_returned > 1:
raise HTTPException(
status_code=404,
detail=f"Instead of a single entry, {data_returned} entries were found",
)

exclude_fields = self.all_fields - response_fields

if results:
Expand Down
14 changes: 3 additions & 11 deletions optimade/server/entry_collections/mongo.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os

from typing import Dict, Tuple, List, Any
from fastapi import HTTPException

from optimade.filterparser import LarkParser
from optimade.filtertransformers.mongo import MongoTransformer
Expand All @@ -20,15 +19,14 @@
if CONFIG.use_production_backend or CI_FORCE_MONGO:
from pymongo import MongoClient

client = MongoClient(CONFIG.mongo_uri)
LOGGER.info("Using: Real MongoDB (pymongo)")
else:
from mongomock.collection import Collection
from mongomock import MongoClient

client = MongoClient()
LOGGER.info("Using: Mock MongoDB (mongomock)")

CLIENT = MongoClient(CONFIG.mongo_uri)


class MongoCollection(EntryCollection):
"""Class for querying MongoDB collections (implemented by either pymongo or mongomock)
Expand Down Expand Up @@ -60,7 +58,7 @@ def __init__(
)

self.parser = LarkParser(version=(1, 0, 0), variant="default")
self.collection = client[database][name]
self.collection = CLIENT[database][name]

# check aliases do not clash with mongo operators
self._check_aliases(self.resource_mapper.all_aliases())
Expand Down Expand Up @@ -115,12 +113,6 @@ def _run_db_query(
# SingleEntryQueryParams, e.g., /structures/{entry_id}
data_returned = nresults_now
more_data_available = False
if nresults_now > 1:
raise HTTPException(
status_code=404,
detail=f"Instead of a single entry, {nresults_now} entries were found",
)
results = results[0] if results else None

return results, data_returned, more_data_available

Expand Down
4 changes: 2 additions & 2 deletions tests/server/query_params/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,8 +378,8 @@ def test_filter_on_relationships(check_response, check_error_response):
if CONFIG.database_backend == SupportedBackend.ELASTIC:
check_error_response(
request,
expected_status=501,
expected_title="NotImplementedError",
expected_status=400,
expected_title="Bad Request",
expected_detail="references is not a searchable quantity",
)
pytest.xfail("Elasticsearch backend does not support relationship filtering.")
Expand Down
6 changes: 3 additions & 3 deletions tests/server/test_server_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,16 @@ def test_with_validator_json_response(both_fake_remote_clients, capsys):
def test_mongo_backend_package_used():
import pymongo
import mongomock
from optimade.server.entry_collections.mongo import client
from optimade.server.entry_collections.mongo import CLIENT

force_mongo_env_var = os.environ.get("OPTIMADE_CI_FORCE_MONGO", None)
if force_mongo_env_var is None:
return

if int(force_mongo_env_var) == 1:
assert issubclass(client.__class__, pymongo.MongoClient)
assert issubclass(CLIENT.__class__, pymongo.MongoClient)
elif int(force_mongo_env_var) == 0:
assert issubclass(client.__class__, mongomock.MongoClient)
assert issubclass(CLIENT.__class__, mongomock.MongoClient)
else:
raise pytest.fail(
"The environment variable OPTIMADE_CI_FORCE_MONGO cannot be parsed as an int."
Expand Down

0 comments on commit f225114

Please sign in to comment.