Skip to content

Commit

Permalink
Escape unsafe query string when rendering to html
Browse files Browse the repository at this point in the history
  • Loading branch information
rayluo committed Sep 16, 2023
1 parent 9ae3b95 commit 3427c25
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 7 deletions.
29 changes: 22 additions & 7 deletions oauth2cli/authcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
try: # Python 3
from http.server import HTTPServer, BaseHTTPRequestHandler
from urllib.parse import urlparse, parse_qs, urlencode
from html import escape
except ImportError: # Fall back to Python 2
from BaseHTTPServer import HTTPServer, BaseHTTPRequestHandler
from urlparse import urlparse, parse_qs
from urllib import urlencode
from cgi import escape


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -77,25 +79,37 @@ def _qs2kv(qs):
for k, v in qs.items()}


def _is_html(text):
return text.startswith("<") # Good enough for our purpose


def _escape(key_value_pairs):
return {k: escape(v) for k, v in key_value_pairs.items()}


class _AuthCodeHandler(BaseHTTPRequestHandler):
def do_GET(self):
# For flexibility, we choose to not check self.path matching redirect_uri
#assert self.path.startswith('/THE_PATH_REGISTERED_BY_THE_APP')
qs = parse_qs(urlparse(self.path).query)
if qs.get('code') or qs.get("error"): # So, it is an auth response
self.server.auth_response = _qs2kv(qs)
logger.debug("Got auth response: %s", self.server.auth_response)
auth_response = _qs2kv(qs)
logger.debug("Got auth response: %s", auth_response)
template = (self.server.success_template
if "code" in qs else self.server.error_template)
self._send_full_response(
template.safe_substitute(**self.server.auth_response))
if _is_html(template.template):
safe_data = _escape(auth_response)
else:
safe_data = auth_response
self._send_full_response(template.safe_substitute(**safe_data))
self.server.auth_response = auth_response # Set it now, after the response is likely sent
# NOTE: Don't do self.server.shutdown() here. It'll halt the server.
else:
self._send_full_response(self.server.welcome_page)

def _send_full_response(self, body, is_ok=True):
self.send_response(200 if is_ok else 400)
content_type = 'text/html' if body.startswith('<') else 'text/plain'
content_type = 'text/html' if _is_html(body) else 'text/plain'
self.send_header('Content-type', content_type)
self.end_headers()
self.wfile.write(body.encode("utf-8"))
Expand Down Expand Up @@ -318,6 +332,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
default="https://login.microsoftonline.com/common/oauth2/v2.0/authorize")
p.add_argument('client_id', help="The client_id of your application")
p.add_argument('--port', type=int, default=0, help="The port in redirect_uri")
p.add_argument('--timeout', type=int, default=60, help="Timeout value, in second")
p.add_argument('--host', default="127.0.0.1", help="The host of redirect_uri")
p.add_argument('--scope', default=None, help="The scope list")
args = parser.parse_args()
Expand All @@ -331,8 +346,8 @@ def __exit__(self, exc_type, exc_val, exc_tb):
auth_uri=flow["auth_uri"],
welcome_template=
"<a href='$auth_uri'>Sign In</a>, or <a href='$abort_uri'>Abort</a",
error_template="Oh no. $error",
error_template="<html>Oh no. $error</html>",
success_template="Oh yeah. Got $code",
timeout=60,
timeout=args.timeout,
state=flow["state"], # Optional
), indent=4))
17 changes: 17 additions & 0 deletions tests/test_authcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import socket
import sys

import requests

from oauth2cli.authcode import AuthCodeReceiver


Expand All @@ -23,3 +25,18 @@ def test_no_two_concurrent_receivers_can_listen_on_same_port(self):
with AuthCodeReceiver(port=receiver.get_port()):
pass

def test_template_should_escape_input(self):
with AuthCodeReceiver() as receiver:
receiver._scheduled_actions = [( # Injection happens here when the port is known
1, # Delay it until the receiver is activated by get_auth_response()
lambda: self.assertEqual(
"<html>&lt;tag&gt;foo&lt;/tag&gt;</html>",
requests.get("http://localhost:{}?error=<tag>foo</tag>".format(
receiver.get_port())).text,
"Unsafe data in HTML should be escaped",
))]
receiver.get_auth_response( # Starts server and hang until timeout
timeout=3,
error_template="<html>$error</html>",
)

0 comments on commit 3427c25

Please sign in to comment.