Skip to content

Commit

Permalink
can now save related models as one-to-one relationships
Browse files Browse the repository at this point in the history
  • Loading branch information
jjorissen52 committed Apr 30, 2019
1 parent e3df456 commit 3eb4574
Showing 1 changed file with 131 additions and 63 deletions.
194 changes: 131 additions & 63 deletions fsmodels/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ class ValidationError(BaseException):
pass


class _BaseModel:
# all Model classes will be subclassed from this. Otherwise we would have circular requirements for type hints
# in methods that required BaseModel as type hints
pass


class Field:

# name is overwritten by the Model containing the Field instance.
Expand Down Expand Up @@ -109,7 +115,48 @@ def __repr__(self):
return f'<{self.__class__.__name__} name:{self.name} required:{self.required} default:{self.default} validation:{self.validation.__name__}>'


class BaseModel:
class ModelField(Field):
"""
Subclass of Field that makes reference to a subclass of BaseModel.
Used for one-to-many relationships.
"""

def __init__(self, model: _BaseModel, **kwargs):
# keeping track of this stuff so we can emit useful error messages
self.field_model = model
self.field_model_name = model.__name__
super(ModelField, self).__init__(**kwargs)

def validate(self, model_instance, raise_error: bool = True) -> (bool, dict):
"""
Check to see that the passed model instance is a subclass of `model` parameter passed into ModelField.__init__,
then validate the fields of that model as usual. Parallels the validate method of the Field class
:param model_instance: instance of model to validate (parallels `value` in validate method of Field class)
:param raise_error: whether or not an exception is raised on validation error
:return (bool, dict): whether or not there was an error and a dict describing the errors
"""
is_valid_model, is_valid_field, model_errors, field_errors = True, {}, True, {}
# check that the passed model_instance is a subclass of the prescribed model from __init__
if isinstance(model_instance, self.field_model):
is_valid_model, model_errors = model_instance.validate(raise_error)
# the model instance is not None, this will emit an error. Otherwise, we check the
# field validation logic to determine whether this is a required field.
elif model_instance is not None:
message = f'{self.name} field failed validation. {model_instance} is {model_instance.__class__.__name__}, must be {self.field_model_name}'
if raise_error:
raise ValidationError(message)
else:
return False, {'error': message}
if is_valid_model:
is_valid_field, field_errors = super(ModelField, self).validate(model_instance, raise_error)
if is_valid_field and is_valid_model:
return True, {}
return False, {**model_errors, **field_errors}


class BaseModel(_BaseModel):
"""
Example:
Expand All @@ -130,23 +177,28 @@ def _get_fields(self) -> frozenset:
field_attr_names = []
for field_attr_name in dir(self):
attr = inspect.getattr_static(self, field_attr_name)
if isinstance(attr, Field):
if not isinstance(attr, ModelField) and isinstance(attr, Field):
field_attr_names.append(field_attr_name)
# frozenset will always return set items in the same order regardless of the order
# that they are added. This results in hash-safe sets
return frozenset(field_attr_names)

def __init__(self, _validate_on_init: bool = False, **kwargs):
f"""
Sets all Field instances defined on a BaseModel subclass as private members to the BaseModel subclass instance.
Then creates public members with the value of Field<instance>.default(*args, **kwargs)
:param _validate_on_init: Whether or not to call the .validate() function on init
:param kwargs: values corresponding to fields defined on the subclass of BaseModel.
def _get_model_fields(self) -> frozenset:
"""
Used to keep track of all ModelField instances defined on a subclass of BaseModel.
{self.__class__.__doc__}
:return: hashable set (frozenset) of all ModelFields defined on the BaseModel subclass
"""
field_attr_names = []
for field_attr_name in dir(self):
attr = inspect.getattr_static(self, field_attr_name)
if isinstance(attr, ModelField):
field_attr_names.append(field_attr_name)
# frozenset will always return set items in the same order regardless of the order
# that they are added. This results in hash-safe sets
return frozenset(field_attr_names)

def _set_fields(self, _validate_on_init, kwargs):
self._field_names = self._get_fields()
for field_name in self._field_names:
field = getattr(self, field_name)
Expand All @@ -160,6 +212,36 @@ def __init__(self, _validate_on_init: bool = False, **kwargs):
if _validate_on_init:
field.validate(field_value)

def _set_model_fields(self, _validate_on_init, kwargs):
self._model_field_names = self._get_model_fields()
for field_name in self._model_field_names:
field = getattr(self, field_name)
field.model_name = self.__class__.__name__
field.name = field_name
field_value = kwargs.get(field_name, field.default())
# replaces the original field with the corresponding value
setattr(self, field_name, field_value)
# actually `Field` instance becomes hidden
setattr(self, f'_{field_name}', field)
# if _validate_on_init:
# field.validate(field_value)

def __init__(self, _validate_on_init: bool = False, **kwargs):
f"""
Sets all Field instances defined on a BaseModel subclass as private members to the BaseModel subclass instance.
Then creates public members with the value of Field<instance>.default(*args, **kwargs)
:param _validate_on_init: Whether or not to call the .validate() function on init
:param kwargs: values corresponding to fields defined on the subclass of BaseModel.
{self.__class__.__doc__}
"""

self._set_fields(_validate_on_init, kwargs)
self._set_model_fields(_validate_on_init, kwargs)



@property
def is_valid(self):
return self.validate(raise_error=False)[0]
Expand Down Expand Up @@ -209,10 +291,9 @@ class User(BaseModel):
:return:
"""
field_tuple = tuple(
(field_name, getattr(self, field_name)) for field_name in self._field_names)
field_tuple = tuple((field_name, getattr(self, field_name)) for field_name in self._field_names.union(self._model_field_names))
return {
field_name: field_value.to_dict() if hasattr(field_value, 'field_to_dict') else field_value
field_name: field_value.to_dict() if hasattr(field_value, 'to_dict') else field_value
for field_name, field_value in field_tuple
}

Expand Down Expand Up @@ -249,47 +330,6 @@ def save(self):
raise NotImplemented


class ModelField(Field):
"""
Subclass of Field that makes reference to a subclass of BaseModel.
Used for one-to-many relationships.
"""

def __init__(self, model: BaseModel, **kwargs):
# keeping track of this stuff so we can emit useful error messages
self.field_model = model
self.field_model_name = model.__name__
super(ModelField, self).__init__(**kwargs)

def validate(self, model_instance, raise_error: bool = True) -> (bool, dict):
"""
Check to see that the passed model instance is a subclass of `model` parameter passed into ModelField.__init__,
then validate the fields of that model as usual. Parallels the validate method of the Field class
:param model_instance: instance of model to validate (parallels `value` in validate method of Field class)
:param raise_error: whether or not an exception is raised on validation error
:return (bool, dict): whether or not there was an error and a dict describing the errors
"""
is_valid_model, is_valid_field, model_errors, field_errors = True, {}, True, {}
# check that the passed model_instance is a subclass of the prescribed model from __init__
if isinstance(model_instance, self.field_model) or issubclass(model_instance.__class__, self.field_model):
is_valid_model, model_errors = model_instance.validate(raise_error)
# the model instance is not None, this will emit an error. Otherwise, we check the
# field validation logic to determine whether this is a required field.
elif model_instance is not None:
message = f'{self.name} field failed validation. {model_instance} is {model_instance.__class__.__name__}, must be {self.field_model_name}'
if raise_error:
raise ValidationError(message)
else:
return False, {'error': message}
if is_valid_model:
is_valid_field, field_errors = super(ModelField, self).validate(model_instance, raise_error)
if is_valid_field and is_valid_model:
return True, {}
return False, {**model_errors, **field_errors}


class Model(BaseModel):
# TODO: implement relational firestore logic, add to __doc__
"""
Expand Down Expand Up @@ -349,34 +389,62 @@ def _document_exists(document) -> bool:
return bool(document.to_dict())

def clean(self) -> dict:
# TODO revisit clean to do more than just validate
"""
returns a cleaned version of the Model as a dict
:return:
"""
self.validate()
return self.to_dict()

def save(self, patch: bool = False) -> dict:
def save(self, patch: bool = True, additional_fields: Optional[dict] = None) -> dict:
"""
Save the record to the relevant collection in firestore (self._collection). If there is an id, it tries to
fetch the existing record first. If not, it creates a new record.
:param patch: only update the firestore record according to the values defined on the instance (rather than
overwriting the entire to match the instance)
:return: dictionary with id and the result of the write operation from firstore
:param additional_fields: dictionary of any additional fields to be saved on the firestore record that are
not defined explicitly on the model
:return: dictionary with id and the result of the write operation from firestore
"""
record = self.clean()

# after this if/else branch, we know for sure that self.id will refer to an id in firestore
if record.get('id', False):
# remove id from record so it is not saved as an attribute in firestore,
# use it to get an existing document or create a new document with corresponding ID.
document_ref = self.collection.document(str(record.pop('id')))
# if the user asks for "patch", we use update instead.
if self._document_exists(document_ref.get()) and patch:
return {'id': self.id, 'result': document_ref.update(record)}
new_record = not self._document_exists(document_ref.get())
else:
new_record = True
document_ref = self.collection.document()
self.id = document_ref.id

# we want to write the related model ids on the saved firestore record
related_record_ids = {}
reverse_id_label = f'{self._collection}_id'
for model_field_instance_name in self._get_model_fields():
field_name = model_field_instance_name.lstrip('_')
record.pop(field_name) # prevent from saving this information on the record (is that what we want?)
model_field_value = getattr(self, field_name)
# sanity check
assert model_field_value is None or isinstance(model_field_value, BaseModel)
# save related model field in firestore and give us the id
model_field_record_id = None if model_field_value is None else str(model_field_value.save(
additional_fields={reverse_id_label: str(self.id)}
)['id'])
related_record_ids[f'{model_field_value._collection}_id'] = model_field_record_id

record.update(related_record_ids)
if additional_fields:
record.update(additional_fields)

if not new_record:
if patch:
return {'id': self.id, 'result': document_ref.update(record)}
else:
return {'id': self.id, 'result': document_ref.set(record)}

return {'id': self.id, 'result': document_ref.set(record)}

def retrieve(self, overwrite_local: bool = False) -> dict:
Expand All @@ -388,7 +456,7 @@ def retrieve(self, overwrite_local: bool = False) -> dict:
:return:
"""
if not self.id:
raise ValidationError(f'Cannot retrieve document for {self._connection}; no id specified.')
raise ValidationError(f'Cannot retrieve document for {self._collection}; no id specified.')
document_dict = self.collection.document(str(self.id)).get().to_dict()
if document_dict is None:
return {}
Expand All @@ -404,6 +472,6 @@ def delete(self) -> dict:
:return: dictionary describing the result of the delete operation from firestore
"""
if not self.id:
raise ValidationError(f'Cannot call delete for {self._connection} document; no id specified.')
raise ValidationError(f'Cannot call delete for {self._collection} document; no id specified.')
document_ref = self.collection.document(str(self.id))
return {'result': document_ref.delete()}

0 comments on commit 3eb4574

Please sign in to comment.