Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion src/django_http_compression/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,18 @@

from collections.abc import AsyncGenerator, Awaitable, Generator, Iterator
from functools import lru_cache, partial
from gzip import GzipFile
from types import MappingProxyType
from typing import Callable, Literal, cast

from asgiref.sync import iscoroutinefunction, markcoroutinefunction
from django.http import HttpRequest, HttpResponse, StreamingHttpResponse
from django.http.response import HttpResponseBase
from django.utils.cache import patch_vary_headers
from django.utils.text import compress_sequence as gzip_compress_sequence
from django.utils.text import ( # type: ignore [attr-defined]
StreamingBuffer,
_get_random_filename,
)
from django.utils.text import compress_string as gzip_compress
from typing_extensions import assert_never

Expand Down Expand Up @@ -250,6 +254,29 @@ def _parse_part(
return None


def gzip_compress_sequence(
sequence: Iterator[bytes], *, max_random_bytes: int
) -> Generator[bytes]:
"""
Copy of Django’s compress_sequence() but with streaming response flushing
bug fixed.
"""
buf = StreamingBuffer()
filename = _get_random_filename(max_random_bytes) if max_random_bytes else None
with GzipFile(
filename=filename, mode="wb", compresslevel=6, fileobj=buf, mtime=0
) as zfile:
# Output headers...
yield b"" # Optimization
for item in sequence:
zfile.write(item)
zfile.flush() # Bug fix
data = buf.read()
if data:
yield data
yield buf.read()


def brotli_compress_sequence(sequence: Iterator[bytes]) -> Generator[bytes]:
# Output headers
yield b""
Expand Down
168 changes: 154 additions & 14 deletions tests/test_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,16 @@
import gzip
import inspect
import sys
import zlib
from collections.abc import Iterator
from gzip import decompress as gzip_decompress
from http import HTTPStatus
from textwrap import dedent
from typing import cast

import django
import pytest
from brotli import Decompressor as BrotliDecompressor
from brotli import decompress as brotli_decompress
from django.http import StreamingHttpResponse
from django.middleware import gzip as django_middleware_gzip
Expand All @@ -22,27 +26,31 @@
class HttpCompressionMiddlewareTests(SimpleTestCase):
def test_short(self):
response = self.client.get("/short/", headers={"accept-encoding": "gzip"})

assert response.status_code == HTTPStatus.OK
assert "content-encoding" not in response.headers
assert "vary" not in response.headers
assert response.content == b"short"

def test_encoded(self):
response = self.client.get("/encoded/", headers={"accept-encoding": "gzip"})

assert response.status_code == HTTPStatus.OK
assert response.headers["content-encoding"] == "supercompression"
assert "vary" not in response.headers
assert response.content.decode() == basic_html

def test_identity(self):
response = self.client.get("/")

assert response.status_code == HTTPStatus.OK
assert "content-encoding" not in response.headers
assert "vary" not in response.headers
assert response.content.decode() == basic_html

def test_gzip(self):
response = self.client.get("/", headers={"accept-encoding": "gzip"})

assert response.status_code == HTTPStatus.OK
assert response.headers["content-encoding"] == "gzip"
assert response.headers["vary"] == "accept-encoding"
Expand All @@ -52,6 +60,7 @@ def test_gzip(self):

def test_brotli(self):
response = self.client.get("/", headers={"accept-encoding": "br"})

assert response.status_code == HTTPStatus.OK
assert response.headers["content-encoding"] == "br"
assert response.headers["vary"] == "accept-encoding"
Expand All @@ -64,6 +73,7 @@ def test_zstd(self):
from compression.zstd import decompress

response = self.client.get("/", headers={"accept-encoding": "zstd"})

assert response.status_code == HTTPStatus.OK
assert response.headers["content-encoding"] == "zstd"
assert response.headers["vary"] == "accept-encoding"
Expand All @@ -73,47 +83,121 @@ def test_zstd(self):

def test_streaming_identity(self):
response = self.client.get("/streaming/")

assert isinstance(response, StreamingHttpResponse)
assert not response.is_async
assert response.status_code == HTTPStatus.OK
assert "content-encoding" not in response.headers
assert "vary" not in response.headers
content = response.getvalue()
streaming_content = cast(Iterator[bytes], response.streaming_content)
content = next(streaming_content)
assert content == b"<!doctype html>\n"
content += next(streaming_content)
assert content == b"<!doctype html>\n<html>\n"
for chunk in streaming_content:
content += chunk
assert content.decode() == basic_html

def test_streaming_gzip(self):
response = self.client.get("/streaming/", headers={"accept-encoding": "gzip"})

assert isinstance(response, StreamingHttpResponse)
assert not response.is_async
assert response.status_code == HTTPStatus.OK
assert response.headers["content-encoding"] == "gzip"
assert response.headers["vary"] == "accept-encoding"
content = response.getvalue()
assert content.startswith(b"\x1f\x8b\x08")
decompressed = gzip.decompress(content)
assert decompressed.decode() == basic_html

decompressor = zlib.decompressobj(zlib.MAX_WBITS | 16) # gzip decoding
content = b""
streaming_content = cast(Iterator[bytes], response.streaming_content)

decompressed = decompressor.decompress(next(streaming_content))
assert decompressed == b""
content += decompressed

decompressed = decompressor.decompress(next(streaming_content))
assert decompressed == b"<!doctype html>\n"
content += decompressed

decompressed = decompressor.decompress(next(streaming_content))
assert decompressed == b"<html>\n"
content += decompressed

for chunk in streaming_content:
content += decompressor.decompress(chunk)
content += decompressor.flush()

assert content.decode() == basic_html

def test_streaming_brotli(self):
response = self.client.get("/streaming/", headers={"accept-encoding": "br"})

assert isinstance(response, StreamingHttpResponse)
assert not response.is_async
assert response.status_code == HTTPStatus.OK
assert response.headers["content-encoding"] == "br"
assert response.headers["vary"] == "accept-encoding"
content = response.getvalue()
assert content.startswith(b"\x8b\x07\x00\xf8")
decompressed = brotli_decompress(content)
assert decompressed.decode() == basic_html

streaming_content = cast(Iterator[bytes], response.streaming_content)
decompressor = BrotliDecompressor()
content = b""

decompressed = decompressor.process(next(streaming_content))
assert decompressed == b""
content += decompressed

decompressed = decompressor.process(next(streaming_content))
assert decompressed == b"<!doctype html>\n"
content += decompressed

decompressed = decompressor.process(next(streaming_content))
assert decompressed == b"<html>\n"
content += decompressed

for chunk in streaming_content:
content += decompressor.process(chunk)

assert content.decode() == basic_html
assert decompressor.is_finished()

@pytest.mark.skipif(sys.version_info < (3, 14), reason="Python 3.14+")
def test_streaming_zstd(self):
from compression.zstd import decompress
from compression.zstd import ZstdDecompressor

response = self.client.get("/streaming/", headers={"accept-encoding": "zstd"})

assert isinstance(response, StreamingHttpResponse)
assert not response.is_async
assert response.status_code == HTTPStatus.OK
assert response.headers["content-encoding"] == "zstd"
assert response.headers["vary"] == "accept-encoding"
content = response.getvalue()
assert content.startswith(b"(\xb5/\xfd")
decompressed = decompress(content)
assert decompressed.decode() == basic_html

streaming_content = cast(Iterator[bytes], response.streaming_content)
decompressor = ZstdDecompressor()
content = b""

decompressed = decompressor.decompress(next(streaming_content))
assert decompressed == b""
content += decompressed

decompressed = decompressor.decompress(next(streaming_content))
assert decompressed == b"<!doctype html>\n"
content += decompressed

decompressed = decompressor.decompress(next(streaming_content))
assert decompressed == b"<html>\n"
content += decompressed

for chunk in streaming_content:
content += decompressor.decompress(chunk)

assert decompressor.eof
assert decompressor.unused_data == b""
assert content.decode() == basic_html

def test_streaming_empty_identity(self):
response = self.client.get("/streaming/empty/")

assert response.status_code == HTTPStatus.OK
assert "content-encoding" not in response.headers
assert "vary" not in response.headers
Expand All @@ -124,6 +208,7 @@ def test_streaming_empty_gzip(self):
response = self.client.get(
"/streaming/empty/", headers={"accept-encoding": "gzip"}
)

assert response.status_code == HTTPStatus.OK
assert response.headers["content-encoding"] == "gzip"
assert response.headers["vary"] == "accept-encoding"
Expand All @@ -136,6 +221,7 @@ def test_streaming_empty_brotli(self):
response = self.client.get(
"/streaming/empty/", headers={"accept-encoding": "br"}
)

assert response.status_code == HTTPStatus.OK
assert response.headers["content-encoding"] == "br"
assert response.headers["vary"] == "accept-encoding"
Expand All @@ -151,6 +237,58 @@ def test_streaming_empty_zstd(self):
response = self.client.get(
"/streaming/empty/", headers={"accept-encoding": "zstd"}
)

assert response.status_code == HTTPStatus.OK
assert response.headers["content-encoding"] == "zstd"
assert response.headers["vary"] == "accept-encoding"
content = response.getvalue()
assert content.startswith(b"(\xb5/\xfd")
decompressed = decompress(content)
assert decompressed == b""

def test_streaming_blanks_identity(self):
response = self.client.get("/streaming/blanks/")

assert response.status_code == HTTPStatus.OK
assert "content-encoding" not in response.headers
assert "vary" not in response.headers
content = response.getvalue()
assert content == b""

def test_streaming_blanks_gzip(self):
response = self.client.get(
"/streaming/blanks/", headers={"accept-encoding": "gzip"}
)

assert response.status_code == HTTPStatus.OK
assert response.headers["content-encoding"] == "gzip"
assert response.headers["vary"] == "accept-encoding"
content = response.getvalue()
assert content.startswith(b"\x1f\x8b\x08")
decompressed = gzip.decompress(content)
assert decompressed == b""

def test_streaming_blanks_brotli(self):
response = self.client.get(
"/streaming/blanks/", headers={"accept-encoding": "br"}
)

assert response.status_code == HTTPStatus.OK
assert response.headers["content-encoding"] == "br"
assert response.headers["vary"] == "accept-encoding"
content = response.getvalue()
assert content == b"k\x00\x03"
decompressed = brotli_decompress(content)
assert decompressed == b""

@pytest.mark.skipif(sys.version_info < (3, 14), reason="Python 3.14+")
def test_streaming_blanks_zstd(self):
from compression.zstd import decompress

response = self.client.get(
"/streaming/blanks/", headers={"accept-encoding": "zstd"}
)

assert response.status_code == HTTPStatus.OK
assert response.headers["content-encoding"] == "zstd"
assert response.headers["vary"] == "accept-encoding"
Expand Down Expand Up @@ -263,6 +401,7 @@ async def test_async_streaming_zstd(self):

def test_binary(self):
response = self.client.get("/binary/", headers={"accept-encoding": "gzip"})

assert response.status_code == HTTPStatus.OK
assert response.headers["content-encoding"] == "gzip"
assert response.headers["vary"] == "accept-encoding"
Expand All @@ -273,6 +412,7 @@ def test_binary(self):

def test_etag(self):
response = self.client.get("/etag/", headers={"accept-encoding": "gzip"})

assert response.status_code == HTTPStatus.OK
assert response.headers["content-encoding"] == "gzip"
assert response.headers["vary"] == "accept-encoding"
Expand Down
1 change: 1 addition & 0 deletions tests/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
path("async/", views.async_),
path("streaming/", views.streaming),
path("streaming/empty/", views.streaming_empty),
path("streaming/blanks/", views.streaming_blanks),
path("async/streaming/", views.async_streaming),
path("binary/", views.binary),
path("etag/", views.etag),
Expand Down
8 changes: 8 additions & 0 deletions tests/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,14 @@ def empty() -> Generator[bytes]:
return StreamingHttpResponse(empty())


def streaming_blanks(request: HttpRequest) -> StreamingHttpResponse:
def empty() -> Generator[bytes]:
yield b""
yield b""

return StreamingHttpResponse(empty())


async def async_streaming(request: HttpRequest) -> StreamingHttpResponse:
async def lines() -> AsyncGenerator[str]:
for line in basic_html.splitlines(keepends=True):
Expand Down