Skip to content

Commit

Permalink
only allow methods which implement HTTP verbs to be called remotely
Browse files Browse the repository at this point in the history
This fixes 500 server crashes caused by requests such as:

curl -X__init__ "http://your-swift-object-server:6000/sda1/p/a/c/o"

Fixes bug 1005903

Change-Id: I6c0ad39a29e07ce5f46b0fdbd11a53a9a1010a04
  • Loading branch information
iartarisi committed Jun 4, 2012
1 parent 783f160 commit 9f5a6bb
Show file tree
Hide file tree
Showing 9 changed files with 147 additions and 32 deletions.
18 changes: 14 additions & 4 deletions swift/account/server.py
Expand Up @@ -31,7 +31,7 @@

import swift.common.db
from swift.common.db import AccountBroker
from swift.common.utils import get_logger, get_param, hash_path, \
from swift.common.utils import get_logger, get_param, hash_path, public, \
normalize_timestamp, split_path, storage_directory, TRUE_VALUES
from swift.common.constraints import ACCOUNT_LISTING_LIMIT, \
check_mount, check_float, check_utf8
Expand Down Expand Up @@ -63,6 +63,7 @@ def _get_account_broker(self, drive, part, account):
db_path = os.path.join(self.root, drive, db_dir, hsh + '.db')
return AccountBroker(db_path, account=account, logger=self.logger)

@public
def DELETE(self, req):
"""Handle HTTP DELETE request."""
start_time = time.time()
Expand All @@ -88,6 +89,7 @@ def DELETE(self, req):
self.logger.timing_since('DELETE.timing', start_time)
return HTTPNoContent(request=req)

@public
def PUT(self, req):
"""Handle HTTP PUT request."""
start_time = time.time()
Expand Down Expand Up @@ -149,6 +151,7 @@ def PUT(self, req):
else:
return HTTPAccepted(request=req)

@public
def HEAD(self, req):
"""Handle HTTP HEAD request."""
# TODO(refactor): The account server used to provide a 'account and
Expand Down Expand Up @@ -192,6 +195,7 @@ def HEAD(self, req):
self.logger.timing_since('HEAD.timing', start_time)
return HTTPNoContent(request=req, headers=headers)

@public
def GET(self, req):
"""Handle HTTP GET request."""
start_time = time.time()
Expand Down Expand Up @@ -292,6 +296,7 @@ def GET(self, req):
self.logger.timing_since('GET.timing', start_time)
return ret

@public
def REPLICATE(self, req):
"""
Handle HTTP REPLICATE request.
Expand All @@ -318,6 +323,7 @@ def REPLICATE(self, req):
self.logger.timing_since('REPLICATE.timing', start_time)
return ret

@public
def POST(self, req):
"""Handle HTTP POST request."""
start_time = time.time()
Expand Down Expand Up @@ -357,10 +363,14 @@ def __call__(self, env, start_response):
res = HTTPPreconditionFailed(body='Invalid UTF8')
else:
try:
if hasattr(self, req.method):
res = getattr(self, req.method)(req)
else:
# disallow methods which are not publicly accessible
try:
method = getattr(self, req.method)
getattr(method, 'publicly_accessible')
except AttributeError:
res = HTTPMethodNotAllowed()
else:
res = method(req)
except (Exception, Timeout):
self.logger.exception(_('ERROR __call__ error with %(method)s'
' %(path)s '), {'method': req.method, 'path': req.path})
Expand Down
15 changes: 15 additions & 0 deletions swift/common/utils.py
Expand Up @@ -1246,3 +1246,18 @@ def streq_const_time(s1, s2):
for (a, b) in zip(s1, s2):
result |= ord(a) ^ ord(b)
return result == 0


def public(func):
"""
Decorator to declare which methods are publicly accessible as HTTP
requests
:param func: function to make public
"""
func.publicly_accessible = True

@functools.wraps(func)
def wrapped(*a, **kw):
return func(*a, **kw)
return wrapped
18 changes: 14 additions & 4 deletions swift/container/server.py
Expand Up @@ -31,7 +31,7 @@

import swift.common.db
from swift.common.db import ContainerBroker
from swift.common.utils import get_logger, get_param, hash_path, \
from swift.common.utils import get_logger, get_param, hash_path, public, \
normalize_timestamp, storage_directory, split_path, validate_sync_to, \
TRUE_VALUES
from swift.common.constraints import CONTAINER_LISTING_LIMIT, \
Expand Down Expand Up @@ -138,6 +138,7 @@ def account_update(self, req, account, container, broker):
'device': account_device})
return None

@public
def DELETE(self, req):
"""Handle HTTP DELETE request."""
start_time = time.time()
Expand Down Expand Up @@ -187,6 +188,7 @@ def DELETE(self, req):
return HTTPNoContent(request=req)
return HTTPNotFound()

@public
def PUT(self, req):
"""Handle HTTP PUT request."""
start_time = time.time()
Expand Down Expand Up @@ -255,6 +257,7 @@ def PUT(self, req):
else:
return HTTPAccepted(request=req)

@public
def HEAD(self, req):
"""Handle HTTP HEAD request."""
start_time = time.time()
Expand Down Expand Up @@ -288,6 +291,7 @@ def HEAD(self, req):
self.logger.timing_since('HEAD.timing', start_time)
return HTTPNoContent(request=req, headers=headers)

@public
def GET(self, req):
"""Handle HTTP GET request."""
start_time = time.time()
Expand Down Expand Up @@ -409,6 +413,7 @@ def GET(self, req):
self.logger.timing_since('GET.timing', start_time)
return ret

@public
def REPLICATE(self, req):
"""
Handle HTTP REPLICATE request (json-encoded RPC calls for replication.)
Expand All @@ -434,6 +439,7 @@ def REPLICATE(self, req):
self.logger.timing_since('REPLICATE.timing', start_time)
return ret

@public
def POST(self, req):
"""Handle HTTP POST request."""
start_time = time.time()
Expand Down Expand Up @@ -485,10 +491,14 @@ def __call__(self, env, start_response):
res = HTTPPreconditionFailed(body='Invalid UTF8')
else:
try:
if hasattr(self, req.method):
res = getattr(self, req.method)(req)
else:
# disallow methods which have not been marked 'public'
try:
method = getattr(self, req.method)
getattr(method, 'publicly_accessible')
except AttributeError:
res = HTTPMethodNotAllowed()
else:
res = method(req)
except (Exception, Timeout):
self.logger.exception(_('ERROR __call__ error with %(method)s'
' %(path)s '), {'method': req.method, 'path': req.path})
Expand Down
18 changes: 14 additions & 4 deletions swift/obj/server.py
Expand Up @@ -35,7 +35,7 @@
from xattr import getxattr, setxattr
from eventlet import sleep, Timeout, tpool

from swift.common.utils import mkdirs, normalize_timestamp, \
from swift.common.utils import mkdirs, normalize_timestamp, public, \
storage_directory, hash_path, renamer, fallocate, \
split_path, drop_buffer_cache, get_logger, write_pickle
from swift.common.bufferedhttp import http_connect
Expand Down Expand Up @@ -484,6 +484,7 @@ def delete_at_update(self, op, delete_at, account, container, obj,
'%s-%s/%s/%s' % (delete_at, account, container, obj),
host, partition, contdevice, headers_out, objdevice)

@public
def POST(self, request):
"""Handle HTTP POST requests for the Swift Object Server."""
start_time = time.time()
Expand Down Expand Up @@ -543,6 +544,7 @@ def POST(self, request):
self.logger.timing_since('POST.timing', start_time)
return response_class(request=request)

@public
def PUT(self, request):
"""Handle HTTP PUT requests for the Swift Object Server."""
start_time = time.time()
Expand Down Expand Up @@ -641,6 +643,7 @@ def PUT(self, request):
self.logger.timing_since('PUT.timing', start_time)
return resp

@public
def GET(self, request):
"""Handle HTTP GET requests for the Swift Object Server."""
start_time = time.time()
Expand Down Expand Up @@ -729,6 +732,7 @@ def GET(self, request):
self.logger.timing_since('GET.timing', start_time)
return request.get_response(response)

@public
def HEAD(self, request):
"""Handle HTTP HEAD requests for the Swift Object Server."""
start_time = time.time()
Expand Down Expand Up @@ -774,6 +778,7 @@ def HEAD(self, request):
self.logger.timing_since('HEAD.timing', start_time)
return response

@public
def DELETE(self, request):
"""Handle HTTP DELETE requests for the Swift Object Server."""
start_time = time.time()
Expand Down Expand Up @@ -824,6 +829,7 @@ def DELETE(self, request):
self.logger.timing_since('DELETE.timing', start_time)
return resp

@public
def REPLICATE(self, request):
"""
Handle REPLICATE requests for the Swift Object Server. This is used
Expand Down Expand Up @@ -862,10 +868,14 @@ def __call__(self, env, start_response):
res = HTTPPreconditionFailed(body='Invalid UTF8')
else:
try:
if hasattr(self, req.method):
res = getattr(self, req.method)(req)
else:
# disallow methods which have not been marked 'public'
try:
method = getattr(self, req.method)
getattr(method, 'publicly_accessible')
except AttributeError:
res = HTTPMethodNotAllowed()
else:
res = method(req)
except (Exception, Timeout):
self.logger.exception(_('ERROR __call__ error with %(method)s'
' %(path)s '), {'method': req.method, 'path': req.path})
Expand Down
22 changes: 2 additions & 20 deletions swift/proxy/server.py
Expand Up @@ -53,7 +53,7 @@

from swift.common.ring import Ring
from swift.common.utils import cache_from_env, ContextPool, get_logger, \
get_remote_client, normalize_timestamp, split_path, TRUE_VALUES
get_remote_client, normalize_timestamp, split_path, TRUE_VALUES, public
from swift.common.bufferedhttp import http_connect
from swift.common.constraints import check_metadata, check_object_creation, \
check_utf8, CONTAINER_LISTING_LIMIT, MAX_ACCOUNT_NAME_LENGTH, \
Expand Down Expand Up @@ -86,21 +86,6 @@ def update_headers(response, headers):
response.headers[name] = value


def public(func):
"""
Decorator to declare which methods are publicly accessible as HTTP
requests
:param func: function to make public
"""
func.publicly_accessible = True

@functools.wraps(func)
def wrapped(*a, **kw):
return func(*a, **kw)
return wrapped


def delay_denial(func):
"""
Decorator to declare which methods should have any swift.authorize call
Expand Down Expand Up @@ -2022,11 +2007,8 @@ def handle_request(self, req):
self.logger.client_ip = get_remote_client(req)
try:
handler = getattr(controller, req.method)
if not getattr(handler, 'publicly_accessible'):
handler = None
getattr(handler, 'publicly_accessible')
except AttributeError:
handler = None
if not handler:
self.logger.increment('method_not_allowed')
return HTTPMethodNotAllowed(request=req)
if path_parts['version']:
Expand Down
24 changes: 24 additions & 0 deletions test/unit/account/test_server.py
Expand Up @@ -962,6 +962,30 @@ def start_response(*args):
self.assertEquals(errbuf.getvalue(), '')
self.assertEquals(outbuf.getvalue()[:4], '400 ')

def test_invalid_method_doesnt_exist(self):
inbuf = StringIO()
errbuf = StringIO()
outbuf = StringIO()
def start_response(*args):
outbuf.writelines(args)
self.controller.__call__({'REQUEST_METHOD': 'method_doesnt_exist',
'PATH_INFO': '/sda1/p/a'},
start_response)
self.assertEquals(errbuf.getvalue(), '')
self.assertEquals(outbuf.getvalue()[:4], '405 ')

def test_invalid_method_is_not_public(self):
inbuf = StringIO()
errbuf = StringIO()
outbuf = StringIO()
def start_response(*args):
outbuf.writelines(args)
self.controller.__call__({'REQUEST_METHOD': '__init__',
'PATH_INFO': '/sda1/p/a'},
start_response)
self.assertEquals(errbuf.getvalue(), '')
self.assertEquals(outbuf.getvalue()[:4], '405 ')

def test_params_utf8(self):
self.controller.PUT(Request.blank('/sda1/p/a',
headers={'X-Timestamp': normalize_timestamp(1)},
Expand Down
24 changes: 24 additions & 0 deletions test/unit/container/test_server.py
Expand Up @@ -928,6 +928,30 @@ def start_response(*args):
self.assertEquals(errbuf.getvalue(), '')
self.assertEquals(outbuf.getvalue()[:4], '400 ')

def test_invalid_method_doesnt_exist(self):
inbuf = StringIO()
errbuf = StringIO()
outbuf = StringIO()
def start_response(*args):
outbuf.writelines(args)
self.controller.__call__({'REQUEST_METHOD': 'method_doesnt_exist',
'PATH_INFO': '/sda1/p/a/c'},
start_response)
self.assertEquals(errbuf.getvalue(), '')
self.assertEquals(outbuf.getvalue()[:4], '405 ')

def test_invalid_method_is_not_public(self):
inbuf = StringIO()
errbuf = StringIO()
outbuf = StringIO()
def start_response(*args):
outbuf.writelines(args)
self.controller.__call__({'REQUEST_METHOD': '__init__',
'PATH_INFO': '/sda1/p/a/c'},
start_response)
self.assertEquals(errbuf.getvalue(), '')
self.assertEquals(outbuf.getvalue()[:4], '405 ')

def test_params_utf8(self):
self.controller.PUT(Request.blank('/sda1/p/a/c',
headers={'X-Timestamp': normalize_timestamp(1)},
Expand Down
24 changes: 24 additions & 0 deletions test/unit/obj/test_server.py
Expand Up @@ -1342,6 +1342,30 @@ def start_response(*args):
self.assertEquals(errbuf.getvalue(), '')
self.assertEquals(outbuf.getvalue()[:4], '405 ')

def test_invalid_method_doesnt_exist(self):
inbuf = StringIO()
errbuf = StringIO()
outbuf = StringIO()
def start_response(*args):
outbuf.writelines(args)
self.object_controller.__call__({'REQUEST_METHOD': 'method_doesnt_exist',
'PATH_INFO': '/sda1/p/a/c/o'},
start_response)
self.assertEquals(errbuf.getvalue(), '')
self.assertEquals(outbuf.getvalue()[:4], '405 ')

def test_invalid_method_is_not_public(self):
inbuf = StringIO()
errbuf = StringIO()
outbuf = StringIO()
def start_response(*args):
outbuf.writelines(args)
self.object_controller.__call__({'REQUEST_METHOD': '__init__',
'PATH_INFO': '/sda1/p/a/c/o'},
start_response)
self.assertEquals(errbuf.getvalue(), '')
self.assertEquals(outbuf.getvalue()[:4], '405 ')

def test_chunked_put(self):
listener = listen(('localhost', 0))
port = listener.getsockname()[1]
Expand Down

0 comments on commit 9f5a6bb

Please sign in to comment.