Skip to content

Commit

Permalink
Compressing old trace files
Browse files Browse the repository at this point in the history
  • Loading branch information
andresriancho committed Apr 24, 2019
1 parent 09e7544 commit 8cb8883
Show file tree
Hide file tree
Showing 2 changed files with 261 additions and 54 deletions.
275 changes: 222 additions & 53 deletions w3af/core/data/db/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,9 @@
import os
import time
import threading
import zipfile
import msgpack

import lz4.frame

from functools import wraps
from shutil import rmtree

Expand Down Expand Up @@ -75,10 +74,13 @@ class HistoryItem(object):
_PRIMARY_KEY_COLUMNS = ('id',)
_INDEX_COLUMNS = ('alias',)

_EXTENSION = '.trace'
_EXTENSION = 'trace'
_MSGPACK_CANARY = 'cute-and-yellow'

COMPRESSION_LEVEL = 2
_COMPRESSED_EXTENSION = 'zip'
_COMPRESSED_FILE_BATCH = 150
_UNCOMPRESSED_FILES = 50
_COMPRESSION_LEVEL = 7

id = None
_request = None
Expand All @@ -101,6 +103,9 @@ def __init__(self):
self._session_dir = os.path.join(get_temp_dir(),
self._db.get_file_name() + '_traces')

def get_session_dir(self):
return self._session_dir

def init(self):
self.init_traces_dir()
self.init_db()
Expand Down Expand Up @@ -151,9 +156,10 @@ def set_request(self, req):

@verify_has_db
def find(self, searchData, result_limit=-1, orderData=[], full=False):
"""Make complex search.
search_data = {name: (value, operator), ...}
orderData = [(name, direction)]
"""
Make complex search.
search_data = {name: (value, operator), ...}
orderData = [(name, direction)]
"""
result = []
sql = 'SELECT * FROM ' + self._DATA_TABLE
Expand Down Expand Up @@ -196,64 +202,141 @@ def _load_from_row(self, row, full=True):
self.method = row[10]
self.response_size = int(row[11])

def _get_fname_for_id(self, _id):
return os.path.join(self._session_dir, str(_id) + self._EXTENSION)

def _get_trace_filename_for_id(self, _id):
return os.path.join(self._session_dir, '%s.%s' % (_id, self._EXTENSION))

def _load_from_trace_file(self, _id):
"""
Load a request/response from a trace file on disk. This is the
simplest implementation, without any retries for concurrency issues.
:param _id: The request-response ID
:return: A tuple containing request and response instances
"""
file_name = self._get_trace_filename_for_id(_id)

if not os.path.exists(file_name):
raise TraceReadException()

# The file exists, but the contents might not be all on-disk yet
serialized_req_res = open(file_name, 'rb').read()
return self._load_from_string(serialized_req_res)

def _load_from_string(self, serialized_req_res):
try:
data = msgpack.loads(serialized_req_res, use_list=True)
except ValueError:
# ValueError: Extra data. returned when msgpack finds invalid
# data in the file
raise TraceReadException()

try:
request_dict, response_dict, canary = data
except TypeError:
# https://github.com/andresriancho/w3af/issues/1101
# 'NoneType' object is not iterable
raise TraceReadException()

if not canary == self._MSGPACK_CANARY:
# read failed, most likely because the file write is not
# complete but for some reason it was a valid msgpack file
raise TraceReadException()

request = HTTPRequest.from_dict(request_dict)
response = HTTPResponse.from_dict(response_dict)
return request, response

def _load_from_trace_file_concurrent(self, _id):
"""
Load a request/response from a trace file on disk, using retries
and error handling to make sure all concurrency issues are handled.
:param _id: The request-response ID
:return: A tuple containing request and response instances
"""
wait_time = 0.05

#
# Retry the read a few times to handle concurrency issues
#
for _ in xrange(int(1 / wait_time)):
try:
self._load_from_trace_file(_id)
except TraceReadException:
time.sleep(wait_time)

else:
msg = 'Timeout expecting trace file "%s" to be ready'
file_name = self._get_trace_filename_for_id(_id)
raise DBException(msg % file_name)

def load_from_file(self, _id):
fname = self._get_fname_for_id(_id)
WAIT_TIME = 0.05
"""
Loads a request/response from a trace file on disk. Two different
options exist:
* The file is compressed inside a zip
* The file is uncompressed in a trace
:param _id: The request-response ID
:return: A tuple containing request and response instances
"""
#
# Due to some concurrency issues, we need to perform these checks
# First we check if the trace file exists and try to load it from
# the uncompressed trace
#
for _ in xrange(int(1 / WAIT_TIME)):
if not os.path.exists(fname):
time.sleep(WAIT_TIME)
continue
file_name = self._get_trace_filename_for_id(_id)

# Ok... the file exists, but it might still be being written
req_res = lz4.frame.decompress(open(fname, 'rb').read())
if os.path.exists(file_name):
return self._load_from_trace_file_concurrent(_id)

try:
data = msgpack.loads(req_res, use_list=True)
except ValueError:
# ValueError: Extra data. returned when msgpack finds invalid
# data in the file
time.sleep(WAIT_TIME)
continue
#
# The trace file doesn't exist, try to find the zip file where the
# compressed file lives and read it from there
#
try:
return self._load_from_zip(_id)
except TraceReadException:
#
# Give the .trace file a last chance, it might be possible that when
# we checked for os.path.exists(file_name) at the beginning of this
# method the file wasn't there yet, but is on disk now
#
return self._load_from_trace_file_concurrent(_id)

try:
request_dict, response_dict, canary = data
except TypeError:
# https://github.com/andresriancho/w3af/issues/1101
# 'NoneType' object is not iterable
time.sleep(WAIT_TIME)
continue

if not canary == self._MSGPACK_CANARY:
# read failed, most likely because the file write is not
# complete but for some reason it was a valid msgpack file
time.sleep(WAIT_TIME)
continue

request = HTTPRequest.from_dict(request_dict)
response = HTTPResponse.from_dict(response_dict)
return request, response
def _load_from_zip(self, _id):
files = os.listdir(self.get_session_dir())
files = [f for f in files if f.endswith(self._COMPRESSED_EXTENSION)]

else:
msg = 'Timeout expecting trace file to be ready "%s"' % fname
raise DBException(msg)
for zip_file in files:
start, end = get_zip_id_range(zip_file)

if start <= _id <= end:
return self._load_from_zip_file(_id, zip_file)

def _load_from_zip_file(self, _id, zip_file):
_zip = zipfile.ZipFile(os.path.join(self.get_session_dir(), zip_file))

try:
serialized_req_res = _zip.read('%s.%s' % (_id, self._EXTENSION))
except KeyError:
# We get here when the zip file doesn't contain the trace file
raise TraceReadException()

return self._load_from_string(serialized_req_res)

@verify_has_db
def delete(self, _id=None):
"""Delete data from DB by ID."""
"""
Delete data from DB by ID.
"""
if _id is None:
_id = self.id

sql = 'DELETE FROM ' + self._DATA_TABLE + ' WHERE id = ? '
self._db.execute(sql, (_id,))

fname = self._get_fname_for_id(_id)
fname = self._get_trace_filename_for_id(_id)

try:
os.remove(fname)
Expand All @@ -263,15 +346,15 @@ def delete(self, _id=None):
@verify_has_db
def load(self, _id=None, full=True, retry=True):
"""Load data from DB by ID."""
if not _id:
if _id is None:
_id = self.id

sql = 'SELECT * FROM ' + self._DATA_TABLE + ' WHERE id = ? '
try:
row = self._db.select_one(sql, (_id,))
except DBException, dbe:
msg = 'An unexpected error occurred while searching for id "%s"'\
' in table "%s". Original exception: "%s".'
msg = ('An unexpected error occurred while searching for id "%s"'
' in table "%s". Original exception: "%s".')
raise DBException(msg % (_id, self._DATA_TABLE, dbe))
else:
if row is not None:
Expand Down Expand Up @@ -345,7 +428,7 @@ def save(self):
#
# Save raw data to file
#
path_fname = self._get_fname_for_id(self.id)
path_fname = self._get_trace_filename_for_id(self.id)

try:
req_res = open(path_fname, 'wb')
Expand Down Expand Up @@ -378,11 +461,82 @@ def save(self):
self._MSGPACK_CANARY)
msgpack_data = msgpack.dumps(data)

req_res.write(lz4.frame.compress(msgpack_data))
req_res.write(msgpack_data)
req_res.close()

self._compress_old_traces()

return True

def _compress_old_traces(self):
"""
We'll compress 150 of the oldest files in the session directory.
Not compressing all files because the newest ones might be read by
plugins and we don't want to decompress them right after compressing
them (waste of CPU).
:return: None
"""
#
# Get the list of files to compress, checking that we have enough to
# proceed with compression
#
session_dir = self._session_dir
min_file_count = self._UNCOMPRESSED_FILES + self._COMPRESSED_FILE_BATCH

# Initial check to boost performance
if len(os.listdir(session_dir)) <= min_file_count:
return

files = [os.path.join(session_dir, f) for f in os.listdir(session_dir)]
files = [f for f in files if f.endswith(self._EXTENSION)]
files = [f for f in files if os.path.isfile(f)]

if len(files) <= min_file_count:
return

#
# Sort by ID and remove the last 50 from the list to avoid
# compression-decompression CPU waste
#
files.sort(key=lambda trace_file: get_trace_id(trace_file))
files = files[:-self._UNCOMPRESSED_FILES]

#
# Only compress in 150 file batches, and making sure that the filenames
# are numerically ordered. We need this order to have 1, 2, ... 150 in
# the same file. The filename will be named `1-150.zip` which will later
# be used to find the uncompressed trace.
#
files = files[:self._COMPRESSED_FILE_BATCH]

#
# Compress the oldest 150 files into a gzip
#
start = get_trace_id(files[0])
end = get_trace_id(files[-1])
compressed_filename = '%s-%s.%s' % (start,
end,
self._COMPRESSED_EXTENSION)
compressed_filename = os.path.join(session_dir, compressed_filename)

_zip = zipfile.ZipFile(file=compressed_filename,
mode='w',
compression=zipfile.ZIP_DEFLATED)

for filename in files:
_zip.write(filename=filename,
arcname='%s.%s' % (get_trace_id(filename), self._EXTENSION))

_zip.close()

#
# And now remove the already compressed files
#
for filename in files:
os.remove(filename)

def get_columns(self):
return self._COLUMNS

Expand Down Expand Up @@ -430,3 +584,18 @@ def clear(self):
rmtree(self._session_dir, ignore_errors=True)

return True


def get_trace_id(trace_file):
return int(trace_file.rsplit('/')[-1].rsplit('.')[-2])


def get_zip_id_range(zip_file):
name_ext = zip_file.rsplit('/')[-1]
name = name_ext.split('.')[0]
start, end = name.split('-')
return int(start), int(end)


class TraceReadException(Exception):
pass

0 comments on commit 8cb8883

Please sign in to comment.