Skip to content

Commit

Permalink
Make loading CONFIG attributes lazy
Browse files Browse the repository at this point in the history
This is mainly to avoid creating the `server.cfg` file automatically
even when it is not needed.
  • Loading branch information
CasperWA committed Dec 7, 2019
1 parent fe46ebb commit 4c658a2
Showing 1 changed file with 71 additions and 46 deletions.
117 changes: 71 additions & 46 deletions optimade/server/config.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,64 @@
import json
from typing import Dict, Set, Any
from typing import Any
from configparser import ConfigParser
from pathlib import Path


class NoFallback(Exception):
"""No fallback value can be found."""


class Config:
"""Base class for loading config files and its parameters"""

index_links_path: Path = Path("./optimade/server/index_links.json")
_path: Path = Path("./optimade/server/config.ini")

def __init__(self, server_cfg: Path = None):
server = (
self._server = (
Path(__file__).resolve().parent.parent.parent.joinpath("server.cfg")
if server_cfg is None
else server_cfg
)

self._create_server_config(server)
self._load_server_config(server)
def __getattr__(self, name: str) -> Any:
if not self._server.exists():
self._create_server_config()
self._load_server_config()

ftype = self._path.suffix[1:] # Remove initial "."
self.load(ftype)

@staticmethod
def _create_server_config(server_cfg: Path):
if name not in self.__dict__:
raise AttributeError(
f"'{self.__class__.__name__}' object has no attribute '{name}'"
)

return getattr(self, name)

def _create_server_config(self):
"""Create 'server.cfg' in top-package dir from 'server_template.cfg' if it does not exist"""
if not server_cfg.exists():
import shutil
import shutil

server_cfg_template = server_cfg.parent.joinpath("server_template.cfg")
shutil.copyfile(server_cfg_template, server_cfg)
server_cfg_template = self._server.parent.joinpath("server_template.cfg")
shutil.copyfile(server_cfg_template, self._server)

def _load_server_config(self, server_cfg: Path):
def _load_server_config(self):
"""Load cfg-file determining paths to server config files"""
SECTION = "optimadeconfig"
INDEX_LINKS_PATH = "INDEX_LINKS"
SERVER_CONFIG_PATH = "CONFIG"

server = ConfigParser()
server.read(server_cfg)
server.read(self._server)

index_links_path = server.get(
SECTION, INDEX_LINKS_PATH, fallback=str(self.index_links_path)
)
self.index_links_path = server_cfg.parent.joinpath(index_links_path).resolve()
self.index_links_path = self._server.parent.joinpath(index_links_path).resolve()

_path = server.get(SECTION, SERVER_CONFIG_PATH, fallback=str(self._path))
self._path = server_cfg.parent.joinpath(_path).resolve()
self._path = self._server.parent.joinpath(_path).resolve()

if not self._path.exists():
raise ValueError(
Expand Down Expand Up @@ -81,24 +92,28 @@ class ServerConfig(Config):
"""

use_real_mongo = False
mongo_database = "optimade"
links_collection = "links"
references_collection = "references"
structures_collection = "structures"

page_limit = 500
version = "0.10.0"
default_db = "test_server"

provider = {
"prefix": "_exmpl_",
"name": "Example provider",
"description": "Provider used for examples, not to be assigned to a real database",
"homepage": "http://example.com",
"index_base_url": "http://example.com/optimade/index",
}
provider_fields: Dict[str, Set] = {}
@staticmethod
def _DEFAULTS(field: str) -> Any:
res = {
"use_real_mongo": False,
"mongo_database": "optimade",
"links_collection": "links",
"references_collection": "references",
"structures_collection": "structures",
"page_limit": 500,
"version": "v0.10.0",
"default_db": "test_server",
"provider": {
"prefix": "_exmpl_",
"name": "Example provider",
"description": "Provider used for examples, not to be assigned to a real database",
"homepage": "http://example.com",
"index_base_url": "http://example.com/optimade/index",
},
}
if field not in res:
raise NoFallback(f"No fallback value found for '{field}'")
return res[field]

def load_from_ini(self):
""" Load from the file "config.ini", if it exists. """
Expand All @@ -107,22 +122,26 @@ def load_from_ini(self):
config.read(self._path)

self.use_real_mongo = config.getboolean(
"BACKEND", "USE_REAL_MONGO", fallback=self.use_real_mongo
"BACKEND", "USE_REAL_MONGO", fallback=self._DEFAULTS("use_real_mongo")
)
self.mongo_database = config.get(
"BACKEND", "MONGO_DATABASE", fallback=self.mongo_database
"BACKEND", "MONGO_DATABASE", fallback=self._DEFAULTS("mongo_database")
)

self.page_limit = config.getint(
"IMPLEMENTATION", "PAGE_LIMIT", fallback=self.page_limit
"IMPLEMENTATION", "PAGE_LIMIT", fallback=self._DEFAULTS("page_limit")
)
self.version = config.get(
"IMPLEMENTATION", "VERSION", fallback=self._DEFAULTS("version")
)
self.version = config.get("IMPLEMENTATION", "VERSION", fallback=self.version)
self.default_db = config.get(
"IMPLEMENTATION", "DEFAULT_DB", fallback=self.default_db
"IMPLEMENTATION", "DEFAULT_DB", fallback=self._DEFAULTS("default_db")
)

if "PROVIDER" in config.sections():
self.provider = dict(config["PROVIDER"])
else:
self.provider = self._DEFAULTS("provider")

self.provider_fields = {}
for endpoint in {"links", "references", "structures"}:
Expand All @@ -139,7 +158,7 @@ def load_from_ini(self):
config.get(
"BACKEND",
f"{endpoint.upper()}_COLLECTION",
fallback=getattr(self, f"{endpoint}_collection"),
fallback=self._DEFAULTS(f"{endpoint}_collection"),
),
)

Expand All @@ -149,22 +168,28 @@ def load_from_json(self):
with open(self._path, "r") as f:
config = json.load(f)

self.use_real_mongo = bool(config.get("use_real_mongo", self.use_real_mongo))
self.mongo_database = config.get("mongo_database", self.mongo_database)
self.use_real_mongo = bool(
config.get("use_real_mongo", self._DEFAULTS("use_real_mongo"))
)
self.mongo_database = config.get(
"mongo_database", self._DEFAULTS("mongo_database")
)
for endpoint in {"links", "references", "structures"}:
setattr(
self,
f"{endpoint}_collection",
config.get(f"{endpoint}_collection"),
getattr(self, f"{endpoint}_collection"),
getattr(self._DEFAULTS(f"{endpoint}_collection")),
)

self.page_limit = int(config.get("page_limit", self.page_limit))
self.version = config.get("version", self.version)
self.default_db = config.get("default_db", self.default_db)
self.page_limit = int(config.get("page_limit", self._DEFAULTS("page_limit")))
self.version = config.get("version", self._DEFAULTS("version"))
self.default_db = config.get("default_db", self._DEFAULTS("default_db"))

self.provider = config.get("provider", self.provider)
self.provider_fields = set(config.get("provider_fields", self.provider_fields))
self.provider = config.get("provider", self._DEFAULTS("provider"))
self.provider_fields = set(
config.get("provider_fields", self._DEFAULTS("provider_fields"))
)


CONFIG = ServerConfig()

0 comments on commit 4c658a2

Please sign in to comment.