Skip to content

Commit

Permalink
New get_average_rtt algorithm which handles timeouts when one of the …
Browse files Browse the repository at this point in the history
…threads dead-locks, fails, etc.
  • Loading branch information
andresriancho committed Nov 29, 2019
1 parent 27fd576 commit b54977d
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 57 deletions.
148 changes: 99 additions & 49 deletions w3af/core/data/url/get_average_rtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,15 @@


class GetAverageRTTForMutant(object):

TIMEOUT = 120

def __init__(self, url_opener):
self._url_opener = url_opener

# Cache to measure RTT
self._rtt_mutant_cache = SynchronizedLRUDict(capacity=128)
self._rtt_mutant_lock = threading.RLock()
self._specific_rtt_mutant_locks = dict()
self._rtt_processing_events = dict()

def _get_cache_key(self, mutant):
#
Expand Down Expand Up @@ -72,59 +74,118 @@ def get_average_rtt_for_mutant(self, mutant, count=3, debugging_id=None):
"""
assert count >= 3, 'Count must be greater or equal than 3.'

#
# First we try to get the data from the cache
#
cache_key = self._get_cache_key(mutant)
cached_rtt = self._get_cached_rtt(cache_key, debugging_id=debugging_id)

if cached_rtt is not None:
return cached_rtt

#
# Only perform one of these checks at the time, this is useful to prevent
# different threads which need the same result from duplicating efforts
#
specific_rtt_mutant_lock = self._get_specific_rtt_mutant_lock(cache_key)

with specific_rtt_mutant_lock:
cached_value = self._rtt_mutant_cache.get(cache_key, default=None)
rtt_processing_event = self._rtt_processing_events.get(cache_key, None)

if cached_value is not None:
timestamp, value = cached_value
if time.time() - timestamp <= 5:
#
# The cache entry is still valid, return the cached value
#
msg = 'Returning cached average RTT of %.2f seconds for mutant %s'
om.out.debug(msg % (value, cache_key))
return value
if rtt_processing_event is not None:
# There is another thread sending HTTP requests to get the average RTT
# we need to wait for that thread to finish
wait_result = rtt_processing_event.wait(timeout=self.TIMEOUT)

#
# Need to send the HTTP requests and do the average
#
rtts = self._get_rtts(mutant, count, debugging_id)

if self._has_outliers(rtts):
if not wait_result:
# The TIMEOUT has been reached, the thread that was trying to get
# the RTT for us found a serious issue, is dead-locked, etc.
#
# We're going to have to try to get the RTT ourselves by sending
# the HTTP requests. Just `pass` here and get to the code below
# that sends the HTTP requests
msg = ('get_average_rtt_for_mutant() timed out waiting for'
' results from another thread. Will send HTTP requests'
' and collect the data from the network (did:%s)')
args = (debugging_id,)
om.out.debug(msg % args)
else:
# The event was set! The other thread finished and we can read
# the result from the cache.
#
# The measurement has outliers, we can't continue! If we do
# continue the average_rtt will be completely invalid and
# potentially yield false positives
# Just in case the other thread had issues getting the RTTs, we
# need to check if the cache actually has the data, and if the
# data is valid
#
self._remove_cache_key_from_mutant_locks(cache_key)
rtts_str = ', '.join(str(i) for i in rtts)
msg = 'Found outliers while sampling average RTT: %s' % rtts_str
raise OutlierException(msg)
# No need to check the timestamp because we know it will be
# valid, it has been just set by the other thread
cached_rtt = self._get_cached_rtt(cache_key, debugging_id=debugging_id)

average_rtt = float(sum(rtts)) / len(rtts)
self._rtt_mutant_cache[cache_key] = (time.time(), average_rtt)
if cached_rtt is not None:
return cached_rtt

self._remove_cache_key_from_mutant_locks(cache_key)
msg = ('get_average_rtt_for_mutant() found no cache entry after'
' the other thread finished. Will send HTTP requests'
' and collect the data from the network (did:%s)')
args = (debugging_id,)
om.out.debug(msg % args)

msg = 'Returning fresh average RTT of %.2f seconds for mutant %s'
om.out.debug(msg % (average_rtt, cache_key))
#
# There is no other thread getting data for `cache_key`, we'll have to
# extract the information by sending the HTTP requests
#
event = threading.Event()
self._rtt_processing_events[cache_key] = event

try:
average_rtt = self._get_average_rtt_for_mutant(mutant,
count=count,
debugging_id=debugging_id)
self._rtt_mutant_cache[cache_key] = (time.time(),
average_rtt)
finally:
event.set()
self._rtt_processing_events.pop(event, None)

msg = 'Returning fresh average RTT of %.2f seconds for mutant %s (did:%s)'
args = (average_rtt, cache_key, debugging_id)
om.out.debug(msg % args)

return average_rtt

def _remove_cache_key_from_mutant_locks(self, cache_key):
with self._rtt_mutant_lock:
if cache_key in self._specific_rtt_mutant_locks:
self._specific_rtt_mutant_locks.pop(cache_key)
def _get_cached_rtt(self, cache_key, debugging_id):
cached_value = self._rtt_mutant_cache.get(cache_key, default=None)

if cached_value is None:
return None

def _get_rtts(self, mutant, count=3, debugging_id=None):
timestamp, value = cached_value
if time.time() - timestamp > 5:
return None

# The cache entry is still valid, return the cached value
msg = 'Returning cached average RTT of %.2f seconds for mutant %s (did:%s)'
args = (value, cache_key, debugging_id)
om.out.debug(msg % args)
return value

def _get_average_rtt_for_mutant(self, mutant, count=3, debugging_id=None):
#
# Need to send the HTTP requests and do the average
#
rtts = self._get_all_rtts(mutant, count, debugging_id)

if self._has_outliers(rtts):
#
# The measurement has outliers, we can't continue! If we do
# continue the average_rtt will be completely invalid and
# potentially yield false positives
#
rtts_str = ', '.join(str(i) for i in rtts)
msg = 'Found outliers while sampling average RTT: %s' % rtts_str
raise OutlierException(msg)

average_rtt = float(sum(rtts)) / len(rtts)
return average_rtt

def _get_all_rtts(self, mutant, count=3, debugging_id=None):
"""
:param mutant: The mutant to send and measure RTT from
:param count: Number of checks to perform
Expand Down Expand Up @@ -165,17 +226,6 @@ def _has_outliers(self, rtts):
#
return False

def _get_specific_rtt_mutant_lock(self, cache_key):
with self._rtt_mutant_lock:
specific_rtt_mutant_lock = self._specific_rtt_mutant_locks.get(cache_key)

if specific_rtt_mutant_lock is not None:
return specific_rtt_mutant_lock

specific_rtt_mutant_lock = threading.RLock()
self._specific_rtt_mutant_locks[cache_key] = specific_rtt_mutant_lock
return specific_rtt_mutant_lock


class OutlierException(Exception):
pass
42 changes: 34 additions & 8 deletions w3af/core/data/url/tests/test_get_average_rtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@
import time
import random
import unittest

import httpretty

from nose.plugins.attrib import attr
from multiprocessing.dummy import Pool as ThreadPool
from itertools import repeat

from w3af.core.data.url.extended_urllib import ExtendedUrllib
from w3af.core.data.parsers.doc.url import URL
Expand All @@ -43,17 +45,17 @@ def tearDown(self):
self.uri_opener.end()
httpretty.reset()

@staticmethod
def request_callback_05(request, uri, headers):
time.sleep(0.5)
body = 'Yup'
return 200, headers, body

@httpretty.activate
def test_get_average_rtt_for_mutant_all_equal(self):

def request_callback(request, uri, headers):
time.sleep(0.5)
body = 'Yup'
return 200, headers, body

httpretty.register_uri(httpretty.GET,
self.MOCK_URL,
body=request_callback)
body=TestGetAverageRTT.request_callback_05)

mock_url = URL(self.MOCK_URL)
fuzzable_request = FuzzableRequest(mock_url)
Expand Down Expand Up @@ -102,6 +104,30 @@ def test_get_average_rtt_for_mutant_one_off(self):
self.assertGreater(average_rtt, 0.80)
self.assertGreater(0.90, average_rtt)

@httpretty.activate
def test_get_average_rtt_for_mutant_with_threads(self):
httpretty.register_uri(httpretty.GET,
self.MOCK_URL,
body=TestGetAverageRTT.request_callback_05)

pool = ThreadPool(25)
mock_url = URL(self.MOCK_URL)
fuzzable_request = FuzzableRequest(mock_url)

iterations = 50

results = pool.map(self.uri_opener.get_average_rtt_for_mutant,
repeat(fuzzable_request, iterations))

self.assertEqual(len(results), iterations)

for result_n in results:
self.assertEqual(result_n, results[0])

# Check the response
self.assertGreater(results[0], 0.45)
self.assertGreater(0.55, results[0])


class RequestCallBackWithDelays(object):

Expand Down

0 comments on commit b54977d

Please sign in to comment.