diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 10757e9..38bc75a 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,3 +1,9 @@ +2.1.0 +===== +* Added getitem for documents at the database level +* Added fill_default() on documents to replace None values by schema defaults +* fill_default() is automatically called on save + 2.0.2 ===== * Fixed contains functions diff --git a/pyArango/collection.py b/pyArango/collection.py index a234c5f..d69de35 100644 --- a/pyArango/collection.py +++ b/pyArango/collection.py @@ -20,7 +20,7 @@ class BulkMode(Enum): DELETE = 3 class CachedDoc(object): - """A cached document""" + """A cached document.""" def __init__(self, document, prev, nextDoc): self.prev = prev self.document = document @@ -43,7 +43,7 @@ def __getattribute__(self, k): raise e2 class DocumentCache(object): - "Document cache for collection, with insert, deletes and updates in O(1)" + """Document cache for collection, with insert, deletes and updates in O(1).""" def __init__(self, cacheSize): self.cacheSize = cacheSize @@ -78,7 +78,7 @@ def cache(self, doc): self.cacheStore[doc._key] = ret def delete(self, _key): - "removes a document from the cache" + """Remove a document from the cache.""" try: doc = self.cacheStore[_key] doc.prev.nextDoc = doc.nextDoc @@ -88,7 +88,7 @@ def delete(self, _key): raise KeyError("Document with _key %s is not available in cache" % _key) def getChain(self): - "returns a list of keys representing the chain of documents" + """Return a list of keys representing the chain of documents.""" l = [] h = self.head while h: @@ -96,8 +96,8 @@ def getChain(self): h = h.nextDoc return l - def stringify(self): - "a pretty str version of getChain()" + def stringify(self) -> str: + """Return a pretty string of 'getChain()'.""" l = [] h = self.head while h: @@ -119,14 +119,16 @@ def __repr__(self): class Field(object): """The class for defining pyArango fields.""" def __init__(self, validators = None, default = None): - """validators must be a list of validators. default can also be a callable""" + """Validators must be a list of validators. + + 'default' can also be a callable.""" if not validators: validators = [] self.validators = validators self.default = default def validate(self, value): - """checks the validity of 'value' given the lits of validators""" + """Check the validity of 'value' given the list of validators.""" for v in self.validators: v.validate(value) return True @@ -138,7 +140,7 @@ def __str__(self): return "" % ', '.join(strv) class Collection_metaclass(type): - """The metaclass that takes care of keeping a register of all collection types""" + """The metaclass that takes care of keeping a register of all collection types.""" collectionClasses = {} _validationDefault = { @@ -173,20 +175,20 @@ def check_set_ConfigDict(dictName): @classmethod def getCollectionClass(cls, name): - """Return the class object of a collection given its 'name'""" + """Return the class object of a collection given its 'name'.""" try: return cls.collectionClasses[name] except KeyError: raise KeyError( "There is no Collection Class of type: '%s'; currently supported values: [%s]" % (name, ', '.join(getCollectionClasses().keys())) ) @classmethod - def isCollection(cls, name): - """return true or false wether 'name' is the name of collection.""" + def isCollection(cls, name) -> bool: + """return 'True' or 'False' whether 'name' is the name of collection.""" return name in cls.collectionClasses @classmethod - def isDocumentCollection(cls, name): - """return true or false wether 'name' is the name of a document collection.""" + def isDocumentCollection(cls, name) -> bool: + """Return 'True' or 'False' whether 'name' is the name of a document collection.""" try: col = cls.getCollectionClass(name) return issubclass(col, Collection) @@ -194,37 +196,37 @@ def isDocumentCollection(cls, name): return False @classmethod - def isEdgeCollection(cls, name): - """return true or false wether 'name' is the name of an edge collection.""" + def isEdgeCollection(cls, name) -> bool: + """Return 'True' or 'False' whether 'name' is the name of an edge collection.""" try: col = cls.getCollectionClass(name) return issubclass(col, Edges) except KeyError: return False -def getCollectionClass(name): - """return true or false wether 'name' is the name of collection.""" +def getCollectionClass(name) -> bool: + """Return 'True' or 'False' whether 'name' is the name of collection.""" return Collection_metaclass.getCollectionClass(name) -def isCollection(name): - """return true or false wether 'name' is the name of a document collection.""" +def isCollection(name) -> bool: + """Return 'True' or 'False' whether 'name' is the name of a document collection.""" return Collection_metaclass.isCollection(name) -def isDocumentCollection(name): - """return true or false wether 'name' is the name of a document collection.""" +def isDocumentCollection(name) -> bool: + """Return 'True' or 'False' whether 'name' is the name of a document collection.""" return Collection_metaclass.isDocumentCollection(name) -def isEdgeCollection(name): - """return true or false wether 'name' is the name of an edge collection.""" +def isEdgeCollection(name) -> bool: + """Return 'True' or 'False' whether 'name' is the name of an edge collection.""" return Collection_metaclass.isEdgeCollection(name) -def getCollectionClasses(): - """returns a dictionary of all defined collection classes""" +def getCollectionClasses() -> bool: + """Return a dictionary of all defined collection classes.""" return Collection_metaclass.collectionClasses class Collection(with_metaclass(Collection_metaclass, object)): - """A document collection. Collections are meant to be instanciated by databases""" - #here you specify the fields that you want for the documents in your collection + """A document collection. Collections are meant to be instantiated by databases.""" + # here you specify the fields that you want for the documents in your collection _fields = {} _validation = { @@ -291,7 +293,7 @@ def getDocumentsURL(self): return "%s/document" % (self.database.getURL()) def getIndexes(self): - """Fills self.indexes with all the indexes associates with the collection and returns it""" + """Fill 'self.indexes' with all the indexes associated with the collection and return it.""" self.indexes_by_name = {} url = "%s/index" % self.database.getURL() r = self.connection.session.get(url, params = {"collection": self.name}) @@ -310,28 +312,30 @@ def getIndex(self, name): return self.indexes_by_name[name] def activateCache(self, cacheSize): - """Activate the caching system. Cached documents are only available through the __getitem__ interface""" + """Activate the caching system. + + Cached documents are only available through the __getitem__ interface.""" self.documentCache = DocumentCache(cacheSize) def deactivateCache(self): - "deactivate the caching system" + """Deactivate the caching system.""" self.documentCache = None def delete(self): - """deletes the collection from the database""" + """Delete the collection from the database.""" r = self.connection.session.delete(self.getURL()) data = r.json() if not r.status_code == 200 or data["error"]: raise DeletionError(data["errorMessage"], data) def createDocument(self, initDict = None): - """create and returns a completely empty document or one populated with initDict""" + """Create and return a completely empty document unless the initial document is set via 'initDict'.""" # res = dict(self.defaultDocument) res = self.getDefaultDocument() - + if initDict is not None: res.update(initDict) - + return self.documentClass(self, res) def _writeBatch(self): @@ -381,7 +385,7 @@ def _saveBatch(self, document, params): if len(self._bulkCache) == self._bulkSize: self._writeBatch() self._bulkMode = BulkMode.NONE - + def _updateBatch(self): if not self._bulkCache: return @@ -390,7 +394,7 @@ def _updateBatch(self): payload = [] for d in self._bulkCache: dPayload = d._store.getPatches() - + if d.collection._validation['on_save']: d.validate() @@ -421,8 +425,8 @@ def _updateBatch(self): self._bulkCache = [] if bulkError is not None: raise bulkError - - + + def _patchBatch(self, document, params): if self._bulkMode != BulkMode.NONE and self._bulkMode != BulkMode.UPDATE: raise UpdateError("Mixed bulk operations not supported - have " + str(self._bulkMode)) @@ -490,7 +494,7 @@ def _finalizeBatch(self): self._isBulkInProgress = False self._batchParams = None self._bulkMode = BulkMode.NONE - + def importBulk(self, data, **addParams): url = "%s/import" % (self.database.getURL()) payload = json.dumps(data, default=str) @@ -514,7 +518,7 @@ def exportDocs( self, **data): return docs def ensureHashIndex(self, fields, unique = False, sparse = True, deduplicate = False, name = None): - """Creates a hash index if it does not already exist, and returns it""" + """Create a hash index if it does not already exist, then return it.""" data = { "type" : "hash", "fields" : fields, @@ -531,7 +535,7 @@ def ensureHashIndex(self, fields, unique = False, sparse = True, deduplicate = F return ind def ensureSkiplistIndex(self, fields, unique = False, sparse = True, deduplicate = False, name = None): - """Creates a skiplist index if it does not already exist, and returns it""" + """Create a skiplist index if it does not already exist, then return it.""" data = { "type" : "skiplist", "fields" : fields, @@ -548,7 +552,7 @@ def ensureSkiplistIndex(self, fields, unique = False, sparse = True, deduplicate return ind def ensurePersistentIndex(self, fields, unique = False, sparse = True, deduplicate = False, name = None): - """Creates a persistent index if it does not already exist, and returns it""" + """Create a persistent index if it does not already exist, then return it.""" data = { "type" : "persistent", "fields" : fields, @@ -565,7 +569,7 @@ def ensurePersistentIndex(self, fields, unique = False, sparse = True, deduplica return ind def ensureTTLIndex(self, fields, expireAfter, unique = False, sparse = True, name = None): - """Creates a TTL index if it does not already exist, and returns it""" + """Create a TTL index if it does not already exist, then return it.""" data = { "type" : "ttl", "fields" : fields, @@ -582,7 +586,7 @@ def ensureTTLIndex(self, fields, expireAfter, unique = False, sparse = True, nam return ind def ensureGeoIndex(self, fields, name = None): - """Creates a geo index if it does not already exist, and returns it""" + """Create a geo index if it does not already exist, then return it.""" data = { "type" : "geo", "fields" : fields, @@ -596,7 +600,7 @@ def ensureGeoIndex(self, fields, name = None): return ind def ensureFulltextIndex(self, fields, minLength = None, name = None): - """Creates a fulltext index if it does not already exist, and returns it""" + """Create a fulltext index if it does not already exist, then return it.""" data = { "type" : "fulltext", "fields" : fields, @@ -613,13 +617,13 @@ def ensureFulltextIndex(self, fields, minLength = None, name = None): return ind def ensureIndex(self, index_type, fields, name=None, **index_args): - """Creates an index of any type.""" + """Create an index of any type.""" data = { "type" : index_type, "fields" : fields, } data.update(index_args) - + if name: data["name"] = name @@ -630,7 +634,7 @@ def ensureIndex(self, index_type, fields, name=None, **index_args): return ind def restoreIndexes(self, indexes_dct=None): - """restores all previously removed indexes""" + """Restore all previously removed indexes.""" if indexes_dct is None: indexes_dct = self.indexes @@ -642,7 +646,7 @@ def restoreIndexes(self, indexes_dct=None): self.ensureIndex(typ, idx.infos["fields"], **infos) def validatePrivate(self, field, value): - """validate a private field value""" + """Validate a private field value.""" if field not in self.arangoPrivates: raise ValueError("%s is not a private field of collection %s" % (field, self)) @@ -652,7 +656,9 @@ def validatePrivate(self, field, value): @classmethod def hasField(cls, fieldName): - """returns True/False wether the collection has field K in it's schema. Use the dot notation for the nested fields: address.street""" + """Return 'True' or 'False' whether the collection has field 'K' in its schema. + + Use the dot notation for the nested fields: address.street""" path = fieldName.split(".") v = cls._fields for k in path: @@ -663,8 +669,10 @@ def hasField(cls, fieldName): return True def fetchDocument(self, key, rawResults = False, rev = None): - """Fetches a document from the collection given it's key. This function always goes straight to the db and bypasses the cache. If you - want to take advantage of the cache use the __getitem__ interface: collection[key]""" + """Fetche a document from the collection given its key. + + This function always goes straight to the db and bypasses the cache. + If you want to take advantage of the cache use the '__getitem__' interface: collection[key]""" url = "%s/%s/%s" % (self.getDocumentsURL(), self.name, key) if rev is not None: r = self.connection.session.get(url, params = {'rev' : rev}) @@ -680,36 +688,41 @@ def fetchDocument(self, key, rawResults = False, rev = None): raise DocumentNotFoundError("Unable to find document with _key: %s, response: %s" % (key, r.json()), r.json()) def fetchByExample(self, exampleDict, batchSize, rawResults = False, **queryArgs): - """exampleDict should be something like {'age' : 28}""" + """'exampleDict' should be something like {'age' : 28}.""" return self.simpleQuery('by-example', rawResults, example = exampleDict, batchSize = batchSize, **queryArgs) def fetchFirstExample(self, exampleDict, rawResults = False): - """exampleDict should be something like {'age' : 28}. returns only a single element but still in a SimpleQuery object. - returns the first example found that matches the example""" + """'exampleDict' should be something like {'age' : 28}. + + Return the first example found that matches the example, still in a 'SimpleQuery' object.""" return self.simpleQuery('first-example', rawResults = rawResults, example = exampleDict) def fetchAll(self, rawResults = False, **queryArgs): - """Returns all the documents in the collection. You can use the optinal arguments 'skip' and 'limit':: + """Returns all the documents in the collection. + You can use the optinal arguments 'skip' and 'limit':: + fetchAlll(limit = 3, shik = 10)""" - fetchAll(limit = 3, skip = 10)""" return self.simpleQuery('all', rawResults = rawResults, **queryArgs) def simpleQuery(self, queryType, rawResults = False, **queryArgs): - """General interface for simple queries. queryType can be something like 'all', 'by-example' etc... everything is in the arango doc. - If rawResults, the query will return dictionaries instead of Document objetcs. - """ + """General interface for simple queries. + + 'queryType' takes the arguments known to the ArangoDB, for instance: 'all' or 'by-example'. + See the ArangoDB documentation for a list of valid 'queryType's. + If 'rawResults' is set to 'True', the query will return dictionaries instead of 'Document' objetcs.""" return SimpleQuery(self, queryType, rawResults, **queryArgs) def action(self, method, action, **params): - """a generic fct for interacting everything that doesn't have an assigned fct""" + """A generic 'fct' for interacting everything that does not have an assigned 'fct'.""" fct = getattr(self.connection.session, method.lower()) r = fct(self.getURL() + "/" + action, params = params) return r.json() def bulkSave(self, docs, onDuplicate="error", **params): - """Parameter docs must be either an iterrable of documents or dictionnaries. - This function will return the number of documents, created and updated, and will raise an UpdateError exception if there's at least one error. - params are any parameters from arango's documentation""" + """Parameter 'docs' must be either an iterable of documents or dictionaries. + + This function will return the number of documents, created and updated, and will raise an UpdateError exception if there is at least one error. + 'params' are any parameters from the ArangoDB documentation.""" payload = [] for d in docs: @@ -740,7 +753,7 @@ def bulkSave(self, docs, onDuplicate="error", **params): return data["updated"] + data["created"] def bulkImport_json(self, filename, onDuplicate="error", formatType="auto", **params): - """bulk import from a file repecting arango's key/value format""" + """Bulk import from a file following the ArangoDB key-value format.""" url = "%s/import" % self.database.getURL() params["onDuplicate"] = onDuplicate @@ -754,8 +767,8 @@ def bulkImport_json(self, filename, onDuplicate="error", formatType="auto", **pa raise UpdateError('Unable to bulk import JSON', r) def bulkImport_values(self, filename, onDuplicate="error", **params): - """bulk import from a file repecting arango's json format""" - + """Bulk import from a file following the ArangoDB json format.""" + url = "%s/import" % self.database.getURL() params["onDuplicate"] = onDuplicate params["collection"] = self.name @@ -767,43 +780,45 @@ def bulkImport_values(self, filename, onDuplicate="error", **params): raise UpdateError('Unable to bulk import values', r) def truncate(self): - """deletes every document in the collection""" + """Delete every document in the collection.""" return self.action('PUT', 'truncate') def empty(self): - """alias for truncate""" + """Alias for truncate.""" return self.truncate() def load(self): - """loads collection in memory""" + """Load collection in memory.""" return self.action('PUT', 'load') def unload(self): - """unloads collection from memory""" + """Unload collection from memory.""" return self.action('PUT', 'unload') def revision(self): - """returns the current revision""" + """Return the current revision.""" return self.action('GET', 'revision')["revision"] def properties(self): - """returns the current properties""" + """Return the current properties.""" return self.action('GET', 'properties') def checksum(self): - """returns the current checksum""" + """Return the current checksum.""" return self.action('GET', 'checksum')["checksum"] def count(self): - """returns the number of documents in the collection""" + """Return the number of documents in the collection.""" return self.action('GET', 'count')["count"] def figures(self): - "a more elaborate version of count, see arangodb docs for more infos" + """A more elaborate version of 'count', see the ArangoDB documentation for more.""" return self.action('GET', 'figures') def getType(self): - """returns a word describing the type of the collection (edges or ducments) instead of a number, if you prefer the number it's in self.type""" + """Return a word describing the type of the collection (edges or ducments) instead of a number. + + If you prefer the number it is in 'self.type'.""" if self.type == CONST.COLLECTION_DOCUMENT_TYPE: return "document" elif self.type == CONST.COLLECTION_EDGE_TYPE: @@ -811,7 +826,7 @@ def getType(self): raise ValueError("The collection is of Unknown type %s" % self.type) def getStatus(self): - """returns a word describing the status of the collection (loaded, loading, deleted, unloaded, newborn) instead of a number, if you prefer the number it's in self.status""" + """Return a word describing the status of the collection (loaded, loading, deleted, unloaded, newborn) instead of a number, if you prefer the number it is in 'self.status'.""" if self.status == CONST.COLLECTION_LOADING_STATUS: return "loading" elif self.status == CONST.COLLECTION_LOADED_STATUS: @@ -825,14 +840,17 @@ def getStatus(self): raise ValueError("The collection has an Unknown status %s" % self.status) def __len__(self): - """returns the number of documents in the collection""" + """Return the number of documents in the collection.""" return self.count() def __repr__(self): return "ArangoDB collection name: %s, id: %s, type: %s, status: %s" % (self.name, self.id, self.getType(), self.getStatus()) def __getitem__(self, key): - """returns a document from the cache. If it's not there, fetches it from the db and caches it first. If the cache is not activated this is equivalent to fetchDocument( rawResults=False)""" + """Return a document from the cache. + + If it is not there, fetch from the db and cache it first. + If the cache is not activated, this is equivalent to 'fetchDocument(rawResults=False)'.""" if self.documentCache is None: return self.fetchDocument(key, rawResults = False) try: @@ -843,7 +861,7 @@ def __getitem__(self, key): return doc def __contains__(self, key): - """if doc in collection""" + """Return 'True' or 'False' whether the doc is in the collection.""" try: self.fetchDocument(key, rawResults = False) return True @@ -851,26 +869,27 @@ def __contains__(self, key): return False class SystemCollection(Collection): - """for all collections with isSystem = True""" + """For all collections with 'isSystem=True'.""" def __init__(self, database, jsonData): Collection.__init__(self, database, jsonData) class Edges(Collection): - """The default edge collection. All edge Collections must inherit from it""" + """The default edge collection. All edge Collections must inherit from it.""" arangoPrivates = ["_id", "_key", "_rev", "_to", "_from"] def __init__(self, database, jsonData): - """This one is meant to be called by the database""" + """This one is meant to be called by the database.""" Collection.__init__(self, database, jsonData) self.documentClass = Edge self.edgesURL = "%s/edges/%s" % (self.database.getURL(), self.name) @classmethod def validateField(cls, fieldName, value): - """checks if 'value' is valid for field 'fieldName'. If the validation is unsuccessful, raises a SchemaViolation or a ValidationError. - for nested dicts ex: {address : { street: xxx} }, fieldName can take the form address.street - """ + """Check if 'value' is valid for field 'fieldName'. + + If the validation fails, raise a 'SchemaViolation' or a 'ValidationError'. + For nested dicts ex: {address : { street: xxx} }, 'fieldName' can take the form 'address.street'.""" try: valValue = Collection.validateField(fieldName, value) except SchemaViolation as e: @@ -880,20 +899,23 @@ def validateField(cls, fieldName, value): return valValue def createEdge(self, initValues = None): - """Create an edge populated with defaults""" + """Create an edge populated with defaults.""" return self.createDocument(initValues) def getInEdges(self, vertex, rawResults = False): - """An alias for getEdges() that returns only the in Edges""" + """An alias for 'getEdges()' that returns only the in 'Edges'.""" return self.getEdges(vertex, inEdges = True, outEdges = False, rawResults = rawResults) def getOutEdges(self, vertex, rawResults = False): - """An alias for getEdges() that returns only the out Edges""" + """An alias for 'getEdges()' that returns only the out 'Edges'.""" return self.getEdges(vertex, inEdges = False, outEdges = True, rawResults = rawResults) def getEdges(self, vertex, inEdges = True, outEdges = True, rawResults = False): - """returns in, out, or both edges liked to a given document. vertex can be either a Document object or a string for an _id. - If rawResults a arango results will be return as fetched, if false, will return a liste of Edge objects""" + """Return in, out, or both edges linked to a given document. + + Vertex can be either a 'Document' object or a string for an '_id'. + If 'rawResults' is set to 'True', return the results just as fetched without any processing. + Otherwise, return a list of Edge objects.""" if isinstance(vertex, Document): vId = vertex._id elif isinstance(vertex,str) or isinstance(vertex,bytes): @@ -928,11 +950,10 @@ class BulkOperation(object): def __init__(self, collection, batchSize=100): self.coll = collection self.batchSize = batchSize - + def __enter__(self): self.coll._isBulkInProgress = True self.coll._bulkSize = self.batchSize return self.coll def __exit__(self, type, value, traceback): self.coll._finalizeBatch(); - diff --git a/pyArango/connection.py b/pyArango/connection.py index 3826f36..c56b700 100644 --- a/pyArango/connection.py +++ b/pyArango/connection.py @@ -34,13 +34,14 @@ class AikidoSession: """ class Holder(object): - def __init__(self, fct, auth, max_conflict_retries=5, verify=True): + def __init__(self, fct, auth, max_conflict_retries=5, verify=True, timeout=30): self.fct = fct self.auth = auth self.max_conflict_retries = max_conflict_retries if not isinstance(verify, bool) and not isinstance(verify, CA_Certificate) and not not isinstance(verify, str) : raise ValueError("'verify' argument can only be of type: bool, CA_Certificate or str ") self.verify = verify + self.timeout = timeout def __call__(self, *args, **kwargs): if self.auth: @@ -50,10 +51,12 @@ def __call__(self, *args, **kwargs): else : kwargs["verify"] = self.verify + kwargs["timeout"] = self.timeout + try: do_retry = True retry = 0 - while do_retry and retry < self.max_conflict_retries : + while do_retry and retry < self.max_conflict_retries: ret = self.fct(*args, **kwargs) do_retry = ret.status_code == 1200 try : @@ -84,7 +87,8 @@ def __init__( max_retries=5, single_session=True, log_requests=False, - pool_maxsize=10 + pool_maxsize=10, + timeout=30, ): if username: self.auth = (username, password) @@ -95,6 +99,7 @@ def __init__( self.max_retries = max_retries self.log_requests = log_requests self.max_conflict_retries = max_conflict_retries + self.timeout = timeout self.session = None if single_session: @@ -133,12 +138,13 @@ def __getattr__(self, request_function_name): auth = object.__getattribute__(self, "auth") verify = object.__getattribute__(self, "verify") + timeout = object.__getattribute__(self, "timeout") if self.log_requests: log = object.__getattribute__(self, "log") log["nb_request"] += 1 log["requests"][request_function.__name__] += 1 - return AikidoSession.Holder(request_function, auth, max_conflict_retries=self.max_conflict_retries, verify=verify) + return AikidoSession.Holder(request_function, auth, max_conflict_retries=self.max_conflict_retries, verify=verify, timeout=timeout) def disconnect(self): pass @@ -180,6 +186,8 @@ class Connection(object): max number of requests for a conflict error (1200 arangodb error). Does not work with gevents (grequests), pool_maxsize: int max number of open connections. (Not intended for grequest) + timeout: int + number of seconds to wait on a hanging connection before giving up """ LOAD_BLANCING_METHODS = {'round-robin', 'random'} @@ -199,7 +207,8 @@ def __init__( use_lock_for_reseting_jwt=True, max_retries=5, max_conflict_retries=5, - pool_maxsize=10 + pool_maxsize=10, + timeout=30 ): if loadBalancing not in Connection.LOAD_BLANCING_METHODS: @@ -215,6 +224,7 @@ def __init__( self.max_retries = max_retries self.max_conflict_retries = max_conflict_retries self.action = ConnectionAction(self) + self.timeout = timeout self.databases = {} self.verbose = verbose @@ -295,7 +305,8 @@ def create_aikido_session( max_conflict_retries=self.max_conflict_retries, max_retries=self.max_retries, log_requests=False, - pool_maxsize=self.pool_maxsize + pool_maxsize=self.pool_maxsize, + timeout=self.timeout ) def create_grequest_session( diff --git a/pyArango/database.py b/pyArango/database.py index fb16204..1dfe24e 100644 --- a/pyArango/database.py +++ b/pyArango/database.py @@ -526,16 +526,20 @@ def __contains__(self, name_or_id): else: return self.hasCollection(name_or_id) or self.hasGraph(name_or_id) - def __getitem__(self, collectionName): - """use database[collectionName] to get a collection from the database""" + def __getitem__(self, col_or_doc_id): + """use database[col_or_doc_id] to get a collection from the database""" try: - return self.collections[collectionName] - except KeyError: - self.reload() + col_name, doc_key = col_or_doc_id.split('/') + return self.collections[col_name][doc_key] + except ValueError: try: - return self.collections[collectionName] + return self.collections[col_or_doc_id] except KeyError: - raise KeyError("Can't find any collection named : %s" % collectionName) + self.reload() + try: + return self.collections[col_or_doc_id] + except KeyError: + raise KeyError("Can't find any collection named : %s" % col_or_doc_id) class DBHandle(Database): "As the loading of a Database also triggers the loading of collections and graphs within. Only handles are loaded first. The full database are loaded on demand in a fully transparent manner." diff --git a/pyArango/document.py b/pyArango/document.py index 3d3ed41..7ef03c8 100644 --- a/pyArango/document.py +++ b/pyArango/document.py @@ -108,12 +108,7 @@ def validate(self): return True def set(self, dct): - """Set the store using a dictionary""" - # if not self.mustValidate: - # self.store = dct - # self.patchStore = dct - # return - + """Set the values to a dict. Any missing value will be filled by it's default""" for field, value in dct.items(): if field not in self.collection.arangoPrivates: if isinstance(value, dict): @@ -126,6 +121,14 @@ def set(self, dct): else: self[field] = value + def fill_default(self): + """replace all None values with defaults""" + for field, value in self.validators.items(): + if isinstance(value, dict): + self[field].fill_default() + elif self[field] is None: + self[field] = value.default + def __dir__(self): return dir(self.getStore()) @@ -234,6 +237,10 @@ def to_default(self): """reset the document to the default values""" self.reset(self.collection, self.collection.getDefaultDocument()) + def fill_default(self): + """reset the document to the default values""" + self._store.fill_default() + def validate(self): """validate the document""" self._store.validate() @@ -264,7 +271,9 @@ def save(self, waitForSync = False, **docArgs): If you want to only update the modified fields use the .patch() function. Use docArgs to put things such as 'waitForSync = True' (for a full list cf ArangoDB's doc). It will only trigger a saving of the document if it has been modified since the last save. If you want to force the saving you can use forceSave()""" + self._store.fill_default() payload = self._store.getStore() + # print(payload) self._save(payload, waitForSync = False, **docArgs) def _save(self, payload, waitForSync = False, **docArgs): diff --git a/pyArango/tests/tests.py b/pyArango/tests/tests.py index 69d3eb3..d08889a 100644 --- a/pyArango/tests/tests.py +++ b/pyArango/tests/tests.py @@ -1,5 +1,6 @@ import unittest, copy import os +from unittest.mock import MagicMock, patch from pyArango.connection import * from pyArango.database import * @@ -122,7 +123,78 @@ class theCol(Collection): doc.to_default() self.assertEqual(doc["address"]["street"], "Paper street") self.assertEqual(doc["name"], "Tyler Durden") + + # @unittest.skip("stand by") + def test_fill_default(self): + class theCol(Collection): + _fields = { + "name": Field( default="Paper"), + "dct1":{ + "num": Field(default=13), + "dct2":{ + "str": Field(default='string'), + } + } + } + + _validation = { + "on_save" : True, + "on_set" : True, + "allow_foreign_fields" : False + } + + col = self.db.createCollection("theCol") + doc = col.createDocument() + doc['name'] = 'Orson' + doc['dct1']['num'] = None + doc['dct1']['dct2']['str'] = None + + doc.fill_default() + self.assertEqual(doc['name'], 'Orson') + self.assertEqual(doc['dct1']['num'], 13) + self.assertEqual(doc['dct1']['dct2']['str'], 'string') + + # @unittest.skip("stand by") + def test_fill_default_on_save(self): + class theCol(Collection): + _fields = { + "name": Field( default="Paper"), + "dct1":{ + "num": Field(default=13), + "dct2":{ + "str": Field(default='string'), + } + } + } + + _validation = { + "on_save" : True, + "on_set" : True, + "allow_foreign_fields" : False + } + + col = self.db.createCollection("theCol") + doc = col.createDocument() + doc['name'] = 'Orson' + doc['dct1']['num'] = None + doc['dct1']['dct2']['str'] = None + store = doc.getStore() + doc.save() + + self.assertEqual(store['name'], 'Orson') + self.assertEqual(store['dct1']['num'], None) + self.assertEqual(store['dct1']['dct2']['str'], None) + + self.assertEqual(doc['name'], 'Orson') + self.assertEqual(doc['dct1']['num'], 13) + self.assertEqual(doc['dct1']['dct2']['str'], 'string') + + doc2 = col[doc['_key']] + self.assertEqual(doc2['name'], 'Orson') + self.assertEqual(doc2['dct1']['num'], 13) + self.assertEqual(doc2['dct1']['dct2']['str'], 'string') + # @unittest.skip("stand by") def test_bulk_operations(self): (collection, docs) = self.createManyUsersBulk(55, 17) @@ -1072,7 +1144,15 @@ def test_tasks(self): db_tasks.delete(task_id) self.assertListEqual(db_tasks(), []) + # @unittest.skip("stand by") + def test_timeout_parameter(self): + # Create a Connection object with the desired timeout + timeout = 120 + connection = Connection(arangoURL=ARANGODB_URL, username=ARANGODB_ROOT_USERNAME, password=ARANGODB_ROOT_PASSWORD, timeout=timeout) + # Verify that the Connection session was created with the correct timeout + assert connection.session.timeout == timeout + if __name__ == "__main__": # Change default username/password in bash like this: # export ARANGODB_ROOT_USERNAME=myUserName diff --git a/setup.py b/setup.py index c4f783b..5cddccf 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ setup( name='pyArango', - version='2.0.2', + version='2.1.0', description='An easy to use python driver for ArangoDB with built-in validation', long_description=long_description,