Skip to content
This repository has been archived by the owner on Jan 13, 2021. It is now read-only.

Commit

Permalink
Merge pull request #186 from irvind/override-default-headers
Browse files Browse the repository at this point in the history
Allow to override default request headers
  • Loading branch information
Lukasa committed Dec 28, 2015
2 parents 31cbc84 + f8a0b04 commit d395d57
Show file tree
Hide file tree
Showing 6 changed files with 151 additions and 27 deletions.
13 changes: 13 additions & 0 deletions hyper/common/headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,19 @@ def iter_raw(self):
for item in self._items:
yield item

def replace(self, key, value):
"""
Replace existing header with new value. If header doesn't exist this
method work like ``__setitem__``. Replacing leads to deletion of all
exsiting headers with the same name.
"""
try:
del self[key]
except KeyError:
pass

self[key] = value

def merge(self, other):
"""
Merge another header set or any other dict-like into this one.
Expand Down
10 changes: 10 additions & 0 deletions hyper/common/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
"""
from hyper.compat import unicode, bytes, imap
from ..packages.rfc3986.uri import URIReference
from ..compat import is_py3
import re


def to_bytestring(element):
"""
Converts a single string to a bytestring, encoding via UTF-8 if needed.
Expand All @@ -28,6 +30,7 @@ def to_bytestring_tuple(*x):
"""
return tuple(imap(to_bytestring, x))


def to_host_port_tuple(host_port_str, default_port=80):
"""
Converts the given string containing a host and possibly a port
Expand All @@ -48,3 +51,10 @@ def to_host_port_tuple(host_port_str, default_port=80):
port = int(uri.port)

return (host, port)


def to_native_string(string, encoding='utf-8'):
if isinstance(string, str):
return string

return string.decode(encoding) if is_py3 else string.encode(encoding)
10 changes: 6 additions & 4 deletions hyper/http20/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ..common.exceptions import ConnectionResetError
from ..common.bufsocket import BufferedSocket
from ..common.headers import HTTPHeaderMap
from ..common.util import to_host_port_tuple
from ..common.util import to_host_port_tuple, to_native_string
from ..packages.hyperframe.frame import (
FRAMES, DataFrame, HeadersFrame, PushPromiseFrame, RstStreamFrame,
SettingsFrame, Frame, WindowUpdateFrame, GoAwayFrame, PingFrame,
Expand Down Expand Up @@ -170,8 +170,10 @@ def request(self, method, url, body=None, headers={}):
"""
stream_id = self.putrequest(method, url)

default_headers = (':method', ':scheme', ':authority', ':path')
for name, value in headers.items():
self.putheader(name, value, stream_id)
is_default = to_native_string(name) in default_headers
self.putheader(name, value, stream_id, replace=is_default)

# Convert the body to bytes if needed.
if isinstance(body, str):
Expand Down Expand Up @@ -319,7 +321,7 @@ def putrequest(self, method, selector, **kwargs):

return s.stream_id

def putheader(self, header, argument, stream_id=None):
def putheader(self, header, argument, stream_id=None, replace=False):
"""
Sends an HTTP header to the server, with name ``header`` and value
``argument``.
Expand All @@ -341,7 +343,7 @@ def putheader(self, header, argument, stream_id=None):
:returns: Nothing.
"""
stream = self._get_stream(stream_id)
stream.add_header(header, argument)
stream.add_header(header, argument, replace)

return

Expand Down
11 changes: 8 additions & 3 deletions hyper/http20/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(self,
local_closed=False):
self.stream_id = stream_id
self.state = STATE_HALF_CLOSED_LOCAL if local_closed else STATE_IDLE
self.headers = []
self.headers = HTTPHeaderMap()

# Set to a key-value set of the response headers once their
# HEADERS..CONTINUATION frame sequence finishes.
Expand Down Expand Up @@ -109,11 +109,15 @@ def __init__(self,
self._encoder = header_encoder
self._decoder = header_decoder

def add_header(self, name, value):
def add_header(self, name, value, replace=False):
"""
Adds a single HTTP header to the headers to be sent on the request.
"""
self.headers.append((name.lower(), value))
if not replace:
self.headers[name] = value
else:
self.headers.replace(name, value)


def send_data(self, data, final):
"""
Expand Down Expand Up @@ -270,6 +274,7 @@ def open(self, end):
"""
# Strip any headers invalid in H2.
headers = h2_safe_headers(self.headers)

# Encode the headers.
encoded_headers = self._encoder.encode(headers)

Expand Down
19 changes: 19 additions & 0 deletions test/test_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,3 +258,22 @@ def test_merge_header_map_dict(self):
(b'hi', b'there'),
(b'cat', b'dog'),
]

def test_replacing(self):
h = HTTPHeaderMap([
(b'name', b'value'),
(b'name2', b'value2'),
(b'name2', b'value2'),
(b'name3', b'value3'),
])

h.replace('name2', '42')
h.replace('name4', 'other_value')

assert list(h.items()) == [
(b'name', b'value'),
(b'name3', b'value3'),
(b'name2', b'42'),
(b'name4', b'other_value'),
]

115 changes: 95 additions & 20 deletions test/test_hyper.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
combine_repeated_headers, split_repeated_headers, h2_safe_headers
)
from hyper.common.headers import HTTPHeaderMap
from hyper.compat import zlib_compressobj
from hyper.compat import zlib_compressobj, is_py2
from hyper.contrib import HTTP20Adapter
import hyper.http20.errors as errors
import errno
Expand All @@ -29,6 +29,7 @@
from io import BytesIO
import hyper


def decode_frame(frame_data):
f, length = Frame.parse_frame_header(frame_data[:9])
f.parse_body(memoryview(frame_data[9:9 + length]))
Expand Down Expand Up @@ -87,11 +88,11 @@ def test_putrequest_autosets_headers(self):
c.putrequest('GET', '/')
s = c.recent_stream

assert s.headers == [
(':method', 'GET'),
(':scheme', 'https'),
(':authority', 'www.google.com'),
(':path', '/'),
assert list(s.headers.items()) == [
(b':method', b'GET'),
(b':scheme', b'https'),
(b':authority', b'www.google.com'),
(b':path', b'/'),
]

def test_putheader_puts_headers(self):
Expand All @@ -101,12 +102,29 @@ def test_putheader_puts_headers(self):
c.putheader('name', 'value')
s = c.recent_stream

assert s.headers == [
(':method', 'GET'),
(':scheme', 'https'),
(':authority', 'www.google.com'),
(':path', '/'),
('name', 'value'),
assert list(s.headers.items()) == [
(b':method', b'GET'),
(b':scheme', b'https'),
(b':authority', b'www.google.com'),
(b':path', b'/'),
(b'name', b'value'),
]

def test_putheader_replaces_headers(self):
c = HTTP20Connection("www.google.com")

c.putrequest('GET', '/')
c.putheader(':authority', 'www.example.org', replace=True)
c.putheader('name', 'value')
c.putheader('name', 'value2', replace=True)
s = c.recent_stream

assert list(s.headers.items()) == [
(b':method', b'GET'),
(b':scheme', b'https'),
(b':path', b'/'),
(b':authority', b'www.example.org'),
(b'name', b'value2'),
]

def test_endheaders_sends_data(self):
Expand Down Expand Up @@ -203,6 +221,33 @@ def test_putrequest_sends_data(self):
assert len(sock.queue) == 2
assert c._out_flow_control_window == 65535 - len(b'hello')

def test_different_request_headers(self):
sock = DummySocket()

c = HTTP20Connection('www.google.com')
c._sock = sock
c.request('GET', '/', body='hello', headers={b'name': b'value'})
s = c.recent_stream

assert list(s.headers.items()) == [
(b':method', b'GET'),
(b':scheme', b'https'),
(b':authority', b'www.google.com'),
(b':path', b'/'),
(b'name', b'value'),
]

c.request('GET', '/', body='hello', headers={u'name2': u'value2'})
s = c.recent_stream

assert list(s.headers.items()) == [
(b':method', b'GET'),
(b':scheme', b'https'),
(b':authority', b'www.google.com'),
(b':path', b'/'),
(b'name2', b'value2'),
]

def test_closed_connections_are_reset(self):
c = HTTP20Connection('www.google.com')
c._sock = DummySocket()
Expand Down Expand Up @@ -502,11 +547,11 @@ def test_that_using_proxy_keeps_http_headers_intact(self):
c.request('GET', '/')
s = c.recent_stream

assert s.headers == [
(':method', 'GET'),
(':scheme', 'http'),
(':authority', 'www.google.com'),
(':path', '/'),
assert list(s.headers.items()) == [
(b':method', b'GET'),
(b':scheme', b'http'),
(b':authority', b'www.google.com'),
(b':path', b'/'),
]

def test_recv_cb_n_times(self):
Expand Down Expand Up @@ -695,13 +740,30 @@ def test_streams_have_ids(self):

def test_streams_initially_have_no_headers(self):
s = Stream(1, None, None, None, None, None, None)
assert s.headers == []
assert list(s.headers.items()) == []

def test_streams_can_have_headers(self):
s = Stream(1, None, None, None, None, None, None)
s.add_header("name", "value")
assert s.headers == [("name", "value")]
assert list(s.headers.items()) == [(b"name", b"value")]

def test_streams_can_replace_headers(self):
s = Stream(1, None, None, None, None, None, None)
s.add_header("name", "value")
s.add_header("name", "other_value", replace=True)

assert list(s.headers.items()) == [(b"name", b"other_value")]

def test_streams_can_replace_none_headers(self):
s = Stream(1, None, None, None, None, None, None)
s.add_header("name", "value")
s.add_header("other_name", "other_value", replace=True)

assert list(s.headers.items()) == [
(b"name", b"value"),
(b"other_name", b"other_value")
]

def test_stream_opening_sends_headers(self):
def data_callback(frame):
assert isinstance(frame, HeadersFrame)
Expand Down Expand Up @@ -1465,11 +1527,23 @@ def test_connection_error_when_send_out_of_range_frame(self):
with pytest.raises(ValueError):
c._send_cb(d)


# Some utility classes for the tests.
class NullEncoder(object):
@staticmethod
def encode(headers):
return '\n'.join("%s%s" % (name, val) for name, val in headers)

def to_str(v):
if is_py2:
return str(v)
else:
if not isinstance(v, str):
v = str(v, 'utf-8')
return v

return '\n'.join("%s%s" % (to_str(name), to_str(val))
for name, val in headers)


class FixedDecoder(object):
def __init__(self, result):
Expand All @@ -1478,6 +1552,7 @@ def __init__(self, result):
def decode(self, headers):
return self.result


class DummySocket(object):
def __init__(self):
self.queue = []
Expand Down

0 comments on commit d395d57

Please sign in to comment.