Skip to content

Commit

Permalink
add unittest for container based service upload
Browse files Browse the repository at this point in the history
  • Loading branch information
VertexC committed Jun 30, 2019
1 parent 9ce8d8c commit 2c5424f
Show file tree
Hide file tree
Showing 8 changed files with 405 additions and 21 deletions.
10 changes: 5 additions & 5 deletions mod_config/controllers.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ def data_processing_ajax(action):


def verify_and_import_module(temp_path, final_path, form, is_container=False):
# import pdb; pdb.set_trace()
if is_container:
instance = ServiceLoader.load_from_container(temp_path)
else:
Expand All @@ -299,17 +300,16 @@ def verify_and_import_module(temp_path, final_path, form, is_container=False):
@template_renderer()
def services():
form = NewServiceForm()

if form.validate_on_submit():
# Process uploaded file
file = request.files[form.file.name]
if file:
filename = secure_filename(file.filename)
basename = filename.split('.')[0]
basename, extname = os.path.splitext(filename)
temp_dir = os.path.join('./pipot/services/temp', basename)
final_dir = os.path.join('./pipot/services', basename)
if not os.path.isdir(final_dir):
if zipfile.is_zipfile(file):
if extname == '.zip':
zip_file = zipfile.ZipFile(file)
ret = zip_file.testzip()
if ret:
Expand All @@ -321,7 +321,7 @@ def services():
# Reset form, all ok
form = NewServiceForm(None)
except ServiceLoader.ServiceLoaderException as e:
os.remove(temp_dir)
shutil.rmtree(temp_dir)
form.errors['container'] = [e.value]
else:
if os.path.exists(temp_dir):
Expand All @@ -338,7 +338,7 @@ def services():
form = NewServiceForm(None)
except ServiceLoader.ServiceLoaderException as e:
# Remove file
os.remove(temp_dir)
shutil.rmtree(temp_dir)
# Pass error to user
form.errors['file'] = [e.value]
else:
Expand Down
4 changes: 3 additions & 1 deletion mod_config/forms.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import re
import os
from enum import Enum
from flask_wtf import Form
from wtforms import SubmitField, FileField, TextAreaField, HiddenField, \
Expand All @@ -16,14 +17,15 @@ class FileType(Enum):
def is_python_or_container(file_name):
# Check if it ends on .py
is_py = re.compile(r"^[^/\\]*.py$").match(file_name)
is_container = re.compile(r"^[^/\\]*.zip$").match(file_name)
is_container = re.compile((r"^[^/\\]*.zip$")).match(file_name)
if not is_py and not is_container:
raise ValidationError('Provided file is not a python (.py) file or a container (.zip)!')
return FileType.CONTAINER if is_container else FileType.PYTHONFILE


def simple_service_file_validation(check_service=True):
def validate_file(form, field):
field.data.filename = os.path.basename(field.data.filename)
file_type = is_python_or_container(field.data.filename)
if file_type is FileType.PYTHONFILE:
# Name cannot be one of the files we already have
Expand Down
44 changes: 29 additions & 15 deletions tests/testAppBase.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,19 @@
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import tests.config
import flask
from flask import g, current_app, session
from collections import namedtuple
from flask import g, current_app
from database import create_session, Base
from mod_auth.models import User, Role, Page, PageAccess
from mod_config.models import Service, Notification, Actions, Conditions, Rule
from mod_honeypot.models import Profile, PiModels, PiPotReport, ProfileService, \
CollectorTypes, Deployment


def generate_keys(tempdir):
secret_csrf_path = os.path.join(tempdir, "secret_csrf")
secret_key_path = os.path.join(tempdir, "secret_key")
def generate_keys(keydir):
secret_csrf_path = os.path.join(keydir, "secret_csrf")
secret_key_path = os.path.join(keydir, "secret_key")
if not os.path.exists(secret_csrf_path):
secret_csrf_cmd = "head -c 24 /dev/urandom > {path}".format(path=secret_csrf_path)
os.system(secret_csrf_cmd)
Expand All @@ -35,8 +36,8 @@ def generate_keys(tempdir):
return {'secret_csrf_path': secret_csrf_path, 'secret_key_path': secret_key_path}


def load_config(tempdir):
key_paths = generate_keys(tempdir)
def load_config(keydir):
key_paths = generate_keys(keydir)
with open(key_paths['secret_key_path'], 'rb') as secret_key_file:
secret_key = secret_key_file.read()
with open(key_paths['secret_csrf_path'], 'rb') as secret_csrf_file:
Expand All @@ -59,34 +60,47 @@ def load_config(tempdir):
}


class TestAppBaseTest(unittest.TestCase):
tempdir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "temp")
class TestAppBase(unittest.TestCase):
keydir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "keys")

def create_app(self):
with patch('config_parser.parse_config', return_value=load_config(self.tempdir)):
with patch('config_parser.parse_config', return_value=load_config(self.keydir)):
from run import app
return app

def create_admin(self):
# test if there is admin existed
name, password, email = "admin", "adminpwd", "admin@email.com"
db = create_session(self.app.config['DATABASE_URI'], drop_tables=False)
role = Role(name="Admin")
role = Role(name=name)
db.add(role)
db.commit()
admin_user = User(role_id=role.id, name="Admin", password="admin", email="admin@sample.com")
admin_user = User(role_id=role.id, name=name, password=password, email=email)
db.add(admin_user)
db.commit()
db.remove()
return admin_user
return name, password, email

def setUp(self):
if not os.path.exists(self.tempdir):
os.mkdir(self.tempdir)
if not os.path.exists(self.keydir):
os.mkdir(self.keydir)
self.app = self.create_app()
self.client = self.app.test_client(self)

def tearDown(self):
db = create_session(self.app.config['DATABASE_URI'], drop_tables=False)
db_engine = create_engine(self.app.config['DATABASE_URI'], convert_unicode=True)
Base.metadata.drop_all(bind=db_engine)
db.remove()


class TestApp(TestAppBase):

def setUp(self):
super(TestApp, self).setUp()

def tearDown(self):
super(TestApp, self).tearDown()

def test_app_is_running(self):
self.assertFalse(current_app is None)
Expand Down
125 changes: 125 additions & 0 deletions tests/testFiles/TelnetService.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import datetime
from sqlalchemy import Column, String
from twisted.internet.protocol import Protocol, Factory

from pipot.services.IService import INetworkService, IModelIP


class ReportTelnet(IModelIP):
__tablename__ = 'report_telnet'

password = Column(String(100))

def __init__(self, deployment_id, ip, port, password, timestamp=None):
super(ReportTelnet, self).__init__(deployment_id, ip, port, timestamp)
self.password = password

def get_message_for_level(self, notification_level):
message = 'Telnet login attempt with password %s' % self.password
message += '\nPlease take action!' if notification_level == 2 else ''
return message


class SimpleTelnetProtocol(Protocol):
"""
Example Telnet Protocol
$ telnet localhost 8025
Trying 127.0.0.1...
Connected to localhost.
Escape character is '^]'.
password:
password:
password:
% Bad passwords
Connection closed by foreign host.
"""
def __init__(self):
self.prompts = 0
self.buffer = ""

def connectionMade(self):
self.transport.write("\xff\xfb\x03\xff\xfb\x01password: ")
self.prompts += 1

def dataReceived(self, data):
"""
Received data is unbuffered so we buffer it for telnet.
"""
self.buffer += data

i = self.buffer.find("\x01")
if i >= 0:
self.buffer = self.buffer[i+1:]
return

if self.buffer.find("\x00") >= 0:
password = self.buffer.strip("\r\n\x00")
log_data = {"password": password}
self.factory.log(log_data, transport=self.transport)
self.buffer = ""

if self.prompts < 3:
self.transport.write("\r\npassword: ")
self.prompts += 1
else:
self.transport.write("\r\n% Bad passwords\r\n")
self.transport.loseConnection()


class TelnetService(INetworkService, Factory):
protocol = SimpleTelnetProtocol

def __init__(self, collector, config):
super(TelnetService, self).__init__(collector, config, 8025)
""":type : list"""
self._report_types = ['entries']

def get_notification_levels(self):
return [1, 2]

def get_used_table_names(self):
return {ReportTelnet.__tablename__: ReportTelnet}

def create_storage_row(self, deployment_id, data, timestamp):
return ReportTelnet(deployment_id, data['src_host'], data['src_port'],
data['password'], timestamp)

def get_notification_level(self, storage_row):
return 1 if storage_row.password == "admin" else 2

def get_report_types(self):
return self._report_types

def get_data_for_type(self, report_type, **kwargs):
if report_type == 'entries':
days = kwargs.pop('time', 7)
timestamp = datetime.datetime.utcnow() - datetime.timedelta(
days=days)
data = ReportTelnet.query.filter(
ReportTelnet.timestamp >= timestamp).order_by(
ReportTelnet.timestamp.desc()).all()
return data
return {}

def get_data_for_type_default_args(self, report_type):
if report_type == 'entries':
return {'time': 7}
return {}

def get_template_for_type(self, report_type):
if report_type == 'entries':
return '<table><thead><tr><th>ID</th><th>Timestamp</th>' \
'<th>IP:port</th><th>Password</th></tr></thead><tbody>' \
'{% for entry in entries %}<tr><td>{{ entry.id }}</td>' \
'<td>{{ entry.timestamp }}</td><td>{{ entry.ip}}:' \
'{{ entry.port }}</td><td>{{ entry.password }}</td></tr>' \
'{% else %}<tr><td colspan="4">No entries for this ' \
'timespan</td></tr>{% endfor %}</tbody></table>'
return ''

def get_template_arguments(self, report_type, initial_data):
if report_type == 'entries':
return {
'entries': initial_data
}
return {}
Binary file added tests/testFiles/TelnetService.zip
Binary file not shown.
Loading

0 comments on commit 2c5424f

Please sign in to comment.