Skip to content

Commit

Permalink
Ratelimits (#160)
Browse files Browse the repository at this point in the history
* Added flexible ratelimits that can be distributed amongst oauth clients that belong to the user account

* Added more tests and more logic to bootstrap

* Updated alembic and added documentation
  • Loading branch information
romanchyla committed Sep 18, 2018
1 parent 42daa15 commit 86c3f29
Show file tree
Hide file tree
Showing 15 changed files with 381 additions and 40 deletions.
2 changes: 1 addition & 1 deletion adsws/accounts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,5 +181,5 @@ def print_token(token):
'expire_in': expiry,
'token_type': 'Bearer',
'scopes': token.scopes,
'anonymous': anon,
'anonymous': anon
}
134 changes: 118 additions & 16 deletions adsws/accounts/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

from adsws.ext.ratelimiter import ratelimit, scope_func
from flask.sessions import SecureCookieSessionInterface
from flask.ext.login import current_user, login_user
from flask.ext.restful import Resource, abort, reqparse
from flask_login import current_user, login_user
from flask_restful import Resource, abort, reqparse, inputs
from flask.ext.wtf.csrf import generate_csrf
from flask import current_app, session, abort, request
from .utils import validate_email, validate_password, \
Expand All @@ -19,8 +19,8 @@
from .exceptions import ValidationError, NoClientError, NoTokenError
from .emails import PasswordResetEmail, VerificationEmail, \
EmailChangedNotification, WelcomeVerificationEmail


from sqlalchemy import func
from sqlalchemy.orm import load_only

class StatusView(Resource):
"""
Expand Down Expand Up @@ -657,6 +657,8 @@ class Bootstrap(Resource):
for the bumblebee javascript client to authenticate and interact with
other adsws-api resources.
"""

decorators = [oauth2.optional_oauth()]

def get(self):
"""
Expand All @@ -679,21 +681,40 @@ def get(self):
parser.add_argument('redirect_uri', type=str)
parser.add_argument('scope', type=str)
parser.add_argument('client_name', type=str)
parser.add_argument('ratelimit', type=float)
parser.add_argument('create_new', type=inputs.boolean)

kwargs = parser.parse_args()

scopes = kwargs.get('scope', None)
client_name = kwargs.get('client_name', None)
redirect_uri = kwargs.get('redirect_uri', None)

ratelimit = kwargs.get('ratelimit', 1.0)
create_new = kwargs.get('create_new', False)

if ratelimit is None:
ratelimit = 1.0

assert ratelimit >= 0.0

# If we visit this endpoint and are unauthenticated, then login as
# our anonymous user
if not current_user.is_authenticated():

if 'scopes' in kwargs or client_name or redirect_uri:
abort(401, "Sorry, you cant change scopes/name/redirect_uri when creating temporary OAuth application")

login_user(user_manipulator.first(
email=current_app.config['BOOTSTRAP_USER_EMAIL']
))

if scopes or client_name:
abort(401, "Sorry, you cant change scopes/name/redirect_uri of this user")
try:
scopes = self._sanitize_scopes(kwargs.get('scope', None))
except ValidationError, e:
return {'error': e.value}, 400
try:
self._check_ratelimit(ratelimit)
except ValidationError, e:
return {'error': e.value}, 400

if current_user.email == current_app.config['BOOTSTRAP_USER_EMAIL']:
try:
Expand All @@ -712,24 +733,73 @@ def get(self):
client, token = Bootstrap.bootstrap_bumblebee()
session['oauth_client'] = client.client_id
else:
client, token = Bootstrap.bootstrap_user()
if create_new:
client, token = Bootstrap.bootstrap_user_new(client_name, scopes=scopes, ratelimit=ratelimit)
else:
client, token = Bootstrap.bootstrap_user(client_name, scopes=scopes, ratelimit=ratelimit)

if scopes:
client._default_scopes = scopes
if redirect_uri:
client._redirect_uris = redirect_uri
if client_name:
client.client_name = client_name
if client.ratelimit != ratelimit:
client.ratelimit = ratelimit

client.last_activity = datetime.datetime.now()
output = print_token(token)

output['client_id'] = client.client_id
output['client_secret'] = client.client_secret
output['ratelimit'] = client.ratelimit
output['client_name'] = client.name

db.session.commit()
return output


def _check_ratelimit(self, ratelimit):
"""Method to verify that there exists available space in the allotted resources
available to this user. A user account can have unlimited 'ratelimit_level'
if the ratelimit_level=-1, or the ratelimit_level specifies how big the global
amount is."""

# we are always called with some user logged in
allowed_limit = current_user.ratelimit_level or 2.0
if allowed_limit == -1:
return True

# count the existing clients
used = db.session.query(func.sum(OAuthClient.ratelimit).label('sum')).filter(OAuthClient.user_id==current_user.get_id()).first()[0] or 0.0
#for x in db.session.query(OAuthClient).filter_by(user_id=current_user.get_id()).options(load_only('ratelimit')).all():
# used += x.ratelimit_level

if allowed_limit - (used+ratelimit) < 0:
raise ValidationError('The current user account does not have enough capacity to create a new client. Requested: %s, Available: %s' % (ratelimit, allowed_limit-used))
return True


def _sanitize_scopes(self, scopes):
"""Makes sure that one can request only scopes that are available
to the given user."""
if not scopes:
return

if hasattr(request, 'oauth'):
allowed_scopes = request.oauth.user.allowed_scopes
elif current_user:
allowed_scopes = current_user.allowed_scopes
else:
raise ValidationError('kabooom') # should NEVER ever happen

if '*' in allowed_scopes:
return scopes
scopes = set(scopes.split())
if not set(allowed_scopes).issuperset(scopes):
raise ValidationError('You have requested a scope not available to the current user')
return ' '.join(sorted(set(allowed_scopes).intersection(scopes)))


@staticmethod
def load_client(clientid):
"""
Expand Down Expand Up @@ -798,11 +868,44 @@ def bootstrap_bumblebee():
return client, token


@staticmethod
@ratelimit.shared_limit_and_check("2/60 second", scope=scope_func)
def bootstrap_user_new(client_name=None, scopes=None, ratelimit=1.0):
"""
Create a OAuthClient owned by the authenticated real user.
Similar logic performed for the OAuthToken.
:return: OAuthToken instance
"""
assert current_user.email != current_app.config['BOOTSTRAP_USER_EMAIL']

uid = current_user.get_id()
client_name = client_name or current_app.config.get('BOOTSTRAP_CLIENT_NAME', 'BB client')

client = OAuthClient(
user_id=current_user.get_id(),
name=client_name,
description=client_name,
is_confidential=True,
is_internal=True,
_default_scopes=scopes or ' '.join(current_app.config['USER_DEFAULT_SCOPES']),
ratelimit=ratelimit
)
client.gen_salt()
db.session.add(client)

token = Bootstrap.create_user_token(client)
db.session.add(token)
current_app.logger.info(
"Created OAuth client for {email}".format(email=current_user.email)
)
db.session.commit()
return client, token

@staticmethod
@ratelimit.shared_limit_and_check("100/600 second", scope=scope_func)
def bootstrap_user():
def bootstrap_user(client_name=None, scopes=None, ratelimit=1.0):
"""
Return or create a OAuthClient owned by the authenticated real user.
Re-uses an existing client if "oauth_client" is found in the database
Expand All @@ -815,15 +918,13 @@ def bootstrap_user():
assert current_user.email != current_app.config['BOOTSTRAP_USER_EMAIL']

uid = current_user.get_id()
client_name = current_app.config.get('BOOTSTRAP_CLIENT_NAME', 'BB client')
client_name = client_name or current_app.config.get('BOOTSTRAP_CLIENT_NAME', 'BB client')

client = OAuthClient.query.filter_by(
user_id=uid,
name=client_name,
).first()
).order_by(OAuthClient.created.desc()).first()

scopes = ' '.join(current_app.config['USER_DEFAULT_SCOPES'])
salt_length = current_app.config.get('OAUTH2_CLIENT_ID_SALT_LEN', 40)

if client is None:
client = OAuthClient(
Expand All @@ -832,7 +933,8 @@ def bootstrap_user():
description=client_name,
is_confidential=True,
is_internal=True,
_default_scopes=scopes,
_default_scopes=scopes or ' '.join(current_app.config['USER_DEFAULT_SCOPES']),
ratelimit=ratelimit
)
client.gen_salt()
db.session.add(client)
Expand Down
15 changes: 15 additions & 0 deletions adsws/core/users/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from adsws.ext.sqlalchemy import db
from sqlalchemy.orm import synonym
from flask.ext.security.utils import encrypt_password, verify_password
from flask import current_app


roles_users = db.Table(
'roles_users',
Expand All @@ -32,6 +34,7 @@ class User(UserMixin, db.Model):
login_count = db.Column(db.Integer)
registered_at = db.Column(db.DateTime())
ratelimit_level = db.Column(db.Integer)
_allowed_scopes = db.Column(db.Text)

roles = db.relationship('Role', secondary=roles_users,
backref=db.backref('users', lazy='dynamic'))
Expand Down Expand Up @@ -62,6 +65,18 @@ def get_id(self):
need to convert it to unicode.
"""
return unicode(self.id)

@property
def allowed_scopes(self):
"""Returns list of scopes that this user is allowed to request/initialize
when bootstraping a new OAuth client; for example ads:internal scopes
should only be given/owned to user accounts that can safely dispense
with them, but should not be available to other users.
"""

if self._allowed_scopes:
return self._allowed_scopes.split(' ')
return current_app.config['USER_DEFAULT_SCOPES']


class Role(RoleMixin, db.Model):
Expand Down
12 changes: 7 additions & 5 deletions adsws/ext/ratelimiter/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def scope_func(endpoint_name):

def limit_func(counts, per_second):
"""
Returns the default limit multiplied by user's ratelimit_level attribute,
Returns the default limit multiplied by the OAuth client's ratelimit attribute,
if it exists.
:param counts: default rate limit
:type counts: int
Expand All @@ -48,10 +48,12 @@ def limit_func(counts, per_second):
:return user's ratelimit
:rtype int
"""
factor = 1

if hasattr(request, 'oauth'):
try:
factor = request.oauth.user.ratelimit_level or 1
factor = request.oauth.client.ratelimit
if factor is None:
factor = 1.0
except AttributeError:
pass
return "{0}/{1} second".format(counts * factor, per_second)
factor = 1.0
return "{0}/{1} second".format(int(counts * factor), per_second)
8 changes: 7 additions & 1 deletion adsws/modules/oauth2server/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from sqlalchemy_utils import URLType

from adsws.core import db, user_manipulator
from adsmutils import get_date, UTCDateTime

from oauthlib.oauth2.rfc6749.errors import InsecureTransportError, \
InvalidRedirectURIError
Expand Down Expand Up @@ -141,7 +142,12 @@ class OAuthClient(db.Model):

user = db.relationship('User')
""" Relationship to user. """


ratelimit = db.Column(db.Float, default=0.0)
""" Pre-computed allotment of the available rates of the user's global ratelimit."""

created = db.Column(UTCDateTime, default=get_date)

@property
def allowed_grant_types(self):
return current_app.config['OAUTH2_ALLOWED_GRANT_TYPES']
Expand Down
36 changes: 32 additions & 4 deletions adsws/modules/oauth2server/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,44 @@

from datetime import datetime, timedelta

from flask import current_app
from flask import current_app, request
from flask.ext.login import current_user
from flask_oauthlib.provider import OAuth2Provider
from flask_oauthlib.utils import extract_params
from flask_login import current_user

from adsws.core import db, user_manipulator
from .models import OAuthToken, OAuthClient, OAuthGrant


oauth2 = OAuth2Provider()
from functools import wraps

class OAuth2bProvider(OAuth2Provider):
def optional_oauth(self, *scopes):
"""Protect resource with specified scopes."""
def wrapper(f):
@wraps(f)
def decorated(*args, **kwargs):
for func in self._before_request_funcs:
func()

if hasattr(request, 'oauth') and request.oauth:
return f(*args, **kwargs)

server = self.server
uri, http_method, body, headers = extract_params()
valid, req = server.verify_request(
uri, http_method, body, headers, scopes
)

for func in self._after_request_funcs:
valid, req = func(valid, req)

if valid:
request.oauth = req
return f(*args, **kwargs)
return decorated
return wrapper

oauth2 = OAuth2bProvider()

@oauth2.clientgetter
def load_client(client_id):
Expand Down
3 changes: 2 additions & 1 deletion adsws/tests/api_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ def handle_404(e):
is_confidential=False,
user_id=user.id,
_redirect_uris='%s/client/authorized' % self.base_url,
_default_scopes="adsws:internal"
_default_scopes="adsws:internal",
ratelimit=1.0
)
db.session.add(c1)
db.session.commit()
Expand Down
Loading

0 comments on commit 86c3f29

Please sign in to comment.