Skip to content

Commit

Permalink
Add HTTP Headers fixer for Client and RequestFactory.
Browse files Browse the repository at this point in the history
  • Loading branch information
David-Wobrock committed Dec 23, 2022
1 parent c26ab9c commit 0ca8fce
Show file tree
Hide file tree
Showing 3 changed files with 327 additions and 1 deletion.
205 changes: 205 additions & 0 deletions src/django_upgrade/fixers/http_headers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
"""
Rewrite HTTP headers as headers keyword parameter:
https://docs.djangoproject.com/en/dev/releases/4.2/#tests
"""
from __future__ import annotations

import ast
from functools import partial
from typing import Iterable

from tokenize_rt import Offset
from tokenize_rt import Token

from django_upgrade.ast import ast_start_offset
from django_upgrade.data import Fixer
from django_upgrade.data import State
from django_upgrade.data import TokenFunc
from django_upgrade.fixers.assert_form_error import looks_like_client_call
from django_upgrade.tokens import find
from django_upgrade.tokens import NAME
from django_upgrade.tokens import OP
from django_upgrade.tokens import parse_call_args

fixer = Fixer(
__name__,
min_version=(4, 2),
)

HEADERS_KWARG = "headers"
HTTP_PREFIX = "HTTP_"


def merge_http_headers_kwargs(tokens: list[Token], i: int, *, node: ast.Call) -> None:
http_kwargs_names = []
contains_headers_kwarg = False

for kw in node.keywords:
if kw.arg and kw.arg.startswith(HTTP_PREFIX):
http_kwargs_names.append(kw.arg)
elif kw.arg == HEADERS_KWARG:
contains_headers_kwarg = True

new_headers_kwarg_values = {}
first_start_idx = -1
for http_kwarg_name in http_kwargs_names:
erased_tokens, start_idx = erase_keyword_argument(
tokens, i, kwarg_name=http_kwarg_name
)
new_headers_kwarg_values[transform_header_name(http_kwarg_name)] = erased_tokens
if first_start_idx == -1:
first_start_idx = start_idx

if new_headers_kwarg_values:
if contains_headers_kwarg:
insert_into_existing_headers_kwarg(
tokens, i, kwarg_dict=new_headers_kwarg_values
)
else:
insert_headers_kwarg(
tokens,
i,
kwarg_dict=new_headers_kwarg_values,
insert_position=first_start_idx,
)


def transform_header_name(header: str) -> str:
if header.startswith(HTTP_PREFIX):
header = header[len(HTTP_PREFIX) :]
return header.replace("_", "-").lower()


def erase_keyword_argument(
tokens: list[Token], i: int, *, kwarg_name: str
) -> tuple[list[Token], int]:
"""
Erases a keyword argument based on its name, and returns the tokens that
defined the value of the keyword argument, and its starting position.
The starting position will be used to start inserting the new argument in
the right place.
"""
open_idx = find(tokens, i, name=OP, src="(")
func_args, close_idx = parse_call_args(tokens, open_idx)
for arg_start_idx, arg_end_idx in func_args:
kwarg_name_start_idx = find(tokens, arg_start_idx, name=NAME, src=kwarg_name)
if kwarg_name_start_idx > arg_end_idx:
continue

value_start_idx = find(tokens, kwarg_name_start_idx, name=OP, src="=")
erased_tokens = tokens[value_start_idx + 1 : arg_end_idx]

ends_with_comma = (
tokens[arg_end_idx].name == OP and tokens[arg_end_idx].src == ","
)
if ends_with_comma:
arg_end_idx += 1
ends_with_whitespace = (
tokens[arg_end_idx].name == "UNIMPORTANT_WS"
and tokens[arg_end_idx].src == " "
)
if ends_with_whitespace:
arg_end_idx += 1

del tokens[kwarg_name_start_idx:arg_end_idx]
return erased_tokens, kwarg_name_start_idx
return [], -1


def insert_into_existing_headers_kwarg(
tokens: list[Token], i: int, *, kwarg_dict: dict[str, list[Token]]
) -> None:
open_idx = find(tokens, i, name=OP, src="(")

headers_arg_start_idx = find(tokens, open_idx, name=NAME, src=HEADERS_KWARG)
headers_closing_idx = find(tokens, headers_arg_start_idx, name=OP, src="}")

new_tokens = []
for key, value in kwarg_dict.items():
new_tokens += [
Token(name="OP", src=","),
Token(name="UNIMPORTANT_WS", src=" "),
]
new_tokens += [
Token(name="STRING", src=f'"{key}"'),
Token(name="OP", src=":"),
Token(name="UNIMPORTANT_WS", src=" "),
]
new_tokens += value
tokens[headers_closing_idx:headers_closing_idx] = new_tokens


def insert_headers_kwarg(
tokens: list[Token],
i: int,
*,
kwarg_dict: dict[str, list[Token]],
insert_position: int,
) -> None:
open_idx = find(tokens, i, name=OP, src="(")
func_args, close_idx = parse_call_args(tokens, open_idx)

new_tokens = []

new_tokens += [
Token(name=NAME, src=HEADERS_KWARG),
Token(name="OP", src="="),
Token(name="OP", src="{"),
]

for i, (key, value) in enumerate(kwarg_dict.items(), 1):
new_tokens += [
Token(name="STRING", src=f'"{key}"'),
Token(name="OP", src=":"),
Token(name="UNIMPORTANT_WS", src=" "),
]
new_tokens += value
if i < len(kwarg_dict):
new_tokens += [
Token(name="OP", src=","),
Token(name="UNIMPORTANT_WS", src=" "),
]

new_tokens += [
Token(name="OP", src="}"),
]

# If it's not the last argument.
if len(func_args) != 0 and func_args[-1][1] > insert_position:
new_tokens += [
Token(name="OP", src=","),
Token(name="UNIMPORTANT_WS", src=" "),
]

tokens[insert_position:insert_position] = new_tokens


@fixer.register(ast.Call)
def visit_Call(
state: State,
node: ast.Call,
parents: list[ast.AST],
) -> Iterable[tuple[Offset, TokenFunc]]:
if (
isinstance(node.func, ast.Name)
and "RequestFactory" in state.from_imports["django.test"]
and node.func.id == "RequestFactory"
):
yield ast_start_offset(node), partial(
merge_http_headers_kwargs,
node=node,
)
elif (
isinstance(node.func, ast.Name)
and "Client" in state.from_imports["django.test"]
and node.func.id == "Client"
):
yield ast_start_offset(node), partial(
merge_http_headers_kwargs,
node=node,
)
elif looks_like_client_call(node, "client") and node.args:
yield ast_start_offset(node), partial(
merge_http_headers_kwargs,
node=node,
)
2 changes: 1 addition & 1 deletion src/django_upgrade/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def erase_node(tokens: list[Token], i: int, *, node: ast.AST) -> None:

def find_and_replace_name(tokens: list[Token], i: int, *, name: str, new: str) -> None:
j = find(tokens, i, name=NAME, src=name)
tokens[j] = tokens[j]._replace(name="CODE", src=new)
tokens[j] = tokens[j]._replace(name=CODE, src=new)


def replace_argument_names(
Expand Down
121 changes: 121 additions & 0 deletions tests/fixers/test_http_headers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from __future__ import annotations

from django_upgrade.data import Settings
from tests.fixers.tools import check_noop
from tests.fixers.tools import check_transformed

settings = Settings(target_version=(4, 2))


def test_transform_constructor_call():
check_transformed(
"""\
from django.test import RequestFactory
RequestFactory(HTTP_HOST="test.com")
""",
"""\
from django.test import RequestFactory
RequestFactory(headers={"host": "test.com"})
""",
settings,
)


def test_transform_constructor_call_multiple():
check_transformed(
"""\
from django.test import Client
c = Client(HTTP_HOST="test.com", headers={"user-agent": "example"}, HTTP_ACCEPT_LANGUAGE="fr-fr")
""", # noqa: E501
"""\
from django.test import Client
c = Client(headers={"user-agent": "example", "host": "test.com", "accept-language": "fr-fr"}, )
""", # noqa: E501
settings,
)


def test_transform():
check_transformed(
"""\
from django.test import Client
self.client = Client()
self.client.get("/", HTTP_HOST="test.com", SCRIPT_NAME="/app/")
""",
"""\
from django.test import Client
self.client = Client()
self.client.get("/", headers={"host": "test.com"}, SCRIPT_NAME="/app/")
""",
settings,
)


def test_transform_only():
check_transformed(
"""\
from django.test import Client
self.client = Client()
self.client.get("/", HTTP_HOST="test.com")
""",
"""\
from django.test import Client
self.client = Client()
self.client.get("/", headers={"host": "test.com"})
""",
settings,
)


def test_transform_multiple():
check_transformed(
"""\
from django.test import Client
self.client = Client()
self.client.get("/", HTTP_HOST="test.com", HTTP_ACCEPT="application/json")
""",
"""\
from django.test import Client
self.client = Client()
self.client.get("/", headers={"host": "test.com", "accept": "application/json"})
""",
settings,
)


def test_unchanged_parameter():
check_noop(
"""
self.client.get("/", SCRIPT_NAME="/app/"),
""",
settings,
)


def test_unchanged_other_call():
check_noop(
"""
request.META.get("/", HTTP_HOST="host.com"),
""",
settings,
)


def test_transform_multiline():
check_transformed(
"""
response = self.client.post(
"/password_reset/",
{"email": "staffmember@example.com"},
HTTP_HOST="www.example:dr.frankenstein@evil.tld",
)
""",
"""
response = self.client.post(
"/password_reset/",
{"email": "staffmember@example.com"},
headers={"host": "www.example:dr.frankenstein@evil.tld"}
)
""",
settings,
)

0 comments on commit 0ca8fce

Please sign in to comment.