diff --git a/masonite/middleware/CsrfMiddleware.py b/masonite/middleware/CsrfMiddleware.py index ba5c690f1..741cab8d9 100644 --- a/masonite/middleware/CsrfMiddleware.py +++ b/masonite/middleware/CsrfMiddleware.py @@ -47,11 +47,11 @@ def in_exempt(self): Returns: bool """ + for route in self.exempt: + if self.request.contains(route): + return True - if self.request.path in self.exempt: - return True - else: - return False + return False def generate_token(self): """Generate a token that will be used for CSRF protection diff --git a/tests/middleware/test_csrf_middleware.py b/tests/middleware/test_csrf_middleware.py index 5e46746e1..c3bfdb3a5 100644 --- a/tests/middleware/test_csrf_middleware.py +++ b/tests/middleware/test_csrf_middleware.py @@ -6,17 +6,25 @@ from masonite.testsuite.TestSuite import generate_wsgi import pytest from masonite.exceptions import InvalidCSRFToken +from masonite.routes import Get, Route class TestCSRFMiddleware: def setup_method(self): self.app = App() - self.request = Request(generate_wsgi()) + wsgi = generate_wsgi() + self.request = Request(wsgi) + self.route = Route().load_environ(wsgi) self.view = View(self.app) self.app.bind('Request', self.request) self.request = self.app.make('Request') + self.app.bind('WebRoutes', [ + Get().route('/test/@route', None), + Get().route('/test/10', None), + ]) + self.request.container = self.app self.middleware = CsrfMiddleware(self.request, Csrf(self.request), self.view) @@ -27,10 +35,36 @@ def test_middleware_shares_correct_input(self): def test_middleware_throws_exception_on_post(self): self.request.environ['REQUEST_METHOD'] = 'POST' + self.request.path = '/' self.middleware.exempt = [] with pytest.raises(InvalidCSRFToken): self.middleware.before() + def test_middleware_can_accept_param_route(self): + self.request.environ['REQUEST_METHOD'] = 'POST' + self.request.path = '/test/1' + self.middleware.exempt = [ + '/test/@route' + ] + self.middleware.before() + + def test_middleware_can_exempt(self): + self.request.environ['REQUEST_METHOD'] = 'POST' + self.request.path = '/test/1' + self.middleware.exempt = [ + '/test/1' + ] + self.middleware.before() + + def test_middleware_throws_exeption_on_wrong_route(self): + self.request.environ['REQUEST_METHOD'] = 'POST' + self.request.path = '/test/10' + self.middleware.exempt = [ + '/test/2' + ] + with pytest.raises(InvalidCSRFToken): + self.middleware.before() + def test_incoming_token_does_not_throw_exception_with_token(self): self.request.environ['REQUEST_METHOD'] = 'POST' self.request.request_variables.update({'__token': self.request.get_cookie('csrf_token')})