Skip to content

Commit

Permalink
Fix wrongly introduced Python 3 incompatibility. Fixes #14 and #15.
Browse files Browse the repository at this point in the history
Add static type checks via mypy (optional static type checker),
Add relevant tests, which could trigger the issue.
  • Loading branch information
arthepsy committed Oct 17, 2016
1 parent f065118 commit 6b76e68
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 7 deletions.
20 changes: 13 additions & 7 deletions ssh-audit.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
"""
from __future__ import print_function
import os, io, sys, socket, struct, random, errno, getopt, re, hashlib, base64
try:
from typing import List, Tuple, Text
except:
pass

VERSION = 'v1.6.1.dev'

Expand Down Expand Up @@ -940,14 +944,15 @@ def get_banner(self, sshv=2):
return self.__banner, self.__header

def recv(self, size=2048):
# type: (int) -> Tuple[int, str]
try:
data = self.__sock.recv(size)
except socket.timeout as e:
r = 0 if e.strerror == 'timed out' else -1
return (r, e)
except socket.timeout:
return (-1, 'timeout')
except socket.error as e:
r = 0 if e.errno in (errno.EAGAIN, errno.EWOULDBLOCK) else -1
return (r, e)
if e.args[0] in (errno.EAGAIN, errno.EWOULDBLOCK):
return (0, 'retry')
return (-1, str(e.args[-1]))
if len(data) == 0:
return (-1, None)
pos = self._buf.tell()
Expand Down Expand Up @@ -977,6 +982,7 @@ def ensure_read(self, size):
raise SSH.Socket.InsufficientReadException(e)

def read_packet(self, sshv=2):
# type: (int) -> Tuple[int, bytes]
try:
header = WriteBuf()
self.ensure_read(4)
Expand Down Expand Up @@ -1024,7 +1030,7 @@ def read_packet(self, sshv=2):
header.write(self.read(self.unread_len))
e = header.write_flush().strip()
else:
e = ex.args[0]
e = ex.args[0].encode('utf-8')
return (-1, e)

def send_packet(self):
Expand Down Expand Up @@ -1651,7 +1657,7 @@ def audit(conf, sshv=None):
if err is None:
packet_type, payload = s.read_packet(sshv)
if packet_type < 0:
payload = str(payload).decode('utf-8')
payload = payload.decode('utf-8') if payload else u'empty'
if payload == u'Protocol major versions differ.':
if sshv == 2 and conf.ssh1:
audit(conf, 1)
Expand Down
96 changes: 96 additions & 0 deletions test/test_errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import pytest, socket


class TestErrors(object):
@pytest.fixture(autouse=True)
def init(self, ssh_audit):
self.AuditConf = ssh_audit.AuditConf
self.audit = ssh_audit.audit

def _conf(self):
conf = self.AuditConf('localhost', 22)
conf.batch = True
return conf

def test_connection_refused(self, output_spy, virtual_socket):
vsocket = virtual_socket
vsocket.errors['connect'] = socket.error(61, 'Connection refused')
output_spy.begin()
with pytest.raises(SystemExit):
self.audit(self._conf())
lines = output_spy.flush()
assert len(lines) == 1
assert 'Connection refused' in lines[-1]

def test_connection_closed_before_banner(self, output_spy, virtual_socket):
vsocket = virtual_socket
vsocket.rdata.append(socket.error(54, 'Connection reset by peer'))
output_spy.begin()
with pytest.raises(SystemExit):
self.audit(self._conf())
lines = output_spy.flush()
assert len(lines) == 1
assert 'did not receive banner' in lines[-1]

def test_connection_closed_after_header(self, output_spy, virtual_socket):
vsocket = virtual_socket
vsocket.rdata.append(b'header line 1\n')
vsocket.rdata.append(b'header line 2\n')
vsocket.rdata.append(socket.error(54, 'Connection reset by peer'))
output_spy.begin()
with pytest.raises(SystemExit):
self.audit(self._conf())
lines = output_spy.flush()
assert len(lines) == 3
assert 'did not receive banner' in lines[-1]

def test_connection_closed_after_banner(self, output_spy, virtual_socket):
vsocket = virtual_socket
vsocket.rdata.append(b'SSH-2.0-ssh-audit-test\r\n')
vsocket.rdata.append(socket.error(54, 'Connection reset by peer'))
output_spy.begin()
with pytest.raises(SystemExit):
self.audit(self._conf())
lines = output_spy.flush()
assert len(lines) == 2
assert 'error reading packet' in lines[-1]
assert 'reset by peer' in lines[-1]

def test_empty_data_after_banner(self, output_spy, virtual_socket):
vsocket = virtual_socket
vsocket.rdata.append(b'SSH-2.0-ssh-audit-test\r\n')
output_spy.begin()
with pytest.raises(SystemExit):
self.audit(self._conf())
lines = output_spy.flush()
assert len(lines) == 2
assert 'error reading packet' in lines[-1]
assert 'empty' in lines[-1]

def test_wrong_data_after_banner(self, output_spy, virtual_socket):
vsocket = virtual_socket
vsocket.rdata.append(b'SSH-2.0-ssh-audit-test\r\n')
vsocket.rdata.append(b'xxx\n')
output_spy.begin()
with pytest.raises(SystemExit):
self.audit(self._conf())
lines = output_spy.flush()
assert len(lines) == 2
assert 'error reading packet' in lines[-1]
assert 'xxx' in lines[-1]

def test_protocol_mismatch_by_conf(self, output_spy, virtual_socket):
vsocket = virtual_socket
vsocket.rdata.append(b'SSH-1.3-ssh-audit-test\r\n')
vsocket.rdata.append(b'Protocol major versions differ.\n')
output_spy.begin()
with pytest.raises(SystemExit):
conf = self._conf()
conf.ssh1, conf.ssh2 = True, False
self.audit(conf)
lines = output_spy.flush()
assert len(lines) == 3
assert 'error reading packet' in lines[-1]
assert 'major versions differ' in lines[-1]

0 comments on commit 6b76e68

Please sign in to comment.