Skip to content

Commit

Permalink
Merge pull request #3 from ESSS/fb-RFDAP-485-replace-marshmallow
Browse files Browse the repository at this point in the history
Replace marshmallow serializers by our own implementation
  • Loading branch information
igortg committed Feb 21, 2018
2 parents 8380233 + 59780ba commit 594f05d
Show file tree
Hide file tree
Showing 17 changed files with 378 additions and 105 deletions.
4 changes: 2 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# Byte-compiled / optimized / DLL files
__pycache__/
.*cache
.~*
*.py[cod]
*$py.class
build/
dist/

# Project settings
/.idea/
.cache
.~*
1 change: 0 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ language: python
python:
- 3.6
- 3.5
- 2.7

install:
- pip install -r requirements.txt
Expand Down
20 changes: 7 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

![travis-ci](https://api.travis-ci.org/ESSS/flask-rest-orm.svg?branch=master)

A Flask extension to build REST APIs based on SQLAlchemy models. It uses [marshmallow-sqlalchemy]
to serialize models and avoid the need of building *Schema* classes, since *Schemas* are
typically a repetition of your model.
A Flask extension to build REST APIs. It dismiss the need of building *Schema* classes,
since usually all the information needed to serialize an SQLAlchemy instance is in the model
itself.

By adding a model to the API, all its properties will be exposed:

Expand All @@ -28,19 +28,13 @@ To change the way properties are serialized, declare only the one that needs a n
behaviour:

```python
from marshmallow import fields
from marshmallow_sqlalchemy import ModelSchema
from flask_rest_orm import ModelSerializer, Field

class UserSerializer(ModelSchema):
class Meta:
include_fk = True
model = User
class UserSerializer(ModelSerializer):

password = fields.Str(load_only=True)
password = Field(load_only=True)


api = Api(flask_app)
api.add_model(User, '/user', serializer=UserSerializer())
api.add_model(User, '/user', serializer_class=UserSerializer)
```

[marshmallow-sqlalchemy]: https://marshmallow-sqlalchemy.readthedocs.io
2 changes: 2 additions & 0 deletions flask_rest_orm/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from .api import Api
from .serialization.modelserializer import *
from .serialization.serializer import Serializer
46 changes: 19 additions & 27 deletions flask_rest_orm/api.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from flask_restful import Api as RestfulApi
from marshmallow_sqlalchemy import ModelSchema

from flask_rest_orm.resources.resources import CollectionResource, ItemResource, CollectionRelationResource, \
ItemRelationResource, CollectionPropertyResource
from flask_rest_orm.serialization.modelserializer import ModelSerializer


class Api(object):
Expand All @@ -14,7 +14,7 @@ def __init__(self, app=None, prefix='', errors=None, request_decorators=None):
if app:
self.init_app(app)

def add_model(self, model, url=None, serializer=None, request_decorators=None,
def add_model(self, model, url=None, serializer_class=None, request_decorators=None,
collection_decorators=None, collection_name=None):
"""
Create API endpoints for the given SQLAlchemy declarative class.
Expand All @@ -28,9 +28,7 @@ def add_model(self, model, url=None, serializer=None, request_decorators=None,
:param string collection_name: custom name for the collection endpoint url definition, if not set the model
table name will be used
:param ModelSchema serializer: Marshmallow schema for serialization. If `None`, a default serializer will be
created.
:param Type[ModelSerializer] serializer_class: If `None`, a default serializer will be created.
:param list|dict request_decorators: decorators to be applied to HTTP methods. Could be a list of decorators
or a dict mapping HTTP method types to a list of decorators (dict keys should be 'get', 'post' or 'put').
Expand All @@ -42,8 +40,10 @@ def add_model(self, model, url=None, serializer=None, request_decorators=None,
"""
restful = self.restful_api
collection_name = collection_name or model.__tablename__
if not serializer:
serializer = self.create_default_serializer(model)()
if not serializer_class:
serializer = self.create_default_serializer(model)
else:
serializer = serializer_class(model)
url = url or '/' + collection_name.lower()

if not request_decorators:
Expand All @@ -70,7 +70,7 @@ class _ItemResource(ItemResource):
resource_class_args=(model, serializer, self.get_db_session)
)

def add_relation(self, relation_property, url_rule=None, serializer=None, request_decorators=None,
def add_relation(self, relation_property, url_rule=None, serializer_class=None, request_decorators=None,
collection_decorators=None, endpoint_name=None):
"""
Create API endpoints for the given SQLAlchemy relationship.
Expand All @@ -80,8 +80,7 @@ def add_relation(self, relation_property, url_rule=None, serializer=None, reques
:param string url_rule: one or more url routes to match for the resource, standard
flask routing rules apply. Defaults to model name in lower case.
:param ModelSchema serializer: Marshmallow schema for serialization. If `None`, a default serializer will be
created.
:param Type[ModelSerializer] serializer_class: If `None`, a default serializer will be created.
:param list|dict request_decorators: decorators to be applied to HTTP methods. Could be a list of decorators
or a dict mapping HTTP method types to a list of decorators (dict keys should be 'get', 'post' or 'put').
Expand All @@ -100,8 +99,10 @@ def add_relation(self, relation_property, url_rule=None, serializer=None, reques
model_collection_name = model.__tablename__.lower()
related_collection_name = related_model.__tablename__.lower()
endpoint_name = endpoint_name or '{}-{}-relation'.format(model_collection_name, related_collection_name)
if not serializer:
serializer = self.create_default_serializer(model)()
if not serializer_class:
serializer = self.create_default_serializer(model)
else:
serializer = serializer_class(model)
if url_rule:
assert '<relation_id>' in url_rule
else:
Expand Down Expand Up @@ -141,9 +142,11 @@ def _add_item_collection_resources(self, item_resource, collection_resource, url
resource_class_args=resource_init_args,
)

def add_property(self, model, related_model, property_name, url_rule=None, serializer=None, request_decorators=[]):
if not serializer:
serializer = self.create_default_serializer(model)()
def add_property(self, model, related_model, property_name, url_rule=None, serializer_class=None, request_decorators=[]):
if not serializer_class:
serializer = self.create_default_serializer(model)
else:
serializer = serializer_class(model)
related_collection_name = related_model.__tablename__.lower()
if url_rule:
assert '<relation_id>' in url_rule
Expand Down Expand Up @@ -173,18 +176,7 @@ def create_default_serializer(model_class):
:rtype: class
"""

class Meta(object):
model = model_class
include_fk = True

schema_class_name = '{}Schema'.format(model_class.__name__)
schema_class = type(
schema_class_name,
(ModelSchema,),
{'Meta': Meta}
)
return schema_class
return ModelSerializer(model_class)

def init_app(self, app):
self.restful_api.init_app(app)
Expand Down
37 changes: 18 additions & 19 deletions flask_rest_orm/resources/resources.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from flask import request, json
from flask_restful import Resource
from marshmallow_sqlalchemy import ModelSchema

from flask_rest_orm.serialization.modelserializer import ModelSerializer
from .utils import query_from_request


Expand All @@ -22,24 +21,24 @@ def __init__(self, declarative_model, serializer, session_getter):
self._resource_model = declarative_model
self._serializer = serializer
self._serializer.strict = True
assert isinstance(self._serializer, ModelSchema), 'Invalid serializer instance: {}'.format(serializer)
assert isinstance(self._serializer, ModelSerializer), 'Invalid serializer instance: {}'.format(serializer)
self._session_getter = session_getter

def save_from_request(self, extra_attrs={}):
session = self._session_getter()
model_obj = self._serializer.load(load_request_data(), session).data
model_obj = self._serializer.load(load_request_data())
for attr_name, value in extra_attrs.items():
setattr(model_obj, attr_name, value)
session.add(model_obj)
session.commit()
return self._serializer.dump(model_obj).data

def _save_serialized(self, serialized_data):
def _save_serialized(self, serialized_data, existing_model=None):
session = self._session_getter()
model_obj = self._serializer.load(serialized_data, session).data
model_obj = self._serializer.load(serialized_data, existing_model)
session.add(model_obj)
session.commit()
return self._serializer.dump(model_obj).data
return self._serializer.dump(model_obj)

@property
def _db_session(self):
Expand All @@ -56,15 +55,15 @@ def get(self, id):
data = self._resource_model.query.get(id)
if data is None:
return NOT_FOUND_ERROR, 404
return self._serializer.dump(data).data
return self._serializer.dump(data)

def put(self, id):
data = self._resource_model.query.get(id)
if data is None:
return NOT_FOUND_ERROR, 404
serialized = self._serializer.dump(data).data
serialized = self._serializer.dump(data)
serialized.update(load_request_data())
self._save_serialized(serialized)
self._save_serialized(serialized, data)
return serialized

def delete(self, id):
Expand All @@ -87,7 +86,7 @@ def get(self):
collection = []
data = query_from_request(self._resource_model, request)
for item in data:
collection.append(self._serializer.dump(item).data)
collection.append(self._serializer.dump(item))
return collection

def post(self):
Expand Down Expand Up @@ -127,7 +126,7 @@ def get(self, relation_id):
return NOT_FOUND_ERROR, 404
# TODO: Is there a more efficient way than using getattr?
data = getattr(related_obj, self._relation_property.key)
collection = [self._serializer.dump(item).data for item in data]
collection = [self._serializer.dump(item) for item in data]
return collection


Expand All @@ -137,11 +136,11 @@ def post(self, relation_id):
if not related_obj:
return NOT_FOUND_ERROR, 404
collection = getattr(related_obj, self._relation_property.key)
new_obj = self._serializer.load(load_request_data(), session).data
new_obj = self._serializer.load(load_request_data())
collection.append(new_obj)
session.add(new_obj)
session.commit()
return self._serializer.dump(new_obj).data, 201
return self._serializer.dump(new_obj), 201


class ItemRelationResource(BaseResource):
Expand Down Expand Up @@ -172,15 +171,15 @@ def get(self, relation_id, id):
requested_obj = self._query_related_obj(relation_id, id)
if not requested_obj:
return NOT_FOUND_ERROR, 404
return self._serializer.dump(requested_obj).data, 200
return self._serializer.dump(requested_obj), 200

def put(self, relation_id, id):
requested_obj = self._query_related_obj(relation_id, id)
if not requested_obj:
return NOT_FOUND_ERROR, 404
serialized = self._serializer.dump(requested_obj).data
serialized = self._serializer.dump(requested_obj)
serialized.update(load_request_data())
return self._save_serialized(serialized)
return self._save_serialized(serialized, requested_obj)

def delete(self, relation_id, id):
requested_obj = self._query_related_obj(relation_id, id)
Expand Down Expand Up @@ -212,7 +211,7 @@ def get(self, relation_id):
if related_obj is None:
return NOT_FOUND_ERROR, 404
data = getattr(related_obj, self._property_name)
collection = [self._serializer.dump(item).data for item in data]
collection = [self._serializer.dump(item) for item in data]
return collection

def post(self, relation_id):
Expand All @@ -232,4 +231,4 @@ def load_request_data():
return request.form.to_dict()


NOT_FOUND_ERROR = 'Resource not found in the database!'
NOT_FOUND_ERROR = 'Resource not found in the database!'
Empty file.

0 comments on commit 594f05d

Please sign in to comment.