Navigation Menu

Skip to content

Commit

Permalink
Merge pull request getredash#1985 from miketheman/miketheman/user-ema…
Browse files Browse the repository at this point in the history
…il-case-insensitive

Ensure email is case-insensitive
  • Loading branch information
arikfr committed Nov 22, 2017
2 parents 1b7cc37 + 45a6651 commit 9de3e97
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 2 deletions.
3 changes: 2 additions & 1 deletion redash/cli/users.py
Expand Up @@ -112,7 +112,8 @@ def delete(email, organization=None):
models.User.org == org.id,
).delete()
else:
deleted_count = models.User.query.filter(models.User.email == email).delete()
deleted_count = models.User.query.filter(models.User.email == email).delete(
synchronize_session=False)
models.db.session.commit()
print("Deleted %d users." % deleted_count)

Expand Down
25 changes: 24 additions & 1 deletion redash/models.py
Expand Up @@ -22,6 +22,7 @@
from redash.query_runner import (get_configuration_schema_for_query_runner_type,
get_query_runner)
from redash.utils import generate_token, json_dumps
from redash.utils.comparators import CaseInsensitiveComparator
from redash.utils.configuration import ConfigurationContainer
from sqlalchemy import distinct, or_
from sqlalchemy.dialects import postgresql
Expand Down Expand Up @@ -368,12 +369,32 @@ def __unicode__(self):
return unicode(self.id)


class LowercasedString(TypeDecorator):
"""
A lowercased string
"""
impl = db.String
comparator_factory = CaseInsensitiveComparator

def __init__(self, length=320, *args, **kwargs):
super(LowercasedString, self).__init__(length=length, *args, **kwargs)

def process_bind_param(self, value, dialect):
if value is not None:
return value.lower()
return value

@property
def python_type(self):
return self.impl.type.python_type


class User(TimestampMixin, db.Model, BelongsToOrgMixin, UserMixin, PermissionsCheckMixin):
id = Column(db.Integer, primary_key=True)
org_id = Column(db.Integer, db.ForeignKey('organizations.id'))
org = db.relationship(Organization, backref=db.backref("users", lazy="dynamic"))
name = Column(db.String(320))
email = Column(db.String(320))
email = Column(LowercasedString)
password_hash = Column(db.String(128), nullable=True)
# XXX replace with association table
group_ids = Column('groups', MutableList.as_mutable(postgresql.ARRAY(db.Integer)), nullable=True)
Expand All @@ -385,6 +406,8 @@ class User(TimestampMixin, db.Model, BelongsToOrgMixin, UserMixin, PermissionsCh
__table_args__ = (db.Index('users_org_id_email', 'org_id', 'email', unique=True),)

def __init__(self, *args, **kwargs):
if kwargs.get('email') is not None:
kwargs['email'] = kwargs['email'].lower()
super(User, self).__init__(*args, **kwargs)

def to_dict(self, with_api_key=False):
Expand Down
7 changes: 7 additions & 0 deletions redash/utils/comparators.py
@@ -0,0 +1,7 @@
from sqlalchemy import func
from sqlalchemy.ext.hybrid import Comparator


class CaseInsensitiveComparator(Comparator):
def __eq__(self, other):
return func.lower(self.__clause_element__()) == func.lower(other)
24 changes: 24 additions & 0 deletions tests/handlers/test_users.py
Expand Up @@ -26,6 +26,16 @@ def test_creates_user(self):
self.assertEqual(rv.json['name'], test_user['name'])
self.assertEqual(rv.json['email'], test_user['email'])

def test_creates_user_case_insensitive_email(self):
admin = self.factory.create_admin()

test_user = {'name': 'User', 'email': 'User@Example.com', 'password': 'test'}
rv = self.make_request('post', '/api/users', data=test_user, user=admin)

self.assertEqual(rv.status_code, 200)
self.assertEqual(rv.json['name'], test_user['name'])
self.assertEqual(rv.json['email'], 'user@example.com')

def test_returns_400_when_email_taken(self):
admin = self.factory.create_admin()

Expand All @@ -34,6 +44,20 @@ def test_returns_400_when_email_taken(self):

self.assertEqual(rv.status_code, 400)

def test_returns_400_when_email_taken_case_insensitive(self):
admin = self.factory.create_admin()

test_user1 = {'name': 'User', 'email': 'user@example.com', 'password': 'test'}
rv = self.make_request('post', '/api/users', data=test_user1, user=admin)

self.assertEqual(rv.status_code, 200)
self.assertEqual(rv.json['email'], 'user@example.com')

test_user2 = {'name': 'User', 'email': 'user@Example.com', 'password': 'test'}
rv = self.make_request('post', '/api/users', data=test_user2, user=admin)

self.assertEqual(rv.status_code, 400)


class TestUserListGet(BaseTestCase):
def test_returns_users_for_given_org_only(self):
Expand Down
20 changes: 20 additions & 0 deletions tests/models/test_users.py
Expand Up @@ -25,3 +25,23 @@ def test_finds_users(self):
users = User.find_by_email(user.email)
self.assertIn(user, users)
self.assertIn(user2, users)

def test_finds_users_case_insensitive(self):
user = self.factory.create_user(email='test@example.com')

users = User.find_by_email('test@EXAMPLE.com')
self.assertIn(user, users)


class TestUserGetByEmailAndOrg(BaseTestCase):
def test_get_user_by_email_and_org(self):
user = self.factory.create_user(email='test@example.com')

found_user = User.get_by_email_and_org(user.email, user.org)
self.assertEqual(user, found_user)

def test_get_user_by_email_and_org_case_insensitive(self):
user = self.factory.create_user(email='test@example.com')

found_user = User.get_by_email_and_org("TEST@example.com", user.org)
self.assertEqual(user, found_user)
12 changes: 12 additions & 0 deletions tests/test_handlers.py
Expand Up @@ -111,6 +111,18 @@ def test_submit_correct_user_and_password(self):
self.assertEquals(rv.status_code, 302)
login_user_mock.assert_called_with(user, remember=False)

def test_submit_case_insensitive_user_and_password(self):
user = self.factory.user
user.hash_password('password')

self.db.session.add(user)
self.db.session.commit()

with patch('redash.handlers.authentication.login_user') as login_user_mock:
rv = self.client.post('/default/login', data={'email': user.email.upper(), 'password': 'password'})
self.assertEquals(rv.status_code, 302)
login_user_mock.assert_called_with(user, remember=False)

def test_submit_correct_user_and_password_and_remember_me(self):
user = self.factory.user
user.hash_password('password')
Expand Down

0 comments on commit 9de3e97

Please sign in to comment.