From 429abcf118bf9bc92ad7631782066b319039fe18 Mon Sep 17 00:00:00 2001 From: Dmitriy Dubson Date: Mon, 16 Sep 2019 16:54:44 -0700 Subject: [PATCH] Address logic complexity in logic.py - Extract remaining database related functions into respective repositories with unit test backing - Extract Python and NMap importers out of logic.py and into `importers` directory under `app` module - Optimize test fixtures for less repetition in test setup - Enhancing functions to follow camel case naming convention --- .gitignore | 2 + app/importers/NmapImporter.py | 317 ++++++++++ app/importers/PythonImporter.py | 72 +++ app/importers/__init__.py | 17 + app/logic.py | 545 +----------------- controller/controller.py | 119 ++-- db/filters.py | 10 +- db/repositories/CVERepository.py | 8 +- db/repositories/HostRepository.py | 41 +- db/repositories/NoteRepository.py | 41 ++ db/repositories/PortRepository.py | 37 +- db/repositories/ProcessRepository.py | 157 ++++- db/repositories/ScriptRepository.py | 34 ++ db/repositories/ServiceRepository.py | 10 +- log/legion-db.log | 1 - log/legion-startup.log | 1 - log/legion.log | 1 - tests/app/test_logic.py | 57 -- tests/db/helpers/db_helpers.py | 14 +- tests/db/repositories/test_CVERepository.py | 20 +- tests/db/repositories/test_HostRepository.py | 116 ++-- tests/db/repositories/test_NoteRepository.py | 54 ++ tests/db/repositories/test_PortRepository.py | 71 ++- .../db/repositories/test_ProcessRepository.py | 317 ++++++---- .../db/repositories/test_ScriptRepository.py | 30 + .../db/repositories/test_ServiceRepository.py | 65 +-- tests/db/test_filters.py | 60 +- 27 files changed, 1252 insertions(+), 965 deletions(-) create mode 100644 app/importers/NmapImporter.py create mode 100644 app/importers/PythonImporter.py create mode 100644 app/importers/__init__.py create mode 100644 db/repositories/NoteRepository.py create mode 100644 db/repositories/ScriptRepository.py delete mode 100644 log/legion-db.log delete mode 100644 log/legion-startup.log delete mode 100644 log/legion.log create mode 100644 tests/db/repositories/test_NoteRepository.py create mode 100644 tests/db/repositories/test_ScriptRepository.py diff --git a/.gitignore b/.gitignore index eca49e2..7b909fb 100644 --- a/.gitignore +++ b/.gitignore @@ -124,3 +124,5 @@ scripts/CloudFail docker/runLocal.sh docker/cleanupUntagged.sh docker/cleanupExited.sh + +log/*.log \ No newline at end of file diff --git a/app/importers/NmapImporter.py b/app/importers/NmapImporter.py new file mode 100644 index 0000000..97a2b8c --- /dev/null +++ b/app/importers/NmapImporter.py @@ -0,0 +1,317 @@ +""" +LEGION (https://govanguard.io) +Copyright (c) 2018 GoVanguard + + This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public + License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later + version. + + This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied + warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more + details. + + You should have received a copy of the GNU General Public License along with this program. + If not, see . + +Author(s): Dmitriy Dubson (d.dubson@gmail.com) +""" +import sys + +from PyQt5 import QtCore + +from db.database import nmapSessionObj, hostObj, note, osObj, serviceObj, portObj, l1ScriptObj +from parsers.Parser import Parser +from ui.ancillaryDialog import ProgressWidget, time + + +class NmapImporter(QtCore.QThread): + tick = QtCore.pyqtSignal(int, name="changed") # New style signal + done = QtCore.pyqtSignal(name="done") # New style signal + schedule = QtCore.pyqtSignal(object, bool, name="schedule") # New style signal + log = QtCore.pyqtSignal(str, name="log") + + def __init__(self): + QtCore.QThread.__init__(self, parent=None) + self.output = '' + self.importProgressWidget = ProgressWidget('Importing nmap..') + + def tsLog(self, msg): + self.log.emit(str(msg)) + + def setDB(self, db): + self.db = db + + def setFilename(self, filename): + self.filename = filename + + def setOutput(self, output): + self.output = output + + def run( + self): # it is necessary to get the qprocess because we need to send it back to the scheduler when we're done importing + try: + self.importProgressWidget.show() + session = self.db.session() + self.tsLog("Parsing nmap xml file: " + self.filename) + startTime = time() + + try: + parser = Parser(self.filename) + except: + self.tsLog('Giving up on import due to previous errors.') + self.tsLog("Unexpected error: {0}".format(sys.exc_info()[0])) + self.done.emit() + return + + self.db.dbsemaphore.acquire() # ensure that while this thread is running, no one else can write to the DB + s = parser.getSession() # nmap session info + if s: + n = nmapSessionObj(self.filename, s.startTime, s.finish_time, s.nmapVersion, s.scanArgs, s.totalHosts, + s.upHosts, s.downHosts) + session.add(n) + hostCount = len(parser.getAllHosts()) + if hostCount == 0: # to fix a division by zero if we ran nmap on one host + hostCount = 1 + totalprogress = 0 + + self.importProgressWidget.setProgress(int(totalprogress)) + self.importProgressWidget.show() + + createProgress = 0 + createOsNodesProgress = 0 + createPortsProgress = 0 + + for h in parser.getAllHosts(): # create all the hosts that need to be created + db_host = session.query(hostObj).filter_by(ip=h.ip).first() + + if not db_host: # if host doesn't exist in DB, create it first + hid = hostObj(osMatch='', osAccuracy='', ip=h.ip, ipv4=h.ipv4, ipv6=h.ipv6, macaddr=h.macaddr, + status=h.status, hostname=h.hostname, vendor=h.vendor, uptime=h.uptime, + lastboot=h.lastboot, distance=h.distance, state=h.state, count=h.count) + self.tsLog("Adding db_host") + session.add(hid) + t_note = note(h.ip, 'Added by nmap') + session.add(t_note) + else: + self.tsLog("Found db_host already in db") + + createProgress = createProgress + ((100.0 / hostCount) / 5) + totalprogress = totalprogress + createProgress + self.importProgressWidget.setProgress(int(totalprogress)) + self.importProgressWidget.show() + + session.commit() + + for h in parser.getAllHosts(): # create all OS, service and port objects that need to be created + self.tsLog("Processing h {ip}".format(ip=h.ip)) + + db_host = session.query(hostObj).filter_by(ip=h.ip).first() + if db_host: + self.tsLog("Found db_host during os/ports/service processing") + else: + self.log("Did not find db_host during os/ports/service processing") + + os_nodes = h.getOs() # parse and store all the OS nodes + self.tsLog(" 'os_nodes' to process: {os_nodes}".format(os_nodes=str(len(os_nodes)))) + for os in os_nodes: + self.tsLog(" Processing os obj {os}".format(os=str(os.name))) + db_os = session.query(osObj).filter_by(hostId=db_host.id).filter_by(name=os.name).filter_by( + family=os.family).filter_by(generation=os.generation).filter_by(osType=os.osType).filter_by( + vendor=os.vendor).first() + + if not db_os: + t_osObj = osObj(os.name, os.family, os.generation, os.osType, os.vendor, os.accuracy, + db_host.id) + session.add(t_osObj) + + createOsNodesProgress = createOsNodesProgress + ((100.0 / hostCount) / 5) + totalprogress = totalprogress + createOsNodesProgress + self.importProgressWidget.setProgress(int(totalprogress)) + self.importProgressWidget.show() + + session.commit() + + all_ports = h.all_ports() + self.tsLog(" 'ports' to process: {all_ports}".format(all_ports=str(len(all_ports)))) + for p in all_ports: # parse the ports + self.tsLog(" Processing port obj {port}".format(port=str(p.portId))) + s = p.getService() + + if not (s is None): # check if service already exists to avoid adding duplicates + # print(" Found service {service} for port {port}".format(service=str(s.name),port=str(p.portId))) + # db_service = session.query(serviceObj).filter_by(name=s.name).filter_by(product=s.product).filter_by(version=s.version).filter_by(extrainfo=s.extrainfo).filter_by(fingerprint=s.fingerprint).first() + db_service = session.query(serviceObj).filter_by(name=s.name).first() + if not db_service: + # print("Did not find service *********** name={0} prod={1} ver={2} extra={3} fing={4}".format(s.name, s.product, s.version, s.extrainfo, s.fingerprint)) + db_service = serviceObj(s.name, s.product, s.version, s.extrainfo, s.fingerprint) + session.add(db_service) + # else: + # print("FOUND service *************** name={0}".format(db_service.name)) + + else: # else, there is no service info to parse + db_service = None + # fetch the port + db_port = session.query(portObj).filter_by(hostId=db_host.id).filter_by(portId=p.portId).filter_by( + protocol=p.protocol).first() + + if not db_port: + # print("Did not find port *********** portid={0} proto={1}".format(p.portId, p.protocol)) + if db_service: + db_port = portObj(p.portId, p.protocol, p.state, db_host.id, db_service.id) + else: + db_port = portObj(p.portId, p.protocol, p.state, db_host.id, '') + session.add(db_port) + # else: + # print('FOUND port *************** portid={0}'.format(db_port.portId)) + createPortsProgress = createPortsProgress + ((100.0 / hostCount) / 5) + totalprogress = totalprogress + createPortsProgress + self.importProgressWidget.setProgress(totalprogress) + self.importProgressWidget.show() + + session.commit() + + # totalprogress += progress + # self.tick.emit(int(totalprogress)) + + for h in parser.getAllHosts(): # create all script objects that need to be created + db_host = session.query(hostObj).filter_by(ip=h.ip).first() + + for p in h.all_ports(): + for scr in p.getScripts(): + self.tsLog(" Processing script obj {scr}".format(scr=str(scr))) + print(" Processing script obj {scr}".format(scr=str(scr))) + db_port = session.query(portObj).filter_by(hostId=db_host.id).filter_by( + portId=p.portId).filter_by(protocol=p.protocol).first() + db_script = session.query(l1ScriptObj).filter_by(scriptId=scr.scriptId).filter_by( + portId=db_port.id).first() + + if not db_script: # if this script object doesn't exist, create it + t_l1ScriptObj = l1ScriptObj(scr.scriptId, scr.output, db_port.id, db_host.id) + self.tsLog(" Adding l1ScriptObj obj {script}".format(script=scr.scriptId)) + session.add(t_l1ScriptObj) + + for hs in h.getHostScripts(): + db_script = session.query(l1ScriptObj).filter_by(scriptId=hs.scriptId).filter_by( + hostId=db_host.id).first() + if not db_script: + t_l1ScriptObj = l1ScriptObj(hs.scriptId, hs.output, None, db_host.id) + session.add(t_l1ScriptObj) + + session.commit() + + for h in parser.getAllHosts(): # update everything + + db_host = session.query(hostObj).filter_by(ip=h.ip).first() + + if db_host.ipv4 == '' and not h.ipv4 == '': + db_host.ipv4 = h.ipv4 + if db_host.ipv6 == '' and not h.ipv6 == '': + db_host.ipv6 = h.ipv6 + if db_host.macaddr == '' and not h.macaddr == '': + db_host.macaddr = h.macaddr + if not h.status == '': + db_host.status = h.status + if db_host.hostname == '' and not h.hostname == '': + db_host.hostname = h.hostname + if db_host.vendor == '' and not h.vendor == '': + db_host.vendor = h.vendor + if db_host.uptime == '' and not h.uptime == '': + db_host.uptime = h.uptime + if db_host.lastboot == '' and not h.lastboot == '': + db_host.lastboot = h.lastboot + if db_host.distance == '' and not h.distance == '': + db_host.distance = h.distance + if db_host.state == '' and not h.state == '': + db_host.state = h.state + if db_host.count == '' and not h.count == '': + db_host.count = h.count + + session.add(db_host) + + tmp_name = '' + tmp_accuracy = '0' # TODO: check if better to convert to int for comparison + + os_nodes = h.getOs() + for os in os_nodes: + db_os = session.query(osObj).filter_by(hostId=db_host.id).filter_by(name=os.name).filter_by( + family=os.family).filter_by(generation=os.generation).filter_by(osType=os.osType).filter_by( + vendor=os.vendor).first() + + db_os.osAccuracy = os.accuracy # update the accuracy + + if not os.name == '': # get the most accurate OS match/accuracy to store it in the host table for easier access + if os.accuracy > tmp_accuracy: + tmp_name = os.name + tmp_accuracy = os.accuracy + + if os_nodes: # if there was operating system info to parse + + if not tmp_name == '' and not tmp_accuracy == '0': # update the current host with the most accurate OS match + db_host.osMatch = tmp_name + db_host.osAccuracy = tmp_accuracy + + session.add(db_host) + + for scr in h.getHostScripts(): + print("-----------------------Host SCR: {0}".format(scr.scriptId)) + db_host = session.query(hostObj).filter_by(ip=h.ip).first() + scrProcessorResults = scr.scriptSelector(db_host) + for scrProcessorResult in scrProcessorResults: + session.add(scrProcessorResult) + + for scr in h.getScripts(): + print("-----------------------SCR: {0}".format(scr.scriptId)) + db_host = session.query(hostObj).filter_by(ip=h.ip).first() + scrProcessorResults = scr.scriptSelector(db_host) + for scrProcessorResult in scrProcessorResults: + session.add(scrProcessorResult) + + for p in h.all_ports(): + s = p.getService() + if not (s is None): + # db_service = session.query(serviceObj).filter_by(name=s.name).filter_by(product=s.product).filter_by(version=s.version).filter_by(extrainfo=s.extrainfo).filter_by(fingerprint=s.fingerprint).first() + db_service = session.query(serviceObj).filter_by(name=s.name).first() + else: + db_service = None + # fetch the port + db_port = session.query(portObj).filter_by(hostId=db_host.id).filter_by(portId=p.portId).filter_by( + protocol=p.protocol).first() + if db_port: + # print("************************ Found {0}".format(db_port)) + + if db_port.state != p.state: + db_port.state = p.state + session.add(db_port) + + if not ( + db_service is None) and db_port.serviceId != db_service.id: # if there is some new service information, update it + db_port.serviceId = db_service.id + session.add(db_port) + + for scr in p.getScripts(): # store the script results (note that existing script outputs are also kept) + db_script = session.query(l1ScriptObj).filter_by(scriptId=scr.scriptId).filter_by( + portId=db_port.id).first() + + if not scr.output == '' and scr.output is not None: + db_script.output = scr.output + + session.add(db_script) + + totalprogress = 100 + self.importProgressWidget.setProgress(int(totalprogress)) + self.importProgressWidget.show() + + session.commit() + self.db.dbsemaphore.release() # we are done with the DB + self.tsLog('Finished in ' + str(time() - startTime) + ' seconds.') + self.done.emit() + self.importProgressWidget.hide() + self.schedule.emit(parser, + self.output == '') # call the scheduler (if there is no terminal output it means we imported nmap) + + except Exception as e: + self.tsLog('Something went wrong when parsing the nmap file..') + self.tsLog("Unexpected error: {0}".format(sys.exc_info()[0])) + self.tsLog(e) + raise + self.done.emit() diff --git a/app/importers/PythonImporter.py b/app/importers/PythonImporter.py new file mode 100644 index 0000000..f4cf99d --- /dev/null +++ b/app/importers/PythonImporter.py @@ -0,0 +1,72 @@ +""" +LEGION (https://govanguard.io) +Copyright (c) 2018 GoVanguard + + This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public + License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later + version. + + This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied + warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more + details. + + You should have received a copy of the GNU General Public License along with this program. + If not, see . + +Author(s): Dmitriy Dubson (d.dubson@gmail.com) +""" +from PyQt5 import QtCore + +from db.database import hostObj +from scripts.python import pyShodan +from ui.ancillaryDialog import ProgressWidget, time + + +class PythonImporter(QtCore.QThread): + tick = QtCore.pyqtSignal(int, name="changed") # New style signal + done = QtCore.pyqtSignal(name="done") # New style signal + schedule = QtCore.pyqtSignal(object, bool, name="schedule") # New style signal + log = QtCore.pyqtSignal(str, name="log") + + def __init__(self): + QtCore.QThread.__init__(self, parent=None) + self.output = '' + self.hostIp = '' + self.pythonScriptDispatch = {'pyShodan': pyShodan.PyShodanScript()} + self.pythonScriptObj = None + self.importProgressWidget = ProgressWidget('Importing shodan data..') + + def tsLog(self, msg): + self.log.emit(str(msg)) + + def setDB(self, db): + self.db = db + + def setHostIp(self, hostIp): + self.hostIp = hostIp + + def setPythonScript(self, pythonScript): + self.pythonScriptObj = self.pythonScriptDispatch[pythonScript] + + def setOutput(self, output): + self.output = output + + def run(self): # it is necessary to get the qprocess because we need to send it back to the scheduler when we're done importing + try: + session = self.db.session() + startTime = time() + self.db.dbsemaphore.acquire() # ensure that while this thread is running, no one else can write to the DB + #self.setPythonScript(self.pythonScript) + db_host = session.query(hostObj).filter_by(ip = self.hostIp).first() + self.pythonScriptObj.setDbHost(db_host) + self.pythonScriptObj.setSession(session) + self.pythonScriptObj.run() + session.commit() + self.db.dbsemaphore.release() # we are done with the DB + self.tsLog('Finished in ' + str(time() - startTime) + ' seconds.') + self.done.emit() + + except Exception as e: + self.tsLog(e) + raise + self.done.emit() diff --git a/app/importers/__init__.py b/app/importers/__init__.py new file mode 100644 index 0000000..bcdbc64 --- /dev/null +++ b/app/importers/__init__.py @@ -0,0 +1,17 @@ +""" +LEGION (https://govanguard.io) +Copyright (c) 2018 GoVanguard + + This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public + License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later + version. + + This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied + warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more + details. + + You should have received a copy of the GNU General Public License along with this program. + If not, see . + +Author(s): Dmitriy Dubson (d.dubson@gmail.com) +""" \ No newline at end of file diff --git a/app/logic.py b/app/logic.py index 658caf5..3c6a73f 100644 --- a/app/logic.py +++ b/app/logic.py @@ -24,11 +24,11 @@ from db.database import * from db.repositories.CVERepository import CVERepository from db.repositories.HostRepository import HostRepository +from db.repositories.NoteRepository import NoteRepository from db.repositories.PortRepository import PortRepository from db.repositories.ProcessRepository import ProcessRepository +from db.repositories.ScriptRepository import ScriptRepository from db.repositories.ServiceRepository import ServiceRepository -from parsers.Parser import * -from scripts.python import pyShodan from ui.ancillaryDialog import * @@ -40,11 +40,13 @@ def __init__(self, project_name: str, db: Database, shell: Shell): self.projectname = project_name log.info(project_name) self.createTemporaryFiles() # creates temporary files/folders used by SPARTA - self.service_repository: ServiceRepository = ServiceRepository(self.db) - self.process_repository: ProcessRepository = ProcessRepository(self.db, log) - self.host_repository: HostRepository = HostRepository(self.db) - self.port_repository: PortRepository = PortRepository(self.db) - self.cve_repository: CVERepository = CVERepository(self.db) + self.serviceRepository: ServiceRepository = ServiceRepository(self.db) + self.processRepository: ProcessRepository = ProcessRepository(self.db, log) + self.hostRepository: HostRepository = HostRepository(self.db) + self.portRepository: PortRepository = PortRepository(self.db) + self.cveRepository: CVERepository = CVERepository(self.db) + self.noteRepository: NoteRepository = NoteRepository(self.db, log) + self.scriptRepository: ScriptRepository = ScriptRepository(self.db) def createTemporaryFiles(self): try: @@ -104,7 +106,7 @@ def moveToolOutput(self, outputFilename): path = self.outputfolder+'/'+str(tool) if not os.path.exists(str(path)): os.makedirs(str(path)) - + # check if the outputFilename exists, if not try .xml and .txt extensions (different tools use different formats) if os.path.exists(str(outputFilename)) and os.path.isfile(str(outputFilename)): shutil.move(str(outputFilename), str(path)) @@ -201,530 +203,3 @@ def saveProjectAs(self, filename, replace=0, projectType = 'legion'): log.info('Something went wrong while saving the project..') log.info("Unexpected error: {0}".format(sys.exc_info()[0])) return False - - # get notes for given host IP - def getNoteFromDB(self, hostId): - session = self.db.session() - return session.query(note).filter_by(hostId=str(hostId)).first() - - # get script info for given host IP - def getScriptsFromDB(self, hostIP): - query = ('SELECT host.id, host.scriptId, port.portId, port.protocol FROM l1ScriptObj AS host ' + - 'INNER JOIN hostObj AS hosts ON hosts.id = host.hostId ' + - 'LEFT OUTER JOIN portObj AS port ON port.id = host.portId ' + - 'WHERE hosts.ip=?') - - return self.db.metadata.bind.execute(query, str(hostIP)).fetchall() - - def getScriptOutputFromDB(self, scriptDBId): - query = ('SELECT script.output FROM l1ScriptObj as script WHERE script.id = ?') - return self.db.metadata.bind.execute(query, str(scriptDBId)).fetchall() - - # used to delete all port/script data related to a host - to overwrite portscan info with the latest scan - def deleteAllPortsAndScriptsForHostFromDB(self, hostID, protocol): - session = self.db.session() - ports_for_host = session.query(portObj).filter(portObj.hostId == hostID).filter(portObj.protocol == str(protocol)).all() - for p in ports_for_host: - scripts_for_ports = session.query(l1ScriptObj).filter(l1ScriptObj.portId == p.id).all() - for s in scripts_for_ports: - session.delete(s) - for p in ports_for_host: - session.delete(p) - session.commit() - return - - def deleteHost(self, hostIP): - session = self.db.session() - h = session.query(hostObj).filter_by(ip=str(hostIP)).first() - session.delete(h) - session.commit() - return - - # this function returns all the processes from the DB - # the showProcesses flag is used to ensure we don't display processes in the process table after we have cleared them or when an existing project is opened. - # to speed up the queries we replace the columns we don't need by zeros (the reason we need all the columns is we are using the same model to display process information everywhere) - def getProcessesFromDB(self, filters, showProcesses='noNmap', sort = 'desc', ncol = 'id'): - if showProcesses == 'noNmap': # we do not fetch nmap processes because these are not displayed in the host tool tabs / tools - query = ('SELECT "0", "0", "0", process.name, "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0" FROM process AS process WHERE process.closed="False" AND process.name!="nmap" group by process.name') - result = self.db.metadata.bind.execute(query).fetchall() - - elif showProcesses == False: # when opening a project, fetch only the processes that have display=false and were not in tabs that were closed by the user - query = ('SELECT process.id, process.hostIp, process.tabTitle, process.outputfile, output.output FROM process AS process ' - 'INNER JOIN process_output AS output ON process.id = output.processId ' - 'WHERE process.display=? AND process.closed="False" order by process.id desc') - result = self.db.metadata.bind.execute(query, str(showProcesses)).fetchall() - - #query = ('SELECT process.id, process.hostIp, process.tabTitle, process.outputfile, output.output FROM process AS process ' - #'INNER JOIN process_output AS output ON process.id = output.processId ' - #'WHERE process.display=? AND process.closed="False" order by process.id desc') - - else: # show all the processes in the (bottom) process table (no matter their closed value) - query = ('SELECT * FROM process AS process WHERE process.display=? order by {0} {1}'.format(ncol, sort)) - result = self.db.metadata.bind.execute(query, str(showProcesses)).fetchall() - - return result - - def getHostsForTool(self, toolname, closed='False'): - if closed == 'FetchAll': - query = ('SELECT "0", "0", "0", "0", "0", process.hostIp, process.port, process.protocol, "0", "0", process.outputfile, "0", "0", "0" FROM process AS process WHERE process.name=?') - else: - query = ('SELECT process.id, "0", "0", "0", "0", "0", "0", process.hostIp, process.port, process.protocol, "0", "0", process.outputfile, "0", "0", "0" FROM process AS process WHERE process.name=? and process.closed="False"') - - return self.db.metadata.bind.execute(query, str(toolname)).fetchall() - - def toggleHostCheckStatus(self, ipaddr): - session = self.db.session() - h = session.query(hostObj).filter_by(ip=ipaddr).first() - if h: - if h.checked == 'False': - h.checked = 'True' - else: - h.checked = 'False' - session.add(h) - self.db.commit() - - # this function adds a new process to the DB - def addProcessToDB(self, proc): - log.info('Add process') - p_output = process_output() # add row to process_output table (separate table for performance reasons) - p = process(str(proc.pid()), str(proc.name), str(proc.tabTitle), str(proc.hostIp), str(proc.port), str(proc.protocol), unicode(proc.command), proc.startTime, "", str(proc.outputfile), 'Waiting', [p_output], 100, 0) - log.info(p) - session = self.db.session() - session.add(p) - self.db.commit() - proc.id = p.id - return p.id - - def addScreenshotToDB(self, ip, port, filename): - p_output = process_output() # add row to process_output table (separate table for performance reasons) - p = process(0, "screenshooter", "screenshot ("+str(port)+"/tcp)", str(ip), str(port), "tcp", "", getTimestamp(True), getTimestamp(True), str(filename), "Finished", [p_output], 2, 0) - session = self.db.session() - session.add(p) - session.commit() - return p.id - - # is not actually a toggle function. it sets all the non-running processes display flag to false to ensure they aren't shown in the process table - # but they need to be shown as tool tabs. this function is called when a user clears the processes or when a project is being closed. - def toggleProcessDisplayStatus(self, resetAll=False): - session = self.db.session() - proc = session.query(process).filter_by(display='True').all() - for p in proc: - session.add(self.toggle_process_status_field(p, resetAll)) - self.db.commit() - - def toggle_process_status_field(self, p, reset_all): - not_running = p.status != 'Running' - not_waiting = p.status != 'Waiting' - - if reset_all and not_running: - p.display = 'False' - else: - if not_running and not_waiting: - p.display = 'False' - - return p - - # this function updates the status of a process if it is killed - def storeProcessKillStatusInDB(self, procId): - session = self.db.session() - proc = session.query(process).filter_by(id=procId).first() - #proc = process.query.filter_by(id=procId).first() - if proc and not proc.status == 'Finished': - proc.status = 'Killed' - proc.endTime = getTimestamp(True) # store end time - session.add(proc) - #session.commit() - self.db.commit() - - def storeProcessCrashStatusInDB(self, procId): - session = self.db.session() - proc = session.query(process).filter_by(id=procId).first() - #proc = process.query.filter_by(id=procId).first() - if proc and not proc.status == 'Killed' and not proc.status == 'Cancelled': - proc.status = 'Crashed' - proc.endTime = getTimestamp(True) # store end time - session.add(proc) - #session.commit() - self.db.commit() - - # this function updates the status of a process if it is killed - def storeProcessCancelStatusInDB(self, procId): - session = self.db.session() - proc = session.query(process).filter_by(id=procId).first() - #proc = process.query.filter_by(id=procId).first() - if proc: - proc.status = 'Cancelled' - proc.endTime = getTimestamp(True) # store end time - session.add(proc) - #session.commit() - self.db.commit() - - def storeProcessRunningStatusInDB(self, procId, pid): - session = self.db.session() - proc = session.query(process).filter_by(id=procId).first() - #proc = process.query.filter_by(id=procId).first() - if proc: - proc.status = 'Running' - proc.pid = str(pid) - session.add(proc) - #session.commit() - self.db.commit() - - # change the status in the db as closed - def storeCloseTabStatusInDB(self, procId): - session = self.db.session() - proc = session.query(process).filter_by(id=procId).first() - #proc = process.query.filter_by(id=int(procId)).first() - if proc: - proc.closed = 'True' - session.add(proc) - #session.commit() - self.db.commit() - - # change the status in the db as closed - def storeProcessRunningElapsedInDB(self, procId, elapsed): - session = self.db.session() - proc = session.query(process).filter_by(id=procId).first() - if proc: - proc.elapsed = elapsed - session.add(proc) - self.db.commit() - - def storeNotesInDB(self, hostId, notes): - if len(notes) == 0: - notes = unicode("".format(hostId=hostId)) - log.debug("Storing notes for {hostId}, Notes {notes}".format(hostId=hostId, notes=notes)) - t_note = self.getNoteFromDB(hostId) - if t_note: - t_note.text = unicode(notes) - else: - t_note = note(hostId, unicode(notes)) - session = self.db.session() - session.add(t_note) - self.db.commit() - - -class PythonImporter(QtCore.QThread): - tick = QtCore.pyqtSignal(int, name="changed") # New style signal - done = QtCore.pyqtSignal(name="done") # New style signal - schedule = QtCore.pyqtSignal(object, bool, name="schedule") # New style signal - log = QtCore.pyqtSignal(str, name="log") - - def __init__(self): - QtCore.QThread.__init__(self, parent=None) - self.output = '' - self.hostIp = '' - self.pythonScriptDispatch = {'pyShodan': pyShodan.PyShodanScript()} - self.pythonScriptObj = None - self.importProgressWidget = ProgressWidget('Importing shodan data..') - - def tsLog(self, msg): - self.log.emit(str(msg)) - - def setDB(self, db): - self.db = db - - def setHostIp(self, hostIp): - self.hostIp = hostIp - - def setPythonScript(self, pythonScript): - self.pythonScriptObj = self.pythonScriptDispatch[pythonScript] - - def setOutput(self, output): - self.output = output - - def run(self): # it is necessary to get the qprocess because we need to send it back to the scheduler when we're done importing - try: - session = self.db.session() - startTime = time() - self.db.dbsemaphore.acquire() # ensure that while this thread is running, no one else can write to the DB - #self.setPythonScript(self.pythonScript) - db_host = session.query(hostObj).filter_by(ip = self.hostIp).first() - self.pythonScriptObj.setDbHost(db_host) - self.pythonScriptObj.setSession(session) - self.pythonScriptObj.run() - session.commit() - self.db.dbsemaphore.release() # we are done with the DB - self.tsLog('Finished in ' + str(time() - startTime) + ' seconds.') - self.done.emit() - - except Exception as e: - self.tsLog(e) - raise - self.done.emit() - - -class NmapImporter(QtCore.QThread): - tick = QtCore.pyqtSignal(int, name="changed") # New style signal - done = QtCore.pyqtSignal(name="done") # New style signal - schedule = QtCore.pyqtSignal(object, bool, name="schedule") # New style signal - log = QtCore.pyqtSignal(str, name="log") - - def __init__(self): - QtCore.QThread.__init__(self, parent=None) - self.output = '' - self.importProgressWidget = ProgressWidget('Importing nmap..') - - def tsLog(self, msg): - self.log.emit(str(msg)) - - def setDB(self, db): - self.db = db - - def setFilename(self, filename): - self.filename = filename - - def setOutput(self, output): - self.output = output - - def run(self): # it is necessary to get the qprocess because we need to send it back to the scheduler when we're done importing - try: - self.importProgressWidget.show() - session = self.db.session() - self.tsLog("Parsing nmap xml file: " + self.filename) - startTime = time() - - try: - parser = Parser(self.filename) - except: - self.tsLog('Giving up on import due to previous errors.') - self.tsLog("Unexpected error: {0}".format(sys.exc_info()[0])) - self.done.emit() - return - - self.db.dbsemaphore.acquire() # ensure that while this thread is running, no one else can write to the DB - s = parser.getSession() # nmap session info - if s: - n = nmapSessionObj(self.filename, s.startTime, s.finish_time, s.nmapVersion, s.scanArgs, s.totalHosts, s.upHosts, s.downHosts) - session.add(n) - hostCount = len(parser.getAllHosts()) - if hostCount==0: # to fix a division by zero if we ran nmap on one host - hostCount=1 - totalprogress = 0 - - self.importProgressWidget.setProgress(int(totalprogress)) - self.importProgressWidget.show() - - createProgress = 0 - createOsNodesProgress = 0 - createPortsProgress = 0 - - for h in parser.getAllHosts(): # create all the hosts that need to be created - db_host = session.query(hostObj).filter_by(ip=h.ip).first() - - if not db_host: # if host doesn't exist in DB, create it first - hid = hostObj(osMatch='', osAccuracy='', ip=h.ip, ipv4=h.ipv4, ipv6=h.ipv6, macaddr=h.macaddr, status=h.status, hostname=h.hostname, vendor=h.vendor, uptime=h.uptime, lastboot=h.lastboot, distance=h.distance, state=h.state, count=h.count) - self.tsLog("Adding db_host") - session.add(hid) - t_note = note(h.ip, 'Added by nmap') - session.add(t_note) - else: - self.tsLog("Found db_host already in db") - - createProgress = createProgress + ((100.0 / hostCount) / 5) - totalprogress = totalprogress + createProgress - self.importProgressWidget.setProgress(int(totalprogress)) - self.importProgressWidget.show() - - session.commit() - - for h in parser.getAllHosts(): # create all OS, service and port objects that need to be created - self.tsLog("Processing h {ip}".format(ip=h.ip)) - - db_host = session.query(hostObj).filter_by(ip=h.ip).first() - if db_host: - self.tsLog("Found db_host during os/ports/service processing") - else: - self.log("Did not find db_host during os/ports/service processing") - - os_nodes = h.getOs() # parse and store all the OS nodes - self.tsLog(" 'os_nodes' to process: {os_nodes}".format(os_nodes=str(len(os_nodes)))) - for os in os_nodes: - self.tsLog(" Processing os obj {os}".format(os=str(os.name))) - db_os = session.query(osObj).filter_by(hostId=db_host.id).filter_by(name=os.name).filter_by(family=os.family).filter_by(generation=os.generation).filter_by(osType=os.osType).filter_by(vendor=os.vendor).first() - - if not db_os: - t_osObj = osObj(os.name, os.family, os.generation, os.osType, os.vendor, os.accuracy, db_host.id) - session.add(t_osObj) - - createOsNodesProgress = createOsNodesProgress + ((100.0 / hostCount) / 5) - totalprogress = totalprogress + createOsNodesProgress - self.importProgressWidget.setProgress(int(totalprogress)) - self.importProgressWidget.show() - - session.commit() - - all_ports = h.all_ports() - self.tsLog(" 'ports' to process: {all_ports}".format(all_ports=str(len(all_ports)))) - for p in all_ports: # parse the ports - self.tsLog(" Processing port obj {port}".format(port=str(p.portId))) - s = p.getService() - - if not (s is None): # check if service already exists to avoid adding duplicates - #print(" Found service {service} for port {port}".format(service=str(s.name),port=str(p.portId))) - #db_service = session.query(serviceObj).filter_by(name=s.name).filter_by(product=s.product).filter_by(version=s.version).filter_by(extrainfo=s.extrainfo).filter_by(fingerprint=s.fingerprint).first() - db_service = session.query(serviceObj).filter_by(name=s.name).first() - if not db_service: - #print("Did not find service *********** name={0} prod={1} ver={2} extra={3} fing={4}".format(s.name, s.product, s.version, s.extrainfo, s.fingerprint)) - db_service = serviceObj(s.name, s.product, s.version, s.extrainfo, s.fingerprint) - session.add(db_service) - # else: - #print("FOUND service *************** name={0}".format(db_service.name)) - - else: # else, there is no service info to parse - db_service = None - # fetch the port - db_port = session.query(portObj).filter_by(hostId=db_host.id).filter_by(portId=p.portId).filter_by(protocol=p.protocol).first() - - if not db_port: - #print("Did not find port *********** portid={0} proto={1}".format(p.portId, p.protocol)) - if db_service: - db_port = portObj(p.portId, p.protocol, p.state, db_host.id, db_service.id) - else: - db_port = portObj(p.portId, p.protocol, p.state, db_host.id, '') - session.add(db_port) - #else: - #print('FOUND port *************** portid={0}'.format(db_port.portId)) - createPortsProgress = createPortsProgress + ((100.0 / hostCount) / 5) - totalprogress = totalprogress + createPortsProgress - self.importProgressWidget.setProgress(totalprogress) - self.importProgressWidget.show() - - session.commit() - - #totalprogress += progress - #self.tick.emit(int(totalprogress)) - - for h in parser.getAllHosts(): # create all script objects that need to be created - db_host = session.query(hostObj).filter_by(ip=h.ip).first() - - for p in h.all_ports(): - for scr in p.getScripts(): - self.tsLog(" Processing script obj {scr}".format(scr=str(scr))) - print(" Processing script obj {scr}".format(scr=str(scr))) - db_port = session.query(portObj).filter_by(hostId=db_host.id).filter_by(portId=p.portId).filter_by(protocol=p.protocol).first() - db_script = session.query(l1ScriptObj).filter_by(scriptId=scr.scriptId).filter_by(portId=db_port.id).first() - - if not db_script: # if this script object doesn't exist, create it - t_l1ScriptObj = l1ScriptObj(scr.scriptId, scr.output, db_port.id, db_host.id) - self.tsLog(" Adding l1ScriptObj obj {script}".format(script=scr.scriptId)) - session.add(t_l1ScriptObj) - - for hs in h.getHostScripts(): - db_script = session.query(l1ScriptObj).filter_by(scriptId=hs.scriptId).filter_by(hostId=db_host.id).first() - if not db_script: - t_l1ScriptObj = l1ScriptObj(hs.scriptId, hs.output, None, db_host.id) - session.add(t_l1ScriptObj) - - session.commit() - - for h in parser.getAllHosts(): # update everything - - db_host = session.query(hostObj).filter_by(ip=h.ip).first() - - if db_host.ipv4 == '' and not h.ipv4 == '': - db_host.ipv4 = h.ipv4 - if db_host.ipv6 == '' and not h.ipv6 == '': - db_host.ipv6 = h.ipv6 - if db_host.macaddr == '' and not h.macaddr == '': - db_host.macaddr = h.macaddr - if not h.status == '': - db_host.status = h.status - if db_host.hostname == '' and not h.hostname == '': - db_host.hostname = h.hostname - if db_host.vendor == '' and not h.vendor == '': - db_host.vendor = h.vendor - if db_host.uptime == '' and not h.uptime == '': - db_host.uptime = h.uptime - if db_host.lastboot == '' and not h.lastboot == '': - db_host.lastboot = h.lastboot - if db_host.distance == '' and not h.distance == '': - db_host.distance = h.distance - if db_host.state == '' and not h.state == '': - db_host.state = h.state - if db_host.count == '' and not h.count == '': - db_host.count = h.count - - session.add(db_host) - - tmp_name = '' - tmp_accuracy = '0' # TODO: check if better to convert to int for comparison - - os_nodes = h.getOs() - for os in os_nodes: - db_os = session.query(osObj).filter_by(hostId=db_host.id).filter_by(name=os.name).filter_by(family=os.family).filter_by(generation=os.generation).filter_by(osType=os.osType).filter_by(vendor=os.vendor).first() - - db_os.osAccuracy = os.accuracy # update the accuracy - - if not os.name == '': # get the most accurate OS match/accuracy to store it in the host table for easier access - if os.accuracy > tmp_accuracy: - tmp_name = os.name - tmp_accuracy = os.accuracy - - if os_nodes: # if there was operating system info to parse - - if not tmp_name == '' and not tmp_accuracy == '0': # update the current host with the most accurate OS match - db_host.osMatch = tmp_name - db_host.osAccuracy = tmp_accuracy - - session.add(db_host) - - for scr in h.getHostScripts(): - print("-----------------------Host SCR: {0}".format(scr.scriptId)) - db_host = session.query(hostObj).filter_by(ip=h.ip).first() - scrProcessorResults = scr.scriptSelector(db_host) - for scrProcessorResult in scrProcessorResults: - session.add(scrProcessorResult) - - for scr in h.getScripts(): - print("-----------------------SCR: {0}".format(scr.scriptId)) - db_host = session.query(hostObj).filter_by(ip=h.ip).first() - scrProcessorResults = scr.scriptSelector(db_host) - for scrProcessorResult in scrProcessorResults: - session.add(scrProcessorResult) - - for p in h.all_ports(): - s = p.getService() - if not (s is None): - #db_service = session.query(serviceObj).filter_by(name=s.name).filter_by(product=s.product).filter_by(version=s.version).filter_by(extrainfo=s.extrainfo).filter_by(fingerprint=s.fingerprint).first() - db_service = session.query(serviceObj).filter_by(name=s.name).first() - else: - db_service = None - # fetch the port - db_port = session.query(portObj).filter_by(hostId=db_host.id).filter_by(portId=p.portId).filter_by(protocol=p.protocol).first() - if db_port: - #print("************************ Found {0}".format(db_port)) - - if db_port.state != p.state: - db_port.state = p.state - session.add(db_port) - - if not (db_service is None) and db_port.serviceId != db_service.id: # if there is some new service information, update it - db_port.serviceId = db_service.id - session.add(db_port) - - for scr in p.getScripts(): # store the script results (note that existing script outputs are also kept) - db_script = session.query(l1ScriptObj).filter_by(scriptId=scr.scriptId).filter_by(portId=db_port.id).first() - - if not scr.output == '' and scr.output is not None: - db_script.output = scr.output - - session.add(db_script) - - totalprogress = 100 - self.importProgressWidget.setProgress(int(totalprogress)) - self.importProgressWidget.show() - - session.commit() - self.db.dbsemaphore.release() # we are done with the DB - self.tsLog('Finished in '+ str(time()-startTime) + ' seconds.') - self.done.emit() - self.importProgressWidget.hide() - self.schedule.emit(parser, self.output == '') # call the scheduler (if there is no terminal output it means we imported nmap) - - except Exception as e: - self.tsLog('Something went wrong when parsing the nmap file..') - self.tsLog("Unexpected error: {0}".format(sys.exc_info()[0])) - self.tsLog(e) - raise - self.done.emit() diff --git a/controller/controller.py b/controller/controller.py index bce0a2a..bbe387d 100644 --- a/controller/controller.py +++ b/controller/controller.py @@ -11,16 +11,19 @@ You should have received a copy of the GNU General Public License along with this program. If not, see . ''' -import sys, os, ntpath, signal, re, subprocess # for file operations, to kill processes, for regex, for subprocesses +import signal # for file operations, to kill processes, for regex, for subprocesses + +from app.importers.NmapImporter import NmapImporter +from app.importers.PythonImporter import PythonImporter + try: import queue except: import Queue as queue -from PyQt5.QtGui import * # for filters dialog from app.logic import * -from app.auxiliary import * from app.settings import * + class Controller(): # initialisations that will happen once - when the program is launched @@ -195,7 +198,7 @@ def openExistingProject(self, filename, projectType='legion'): def saveProject(self, lastHostIdClicked, notes): if not lastHostIdClicked == '': - self.logic.storeNotesInDB(lastHostIdClicked, notes) + self.logic.noteRepository.storeNotes(lastHostIdClicked, notes) def saveProjectAs(self, filename, replace=0): success = self.logic.saveProjectAs(filename, replace) @@ -207,7 +210,7 @@ def closeProject(self): self.saveSettings() # backup and save config file, if necessary self.screenshooter.terminate() self.initScreenshooter() - self.logic.toggleProcessDisplayStatus(True) + self.logic.processRepository.toggleProcessDisplayStatus(True) self.view.updateProcessesTableView() # clear process table self.logic.removeTemporaryFiles() @@ -271,16 +274,16 @@ def getContextMenuForHost(self, isChecked, showAll=True): # showAll ex def handleHostAction(self, ip, hostid, actions, action): if action.text() == 'Mark as checked' or action.text() == 'Mark as unchecked': - self.logic.toggleHostCheckStatus(ip) + self.logic.hostRepository.toggleHostCheckStatus(ip) self.view.updateInterface() return if action.text() == 'Run nmap (staged)': log.info('Purging previous portscan data for ' + str(ip)) # if we are running nmap we need to purge previous portscan results - if self.logic.port_repository.get_ports_by_ip_and_protocol(ip, 'tcp'): - self.logic.deleteAllPortsAndScriptsForHostFromDB(hostid, 'tcp') - if self.logic.port_repository.get_ports_by_ip_and_protocol(ip, 'udp'): - self.logic.deleteAllPortsAndScriptsForHostFromDB(hostid, 'udp') + if self.logic.portRepository.getPortsByIPAndProtocol(ip, 'tcp'): + self.logic.portRepository.deleteAllPortsAndScriptsByHostId(hostid, 'tcp') + if self.logic.portRepository.getPortsByIPAndProtocol(ip, 'udp'): + self.logic.portRepository.deleteAllPortsAndScriptsByHostId(hostid, 'udp') self.view.updateInterface() self.runStagedNmap(ip, False) return @@ -292,20 +295,20 @@ def handleHostAction(self, ip, hostid, actions, action): if action.text() == 'Purge Results': log.info('Purging previous portscan data for host {0}'.format(str(ip))) - if self.logic.port_repository.get_ports_by_ip_and_protocol(ip, 'tcp'): - self.logic.deleteAllPortsAndScriptsForHostFromDB(hostid, 'tcp') - if self.logic.port_repository.get_ports_by_ip_and_protocol(ip, 'udp'): - self.logic.deleteAllPortsAndScriptsForHostFromDB(hostid, 'udp') + if self.logic.portRepository.getPortsByIPAndProtocol(ip, 'tcp'): + self.logic.portRepository.deleteAllPortsAndScriptsByHostId(hostid, 'tcp') + if self.logic.portRepository.getPortsByIPAndProtocol(ip, 'udp'): + self.logic.portRepository.deleteAllPortsAndScriptsByHostId(hostid, 'udp') self.view.updateInterface() return if action.text() == 'Delete': log.info('Purging previous portscan data for host {0}'.format(str(ip))) - if self.logic.port_repository.get_ports_by_ip_and_protocol(ip, 'tcp'): - self.logic.deleteAllPortsAndScriptsForHostFromDB(hostid, 'tcp') - if self.logic.port_repository.get_ports_by_ip_and_protocol(ip, 'udp'): - self.logic.deleteAllPortsAndScriptsForHostFromDB(hostid, 'udp') - self.logic.deleteHost(ip) + if self.logic.portRepository.getPortsByIPAndProtocol(ip, 'tcp'): + self.logic.portRepository.deleteAllPortsAndScriptsByHostId(hostid, 'tcp') + if self.logic.portRepository.getPortsByIPAndProtocol(ip, 'udp'): + self.logic.portRepository.deleteAllPortsAndScriptsByHostId(hostid, 'udp') + self.logic.hostRepository.deleteHost(ip) self.view.updateInterface() return @@ -328,8 +331,8 @@ def handleHostAction(self, ip, hostid, actions, action): if '-sU' in command: proto = 'udp' - if self.logic.port_repository.get_ports_by_ip_and_protocol(ip, proto): # if we are running nmap we need to purge previous portscan results (of the same protocol) - self.logic.deleteAllPortsAndScriptsForHostFromDB(hostid, proto) + if self.logic.portRepository.getPortsByIPAndProtocol(ip, proto): # if we are running nmap we need to purge previous portscan results (of the same protocol) + self.logic.portRepository.deleteAllPortsAndScriptsByHostId(hostid, proto) tabTitle = self.settings.hostActions[i][1] self.runCommand(name, tabTitle, ip, '','', command, getTimestamp(True), outputfile, self.view.createNewTabForHost(ip, tabTitle, invisibleTab)) @@ -460,9 +463,9 @@ def handleProcessAction(self, selectedProcesses, action): # selectedPr for p in selectedProcesses: if p[1]!="Running": if p[1]=="Waiting": - if str(self.logic.process_repository.get_status_by_process_id(p[2])) == 'Running': + if str(self.logic.processRepository.getStatusByProcessId(p[2])) == 'Running': self.killProcess(self.view.ProcessesTableModel.getProcessPidForId(p[2]), p[2]) - self.logic.storeProcessCancelStatusInDB(str(p[2])) + self.logic.processRepository.storeProcessCancelStatus(str(p[2])) else: log.info("This process has already been terminated. Skipping.") else: @@ -471,65 +474,65 @@ def handleProcessAction(self, selectedProcesses, action): # selectedPr return if action.text() == 'Clear': # h.ide all the processes that are not running - self.logic.toggleProcessDisplayStatus() + self.logic.processRepository.toggleProcessDisplayStatus() self.view.updateProcessesTableView() #################### LEFT PANEL INTERFACE UPDATE FUNCTIONS #################### def isHostInDB(self, host): - return self.logic.host_repository.exists(host) + return self.logic.hostRepository.exists(host) def getHostsFromDB(self, filters): - return self.logic.host_repository.get_hosts(filters) + return self.logic.hostRepository.getHosts(filters) def getServiceNamesFromDB(self, filters): - return self.logic.service_repository.get_service_names(filters) + return self.logic.serviceRepository.getServiceNames(filters) def getProcessStatusForDBId(self, dbId): - return self.logic.process_repository.get_status_by_process_id(dbId) + return self.logic.processRepository.getStatusByProcessId(dbId) def getPidForProcess(self,dbId): - return self.logic.process_repository.get_pid_by_process_id(dbId) + return self.logic.processRepository.getPIDByProcessId(dbId) def storeCloseTabStatusInDB(self,pid): - return self.logic.storeCloseTabStatusInDB(pid) + return self.logic.processRepository.storeCloseStatus(pid) def getServiceNameForHostAndPort(self, hostIP, port): - return self.logic.service_repository.get_service_names_by_host_ip_and_port(hostIP, port) + return self.logic.serviceRepository.getServiceNamesByHostIPAndPort(hostIP, port) #################### RIGHT PANEL INTERFACE UPDATE FUNCTIONS #################### def getPortsAndServicesForHostFromDB(self, hostIP, filters): - return self.logic.port_repository.get_ports_and_services_by_host_ip(hostIP, filters) + return self.logic.portRepository.getPortsAndServicesByHostIP(hostIP, filters) def getHostsAndPortsForServiceFromDB(self, serviceName, filters): - return self.logic.host_repository.get_hosts_and_ports_by_service_name(serviceName, filters) + return self.logic.hostRepository.getHostsAndPortsByServiceName(serviceName, filters) def getHostInformation(self, hostIP): - return self.logic.host_repository.get_host_information(hostIP) + return self.logic.hostRepository.getHostInformation(hostIP) def getPortStatesForHost(self, hostid): - return self.logic.port_repository.get_port_states_by_host_id(hostid) + return self.logic.portRepository.getPortStatesByHostId(hostid) def getScriptsFromDB(self, hostIP): - return self.logic.getScriptsFromDB(hostIP) + return self.logic.scriptRepository.getScriptsByHostIP(hostIP) def getCvesFromDB(self, hostIP): - return self.logic.cve_repository.get_cves_by_host_ip(hostIP) + return self.logic.cveRepository.getCVEsByHostIP(hostIP) - def getScriptOutputFromDB(self,scriptDBId): - return self.logic.getScriptOutputFromDB(scriptDBId) + def getScriptOutputFromDB(self, scriptDBId): + return self.logic.scriptRepository.getScriptOutputById(scriptDBId) def getNoteFromDB(self, hostid): - return self.logic.getNoteFromDB(hostid) + return self.logic.noteRepository.getNoteByHostId(hostid) - def getHostsForTool(self, toolname, closed = 'False'): - return self.logic.getHostsForTool(toolname, closed) + def getHostsForTool(self, toolName, closed='False'): + return self.logic.processRepository.getHostsByToolName(toolName, closed) #################### BOTTOM PANEL INTERFACE UPDATE FUNCTIONS #################### - def getProcessesFromDB(self, filters, showProcesses = 'noNmap', sort = 'desc', ncol = 'id'): - return self.logic.getProcessesFromDB(filters, showProcesses, sort, ncol) + def getProcessesFromDB(self, filters, showProcesses='noNmap', sort='desc', ncol='id'): + return self.logic.processRepository.getProcesses(filters, showProcesses, sort, ncol) #################### PROCESSES #################### @@ -542,7 +545,7 @@ def checkProcessQueue(self): self.processTableUiUpdateTimer.start(1000) if (self.fastProcessesRunning <= int(self.settings.general_max_fast_processes)): next_proc = self.fastProcessQueue.get() - if not self.logic.process_repository.is_cancelled_process(str(next_proc.id)): + if not self.logic.processRepository.isCancelledProcess(str(next_proc.id)): log.debug('Running: '+ str(next_proc.command)) next_proc.display.clear() self.processes.append(next_proc) @@ -550,7 +553,7 @@ def checkProcessQueue(self): # Add Timeout next_proc.waitForFinished(10) next_proc.start(next_proc.command) - self.logic.storeProcessRunningStatusInDB(next_proc.id, next_proc.pid()) + self.logic.processRepository.storeProcessRunningStatus(next_proc.id, next_proc.pid()) elif not self.fastProcessQueue.empty(): log.debug('> next process was canceled, checking queue again..') self.checkProcessQueue() @@ -560,13 +563,13 @@ def checkProcessQueue(self): def cancelProcess(self, dbId): log.info('Canceling process: ' + str(dbId)) - self.logic.storeProcessCancelStatusInDB(str(dbId)) # mark it as cancelled + self.logic.processRepository.storeProcessCancelStatus(str(dbId)) # mark it as cancelled self.updateUITimer.stop() self.updateUITimer.start(1500) # update the interface soon def killProcess(self, pid, dbId): log.info('Killing process: ' + str(pid)) - self.logic.storeProcessKillStatusInDB(str(dbId)) # mark it as killed + self.logic.processRepository.storeProcessKillStatus(str(dbId)) try: os.kill(int(pid), signal.SIGTERM) except OSError: @@ -588,7 +591,7 @@ def handleProcStop(*vargs): self.processTimers[qProcess.id] = None procTime = timer.elapsed() / 1000 qProcess.elapsed = procTime - self.logic.storeProcessRunningElapsedInDB(qProcess.id, procTime) + self.logic.processRepository.storeProcessRunningElapsedTime(qProcess.id, procTime) def handleProcUpdate(*vargs): procTime = timer.elapsed() / 1000 @@ -612,7 +615,7 @@ def handleProcUpdate(*vargs): qProcess.finished.connect(handleProcStop) updateElapsed.timeout.connect(handleProcUpdate) - textbox.setProperty('dbId', str(self.logic.addProcessToDB(qProcess))) + textbox.setProperty('dbId', str(self.logic.processRepository.storeProcess(qProcess))) updateElapsed.start(1000) self.processTimers[qProcess.id] = updateElapsed self.processMeasurements[qProcess.pid()] = 0 @@ -638,7 +641,7 @@ def handleProcUpdate(*vargs): nextStage = stage + 1 qProcess.finished.connect( lambda: self.runStagedNmap(str(hostIp), discovery=discovery, stage=nextStage, - stop=self.logic.process_repository.is_killed_process(str(qProcess.id)))) + stop=self.logic.processRepository.isKilledProcess(str(qProcess.id)))) return qProcess.pid() # return the pid so that we can kill the process if needed @@ -654,7 +657,7 @@ def runPython(self): outputfile = '/tmp/a' qProcess = MyQProcess(name, tabTitle, hostIp, port, protocol, command, startTime, outputfile, textbox) - textbox.setProperty('dbId', str(self.logic.addProcessToDB(qProcess))) + textbox.setProperty('dbId', str(self.logic.processRepository.storeProcess(qProcess))) log.info('Queuing: ' + str(command)) self.fastProcessQueue.put(qProcess) @@ -718,7 +721,7 @@ def importFinished(self): self.view.displayAddHostsOverlay(False) # if nmap import was the first action, we need to hide the overlay (note: we shouldn't need to do this everytime. this can be improved) def screenshotFinished(self, ip, port, filename): - dbId = self.logic.addScreenshotToDB(str(ip),str(port),str(filename)) + dbId = self.logic.processRepository.storeScreenshot(str(ip), str(port), str(filename)) imageviewer = self.view.createNewTabForHost(ip, 'screenshot ('+port+'/tcp)', True, '', str(self.logic.outputfolder)+'/screenshots/'+str(filename)) imageviewer.setProperty('dbId', QVariant(str(dbId))) self.view.switchTabClick() # to make sure the screenshot tab appears when it is launched from the host services tab @@ -726,18 +729,18 @@ def screenshotFinished(self, ip, port, filename): self.updateUITimer.start(900) def processCrashed(self, proc): - self.logic.storeProcessCrashStatusInDB(str(proc.id)) + self.logic.processRepository.storeProcessCrashStatus(str(proc.id)) log.info('Process {qProcessId} Crashed!'.format(qProcessId=str(proc.id))) qProcessOutput = "\n\t" + str(proc.display.toPlainText()).replace('\n','').replace("b'","") #self.view.closeHostToolTab(self, index)) - self.view.findFinishedServiceTab(str(self.logic.process_repository.get_pid_by_process_id(str(proc.id)))) + self.view.findFinishedServiceTab(str(self.logic.processRepository.getPIDByProcessId(str(proc.id)))) log.info('Process {qProcessId} Output: {qProcessOutput}'.format(qProcessId=str(proc.id), qProcessOutput=qProcessOutput)) # this function handles everything after a process ends #def processFinished(self, qProcess, crashed=False): def processFinished(self, qProcess): try: - if not self.logic.process_repository.is_killed_process(str(qProcess.id)): # if process was not killed + if not self.logic.processRepository.isKilledProcess(str(qProcess.id)): # if process was not killed if not qProcess.outputfile == '': self.logic.moveToolOutput(qProcess.outputfile) # move tool output from runningfolder to output folder if there was an output file print(qProcess.command) @@ -762,10 +765,10 @@ def processFinished(self, qProcess): log.info("Process {qProcessId} is done!".format(qProcessId=qProcess.id)) - self.logic.process_repository.store_process_output(str(qProcess.id), qProcess.display.toPlainText()) + self.logic.processRepository.storeProcessOutput(str(qProcess.id), qProcess.display.toPlainText()) if 'hydra' in qProcess.name: # find the corresponding widget and tell it to update its UI - self.view.findFinishedBruteTab(str(self.logic.process_repository.get_pid_by_process_id(str(qProcess.id)))) + self.view.findFinishedBruteTab(str(self.logic.processRepository.getPIDByProcessId(str(qProcess.id)))) try: self.fastProcessesRunning =- 1 diff --git a/db/filters.py b/db/filters.py index d2db73e..435069f 100644 --- a/db/filters.py +++ b/db/filters.py @@ -18,14 +18,14 @@ from app.auxiliary import sanitise -def apply_filters(filters): +def applyFilters(filters): query_filter = "" - query_filter += apply_hosts_filters(filters) - query_filter += apply_port_filters(filters) + query_filter += applyHostsFilters(filters) + query_filter += applyPortFilters(filters) return query_filter -def apply_hosts_filters(filters): +def applyHostsFilters(filters): query_filter = "" if not filters.down: query_filter += " AND hosts.status != 'down'" @@ -40,7 +40,7 @@ def apply_hosts_filters(filters): return query_filter -def apply_port_filters(filters): +def applyPortFilters(filters): query_filter = "" if not filters.portopen: query_filter += " AND ports.state != 'open' AND ports.state != 'open|filtered'" diff --git a/db/repositories/CVERepository.py b/db/repositories/CVERepository.py index 9a1cc63..e40186c 100644 --- a/db/repositories/CVERepository.py +++ b/db/repositories/CVERepository.py @@ -19,12 +19,12 @@ class CVERepository: - def __init__(self, db_adapter: Database): - self.db_adapter = db_adapter + def __init__(self, dbAdapter: Database): + self.dbAdapter = dbAdapter - def get_cves_by_host_ip(self, host_ip): + def getCVEsByHostIP(self, hostIP): query = ('SELECT cves.name, cves.severity, cves.product, cves.version, cves.url, cves.source, ' 'cves.exploitId, cves.exploit, cves.exploitUrl FROM cve AS cves ' 'INNER JOIN hostObj AS hosts ON hosts.id = cves.hostId ' 'WHERE hosts.ip = ?') - return self.db_adapter.metadata.bind.execute(query, str(host_ip)).fetchall() + return self.dbAdapter.metadata.bind.execute(query, str(hostIP)).fetchall() diff --git a/db/repositories/HostRepository.py b/db/repositories/HostRepository.py index 1ac313e..3067286 100644 --- a/db/repositories/HostRepository.py +++ b/db/repositories/HostRepository.py @@ -17,33 +17,50 @@ """ from app.auxiliary import Filters from db.database import Database, hostObj -from db.filters import apply_filters, apply_hosts_filters +from db.filters import applyFilters, applyHostsFilters class HostRepository: - def __init__(self, db_adapter: Database): - self.db_adapter = db_adapter + def __init__(self, dbAdapter: Database): + self.dbAdapter = dbAdapter def exists(self, host: str): query = 'SELECT host.ip FROM hostObj AS host WHERE host.ip == ? OR host.hostname == ?' - result = self.db_adapter.metadata.bind.execute(query, str(host), str(host)).fetchall() + result = self.dbAdapter.metadata.bind.execute(query, str(host), str(host)).fetchall() return True if result else False - def get_hosts(self, filters): + def getHosts(self, filters): query = 'SELECT * FROM hostObj AS hosts WHERE 1=1' - query += apply_hosts_filters(filters) - return self.db_adapter.metadata.bind.execute(query).fetchall() + query += applyHostsFilters(filters) + return self.dbAdapter.metadata.bind.execute(query).fetchall() - def get_hosts_and_ports_by_service_name(self, service_name, filters: Filters): + def getHostsAndPortsByServiceName(self, service_name, filters: Filters): query = ("SELECT hosts.ip,ports.portId,ports.protocol,ports.state,ports.hostId,ports.serviceId," "services.name,services.product,services.version,services.extrainfo,services.fingerprint " "FROM portObj AS ports " + "INNER JOIN hostObj AS hosts ON hosts.id = ports.hostId " + "LEFT OUTER JOIN serviceObj AS services ON services.id=ports.serviceId " + "WHERE services.name=?") - query += apply_filters(filters) - return self.db_adapter.metadata.bind.execute(query, str(service_name)).fetchall() + query += applyFilters(filters) + return self.dbAdapter.metadata.bind.execute(query, str(service_name)).fetchall() - def get_host_information(self, host_ip_address: str): - session = self.db_adapter.session() + def getHostInformation(self, host_ip_address: str): + session = self.dbAdapter.session() return session.query(hostObj).filter_by(ip=str(host_ip_address)).first() + + def deleteHost(self, hostIP): + session = self.dbAdapter.session() + h = session.query(hostObj).filter_by(ip=str(hostIP)).first() + session.delete(h) + session.commit() + + def toggleHostCheckStatus(self, ipAddress): + session = self.dbAdapter.session() + host = session.query(hostObj).filter_by(ip=ipAddress).first() + if host: + if host.checked == 'False': + host.checked = 'True' + else: + host.checked = 'False' + session.add(host) + self.dbAdapter.commit() diff --git a/db/repositories/NoteRepository.py b/db/repositories/NoteRepository.py new file mode 100644 index 0000000..26e66b7 --- /dev/null +++ b/db/repositories/NoteRepository.py @@ -0,0 +1,41 @@ +""" +LEGION (https://govanguard.io) +Copyright (c) 2018 GoVanguard + + This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public + License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later + version. + + This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied + warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more + details. + + You should have received a copy of the GNU General Public License along with this program. + If not, see . + +Author(s): Dmitriy Dubson (d.dubson@gmail.com) +""" +from db.database import Database, note +from six import u as unicode + + +class NoteRepository: + def __init__(self, dbAdapter: Database, log): + self.dbAdapter = dbAdapter + self.log = log + + def getNoteByHostId(self, hostId): + return self.dbAdapter.session().query(note).filter_by(hostId=str(hostId)).first() + + def storeNotes(self, hostId, notes): + if len(notes) == 0: + notes = unicode("".format(hostId=hostId)) + self.log.debug("Storing notes for {hostId}, Notes {notes}".format(hostId=hostId, notes=notes)) + t_note = self.getNoteByHostId(hostId) + if t_note: + t_note.text = unicode(notes) + else: + t_note = note(hostId, unicode(notes)) + session = self.dbAdapter.session() + session.add(t_note) + self.dbAdapter.commit() diff --git a/db/repositories/PortRepository.py b/db/repositories/PortRepository.py index e297769..aab2748 100644 --- a/db/repositories/PortRepository.py +++ b/db/repositories/PortRepository.py @@ -15,27 +15,42 @@ Author(s): Dmitriy Dubson (d.dubson@gmail.com) """ -from db.database import Database -from db.filters import apply_port_filters +from db.database import Database, portObj, l1ScriptObj +from db.filters import applyPortFilters class PortRepository: - def __init__(self, db_adapter: Database): - self.db_adapter = db_adapter + def __init__(self, dbAdapter: Database): + self.dbAdapter = dbAdapter - def get_ports_by_ip_and_protocol(self, host_ip, protocol): + def getPortsByIPAndProtocol(self, host_ip, protocol): query = ("SELECT ports.portId FROM portObj AS ports INNER JOIN hostObj AS hosts ON hosts.id = ports.hostId " "WHERE hosts.ip = ? and ports.protocol = ?") - return self.db_adapter.metadata.bind.execute(query, str(host_ip), str(protocol)).first() + return self.dbAdapter.metadata.bind.execute(query, str(host_ip), str(protocol)).first() - def get_port_states_by_host_id(self, host_id): + def getPortStatesByHostId(self, host_id): query = 'SELECT port.state FROM portObj as port WHERE port.hostId = ?' - return self.db_adapter.metadata.bind.execute(query, str(host_id)).fetchall() + return self.dbAdapter.metadata.bind.execute(query, str(host_id)).fetchall() - def get_ports_and_services_by_host_ip(self, host_ip, filters): + def getPortsAndServicesByHostIP(self, host_ip, filters): query = ("SELECT hosts.ip, ports.portId, ports.protocol, ports.state, ports.hostId, ports.serviceId, " "services.name, services.product, services.version, services.extrainfo, services.fingerprint " "FROM portObj AS ports INNER JOIN hostObj AS hosts ON hosts.id = ports.hostId " "LEFT OUTER JOIN serviceObj AS services ON services.id = ports.serviceId WHERE hosts.ip = ?") - query += apply_port_filters(filters) - return self.db_adapter.metadata.bind.execute(query, str(host_ip)).fetchall() + query += applyPortFilters(filters) + return self.dbAdapter.metadata.bind.execute(query, str(host_ip)).fetchall() + + # used to delete all port/script data related to a host - to overwrite portscan info with the latest scan + def deleteAllPortsAndScriptsByHostId(self, hostID, protocol): + session = self.dbAdapter.session() + ports_for_host = session.query(portObj)\ + .filter(portObj.hostId == hostID)\ + .filter(portObj.protocol == str(protocol)).all() + + for p in ports_for_host: + scripts_for_ports = session.query(l1ScriptObj).filter(l1ScriptObj.portId == p.id).all() + for s in scripts_for_ports: + session.delete(s) + for p in ports_for_host: + session.delete(p) + session.commit() diff --git a/db/repositories/ProcessRepository.py b/db/repositories/ProcessRepository.py index 4ae9bb5..2c0bf4c 100644 --- a/db/repositories/ProcessRepository.py +++ b/db/repositories/ProcessRepository.py @@ -16,17 +16,52 @@ Author(s): Dmitriy Dubson (d.dubson@gmail.com) """ from app.auxiliary import getTimestamp -from db.database import * from six import u as unicode +from db.database import process, process_output, Database + class ProcessRepository: - def __init__(self, db_adapter: Database, log): - self.db_adapter = db_adapter + def __init__(self, dbAdapter: Database, log): + self.dbAdapter = dbAdapter self.log = log - def store_process_output(self, process_id: str, output: str): - session = self.db_adapter.session() + # the showProcesses flag is used to ensure we don't display processes in the process table after we have cleared + # them or when an existing project is opened. + # to speed up the queries we replace the columns we don't need by zeros (the reason we need all the columns is + # we are using the same model to display process information everywhere) + def getProcesses(self, filters, showProcesses: str = 'noNmap', sort: str = 'desc', ncol: str = 'id'): + # we do not fetch nmap processes because these are not displayed in the host tool tabs / tools + if showProcesses == 'noNmap': + query = ('SELECT "0", "0", "0", process.name, "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0" ' + 'FROM process AS process WHERE process.closed = "False" AND process.name != "nmap" ' + 'GROUP BY process.name') + result = self.dbAdapter.metadata.bind.execute(query).fetchall() + elif showProcesses == 'False': + query = ('SELECT process.id, process.hostIp, process.tabTitle, process.outputfile, output.output ' + 'FROM process AS process INNER JOIN process_output AS output ON process.id = output.processId ' + 'WHERE process.display = ? AND process.closed = "False" order by process.id desc') + result = self.dbAdapter.metadata.bind.execute(query, str(showProcesses)).fetchall() + else: + query = ('SELECT * FROM process AS process WHERE process.display=? order by {0} {1}'.format(ncol, sort)) + result = self.dbAdapter.metadata.bind.execute(query, str(showProcesses)).fetchall() + + return result + + def storeProcess(self, proc): + p_output = process_output() + p = process(str(proc.pid()), str(proc.name), str(proc.tabTitle), + str(proc.hostIp), str(proc.port), str(proc.protocol), + unicode(proc.command), proc.startTime, "", str(proc.outputfile), + 'Waiting', [p_output], 100, 0) + self.log.info(f"Adding process: {p}") + self.dbAdapter.session().add(p) + self.dbAdapter.commit() + proc.id = p.id + return p.id + + def storeProcessOutput(self, process_id: str, output: str): + session = self.dbAdapter.session() proc = session.query(process).filter_by(id=process_id).first() if not proc: @@ -41,28 +76,116 @@ def store_process_output(self, process_id: str, output: str): proc.endTime = getTimestamp(True) if proc.status == "Killed" or proc.status == "Cancelled" or proc.status == "Crashed": - self.db_adapter.commit() + self.dbAdapter.commit() return True else: proc.status = 'Finished' session.add(proc) - self.db_adapter.commit() + self.dbAdapter.commit() - def get_status_by_process_id(self, process_id: str): - return self.get_field_by_process_id("status", process_id) + def getStatusByProcessId(self, process_id: str): + return self.getFieldByProcessId("status", process_id) - def get_pid_by_process_id(self, process_id: str): - return self.get_field_by_process_id("pid", process_id) + def getPIDByProcessId(self, process_id: str): + return self.getFieldByProcessId("pid", process_id) - def is_killed_process(self, process_id: str) -> bool: - status = self.get_field_by_process_id("status", process_id) + def isKilledProcess(self, process_id: str) -> bool: + status = self.getFieldByProcessId("status", process_id) return True if status == "Killed" else False - def is_cancelled_process(self, process_id: str) -> bool: - status = self.get_field_by_process_id("status", process_id) + def isCancelledProcess(self, process_id: str) -> bool: + status = self.getFieldByProcessId("status", process_id) return True if status == "Cancelled" else False - def get_field_by_process_id(self, field_name: str, process_id: str): + def getFieldByProcessId(self, field_name: str, process_id: str): query = f"SELECT process.{field_name} FROM process AS process WHERE process.id=?" - p = self.db_adapter.metadata.bind.execute(query, str(process_id)).fetchall() + p = self.dbAdapter.metadata.bind.execute(query, str(process_id)).fetchall() return p[0][0] if p else -1 + + def getHostsByToolName(self, toolName: str, closed: str = "False"): + if closed == 'FetchAll': + query = ('SELECT "0", "0", "0", "0", "0", process.hostIp, process.port, process.protocol, "0", "0", ' + 'process.outputfile, "0", "0", "0" FROM process AS process WHERE process.name=?') + else: + query = ('SELECT process.id, "0", "0", "0", "0", "0", "0", process.hostIp, process.port, ' + 'process.protocol, "0", "0", process.outputfile, "0", "0", "0" FROM process AS process ' + 'WHERE process.name=? and process.closed="False"') + + return self.dbAdapter.metadata.bind.execute(query, str(toolName)).fetchall() + + def storeProcessCrashStatus(self, processId: str): + session = self.dbAdapter.session() + proc = session.query(process).filter_by(id=processId).first() + if proc and not proc.status == 'Killed' and not proc.status == 'Cancelled': + proc.status = 'Crashed' + proc.endTime = getTimestamp(True) + session.add(proc) + self.dbAdapter.commit() + + def storeProcessCancelStatus(self, processId: str): + session = self.dbAdapter.session() + proc = session.query(process).filter_by(id=processId).first() + if proc: + proc.status = 'Cancelled' + proc.endTime = getTimestamp(True) + session.add(proc) + self.dbAdapter.commit() + + def storeProcessKillStatus(self, processId: str): + session = self.dbAdapter.session() + proc = session.query(process).filter_by(id=processId).first() + if proc and not proc.status == 'Finished': + proc.status = 'Killed' + proc.endTime = getTimestamp(True) + session.add(proc) + self.dbAdapter.commit() + + def storeProcessRunningStatus(self, processId: str, pid): + session = self.dbAdapter.session() + proc = session.query(process).filter_by(id=processId).first() + if proc: + proc.status = 'Running' + proc.pid = str(pid) + session.add(proc) + self.dbAdapter.commit() + + def storeProcessRunningElapsedTime(self, processId: str, elapsed): + session = self.dbAdapter.session() + proc = session.query(process).filter_by(id=processId).first() + if proc: + proc.elapsed = elapsed + session.add(proc) + self.dbAdapter.commit() + + def storeCloseStatus(self, processId): + session = self.dbAdapter.session() + proc = session.query(process).filter_by(id=processId).first() + if proc: + proc.closed = 'True' + session.add(proc) + self.dbAdapter.commit() + + def storeScreenshot(self, ip: str, port: str, filename: str): + p = process(0, "screenshooter", "screenshot (" + str(port) + "/tcp)", str(ip), str(port), "tcp", "", + getTimestamp(True), getTimestamp(True), str(filename), "Finished", [process_output()], 2, 0) + session = self.dbAdapter.session() + session.add(p) + session.commit() + return p.id + + def toggleProcessDisplayStatus(self, resetAll=False): + session = self.dbAdapter.session() + proc = session.query(process).filter_by(display='True').all() + for p in proc: + session.add(self.toggleProcessStatusField(p, resetAll)) + self.dbAdapter.commit() + + @staticmethod + def toggleProcessStatusField(p, reset_all): + not_running = p.status != 'Running' + not_waiting = p.status != 'Waiting' + + if (reset_all and not_running) or (not_running and not_waiting): + p.display = 'False' + + return p diff --git a/db/repositories/ScriptRepository.py b/db/repositories/ScriptRepository.py new file mode 100644 index 0000000..ad76650 --- /dev/null +++ b/db/repositories/ScriptRepository.py @@ -0,0 +1,34 @@ +""" +LEGION (https://govanguard.io) +Copyright (c) 2018 GoVanguard + + This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public + License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later + version. + + This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied + warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more + details. + + You should have received a copy of the GNU General Public License along with this program. + If not, see . + +Author(s): Dmitriy Dubson (d.dubson@gmail.com) +""" +from db.database import Database + + +class ScriptRepository: + def __init__(self, dbAdapter: Database): + self.dbAdapter = dbAdapter + + def getScriptsByHostIP(self, hostIP): + query = ("SELECT host.id, host.scriptId, port.portId, port.protocol FROM l1ScriptObj AS host " + "INNER JOIN hostObj AS hosts ON hosts.id = host.hostId " + "LEFT OUTER JOIN portObj AS port ON port.id = host.portId WHERE hosts.ip=?") + + return self.dbAdapter.metadata.bind.execute(query, str(hostIP)).fetchall() + + def getScriptOutputById(self, scriptDBId): + query = "SELECT script.output FROM l1ScriptObj as script WHERE script.id = ?" + return self.dbAdapter.metadata.bind.execute(query, str(scriptDBId)).fetchall() \ No newline at end of file diff --git a/db/repositories/ServiceRepository.py b/db/repositories/ServiceRepository.py index b2a164a..c67c1f5 100644 --- a/db/repositories/ServiceRepository.py +++ b/db/repositories/ServiceRepository.py @@ -17,25 +17,25 @@ """ from app.auxiliary import sanitise, Filters from db.database import hostObj -from db.filters import apply_filters +from db.filters import applyFilters class ServiceRepository: def __init__(self, db_adapter): self.db_adapter = db_adapter - def get_service_names(self, filters: Filters): + def getServiceNames(self, filters: Filters): query = ("SELECT DISTINCT service.name FROM serviceObj as service " "INNER JOIN portObj as ports " "INNER JOIN hostObj AS hosts " "ON hosts.id = ports.hostId AND service.id=ports.serviceId WHERE 1=1") - query += apply_filters(filters) + query += applyFilters(filters) query += ' ORDER BY service.name ASC' return self.db_adapter.metadata.bind.execute(query).fetchall() - def get_service_names_by_host_ip_and_port(self, host_ip, port): + def getServiceNamesByHostIPAndPort(self, host_ip, port): query = ("SELECT services.name FROM serviceObj AS services " "INNER JOIN hostObj AS hosts ON hosts.id = ports.hostId " "INNER JOIN portObj AS ports ON services.id=ports.serviceId " "WHERE hosts.ip=? and ports.portId = ?") - return self.db_adapter.metadata.bind.execute(query, str(host_ip), str(port)).first() \ No newline at end of file + return self.db_adapter.metadata.bind.execute(query, str(host_ip), str(port)).first() diff --git a/log/legion-db.log b/log/legion-db.log deleted file mode 100644 index 8b13789..0000000 --- a/log/legion-db.log +++ /dev/null @@ -1 +0,0 @@ - diff --git a/log/legion-startup.log b/log/legion-startup.log deleted file mode 100644 index 8b13789..0000000 --- a/log/legion-startup.log +++ /dev/null @@ -1 +0,0 @@ - diff --git a/log/legion.log b/log/legion.log deleted file mode 100644 index 8b13789..0000000 --- a/log/legion.log +++ /dev/null @@ -1 +0,0 @@ - diff --git a/tests/app/test_logic.py b/tests/app/test_logic.py index 440720d..9c7df95 100644 --- a/tests/app/test_logic.py +++ b/tests/app/test_logic.py @@ -20,13 +20,6 @@ from unittest.mock import MagicMock, patch -def build_mock_process(status: str, display: str) -> MagicMock: - process = MagicMock() - process.status = status - process.display = display - return process - - class LogicTest(unittest.TestCase): @patch('utilities.stenoLogging.get_logger') def setUp(self, get_logger) -> None: @@ -77,53 +70,3 @@ def test_removeTemporaryFiles_whenProjectIsTemporary_shouldRemoveProjectAndOutpu self.shell.remove_file.assert_called_once_with("project-name") self.shell.remove_directory.assert_has_calls([mock.call("./output/folder"), mock.call("./running/folder")]) - - def test_toggleProcessDisplayStatus_whenResetAllIsTrue_setDisplayToFalseForAllProcessesThatAreNotRunning( - self): - from app.logic import Logic - logic = Logic("test-session", self.mock_db_session, self.shell) - - process1 = build_mock_process(status="Waiting", display="True") - process2 = build_mock_process(status="Waiting", display="True") - logic.db = MagicMock() - logic.db.session.return_value = self.mock_db_session - mock_query_response = MagicMock() - mock_filtered_response = MagicMock() - mock_filtered_response.all.return_value = [process1, process2] - mock_query_response.filter_by.return_value = mock_filtered_response - self.mock_db_session.query.return_value = mock_query_response - logic.toggleProcessDisplayStatus(resetAll=True) - - self.assertEqual("False", process1.display) - self.assertEqual("False", process2.display) - self.mock_db_session.add.assert_has_calls([ - mock.call(process1), - mock.call(process2), - ]) - logic.db.commit.assert_called_once() - - def test_toggleProcessDisplayStatus_whenResetAllIFalse_setDisplayToFalseForAllProcessesThatAreNotRunningOrWaiting( - self): - from app.logic import Logic - logic = Logic("test-session", self.mock_db_session, self.shell) - - process1 = build_mock_process(status="Random Status", display="True") - process2 = build_mock_process(status="Another Random Status", display="True") - process3 = build_mock_process(status="Running", display="True") - logic.db = MagicMock() - logic.db.session.return_value = self.mock_db_session - mock_query_response = MagicMock() - mock_filtered_response = MagicMock() - mock_filtered_response.all.return_value = [process1, process2] - mock_query_response.filter_by.return_value = mock_filtered_response - self.mock_db_session.query.return_value = mock_query_response - logic.toggleProcessDisplayStatus() - - self.assertEqual("False", process1.display) - self.assertEqual("False", process2.display) - self.assertEqual("True", process3.display) - self.mock_db_session.add.assert_has_calls([ - mock.call(process1), - mock.call(process2), - ]) - logic.db.commit.assert_called_once() diff --git a/tests/db/helpers/db_helpers.py b/tests/db/helpers/db_helpers.py index c54a064..66f93be 100644 --- a/tests/db/helpers/db_helpers.py +++ b/tests/db/helpers/db_helpers.py @@ -1,25 +1,31 @@ from unittest.mock import MagicMock -def mock_execute_fetchall(return_value): +def mockExecuteFetchAll(return_value): mock_db_execute = MagicMock() mock_db_execute.fetchall.return_value = return_value return mock_db_execute -def mock_first_by_side_effect(return_value): +def mockExecuteAll(return_value): + mock_db_execute = MagicMock() + mock_db_execute.all.return_value = return_value + return mock_db_execute + + +def mockFirstBySideEffect(return_value): mock_filter_by = MagicMock() mock_filter_by.first.side_effect = return_value return mock_filter_by -def mock_first_by_return_value(return_value): +def mockFirstByReturnValue(return_value): mock_filter_by = MagicMock() mock_filter_by.first.return_value = return_value return mock_filter_by -def mock_query_with_filter_by(return_value): +def mockQueryWithFilterBy(return_value): mock_query = MagicMock() mock_query.filter_by.return_value = return_value return mock_query diff --git a/tests/db/repositories/test_CVERepository.py b/tests/db/repositories/test_CVERepository.py index 5b3fba3..2a96cc8 100644 --- a/tests/db/repositories/test_CVERepository.py +++ b/tests/db/repositories/test_CVERepository.py @@ -18,7 +18,7 @@ import unittest from unittest.mock import patch, MagicMock -from tests.db.helpers.db_helpers import mock_execute_fetchall +from tests.db.helpers.db_helpers import mockExecuteFetchAll class CVERepositoryTest(unittest.TestCase): @@ -26,16 +26,14 @@ class CVERepositoryTest(unittest.TestCase): def setUp(self, get_logger) -> None: self.mock_db_adapter = MagicMock() - def test_get_cves_by_host_ip_WhenProvidedAHostIp_ReturnsCVEs(self): + def test_getCVEsByHostIP_WhenProvidedAHostIp_ReturnsCVEs(self): from db.repositories.CVERepository import CVERepository - self.mock_db_adapter.metadata.bind.execute.return_value = mock_execute_fetchall([['cve1'], ['cve2']]) - expected_query = ( - 'SELECT cves.name, cves.severity, cves.product, cves.version, cves.url, cves.source, ' - 'cves.exploitId, cves.exploit, cves.exploitUrl FROM cve AS cves ' - 'INNER JOIN hostObj AS hosts ON hosts.id = cves.hostId ' - 'WHERE hosts.ip = ?' - ) - cve_repository = CVERepository(self.mock_db_adapter) - result = cve_repository.get_cves_by_host_ip("some_host") + self.mock_db_adapter.metadata.bind.execute.return_value = mockExecuteFetchAll([['cve1'], ['cve2']]) + expected_query = ("SELECT cves.name, cves.severity, cves.product, cves.version, cves.url, cves.source, " + "cves.exploitId, cves.exploit, cves.exploitUrl FROM cve AS cves " + "INNER JOIN hostObj AS hosts ON hosts.id = cves.hostId " + "WHERE hosts.ip = ?") + cveRepository = CVERepository(self.mock_db_adapter) + result = cveRepository.getCVEsByHostIP("some_host") self.assertEqual([['cve1'], ['cve2']], result) self.mock_db_adapter.metadata.bind.execute.assert_called_once_with(expected_query, "some_host") diff --git a/tests/db/repositories/test_HostRepository.py b/tests/db/repositories/test_HostRepository.py index 87bc25b..b605b31 100644 --- a/tests/db/repositories/test_HostRepository.py +++ b/tests/db/repositories/test_HostRepository.py @@ -18,12 +18,12 @@ import unittest from unittest.mock import MagicMock, patch -from tests.db.helpers.db_helpers import mock_execute_fetchall, mock_query_with_filter_by, mock_first_by_return_value +from tests.db.helpers.db_helpers import mockExecuteFetchAll, mockQueryWithFilterBy, mockFirstByReturnValue -exists_query = 'SELECT host.ip FROM hostObj AS host WHERE host.ip == ? OR host.hostname == ?' +existsQuery = 'SELECT host.ip FROM hostObj AS host WHERE host.ip == ? OR host.hostname == ?' -def expected_get_hosts_and_ports_query(with_filter: str = "") -> str: +def expectedGetHostsAndPortsQuery(with_filter: str = "") -> str: query = ( "SELECT hosts.ip,ports.portId,ports.protocol,ports.state,ports.hostId,ports.serviceId,services.name," "services.product,services.version,services.extrainfo,services.fingerprint FROM portObj AS ports " @@ -37,92 +37,104 @@ def expected_get_hosts_and_ports_query(with_filter: str = "") -> str: class HostRepositoryTest(unittest.TestCase): @patch('utilities.stenoLogging.get_logger') def setUp(self, get_logger) -> None: - self.mock_db_adapter = MagicMock() - - def get_hosts_and_ports_test_case(self, filters, service_name, expected_query): from db.repositories.HostRepository import HostRepository - - repository = HostRepository(self.mock_db_adapter) - self.mock_db_adapter.metadata.bind.execute.return_value = mock_execute_fetchall( + self.mockDbAdapter = MagicMock() + self.mockDbSession = MagicMock() + self.mockProcess = MagicMock() + self.mockDbAdapter.session.return_value = self.mockDbSession + self.hostRepository = HostRepository(self.mockDbAdapter) + + def getHostsAndPortsTestCase(self, filters, service_name, expectedQuery): + self.mockDbAdapter.metadata.bind.execute.return_value = mockExecuteFetchAll( [{'name': 'service_name1'}, {'name': 'service_name2'}]) - service_names = repository.get_hosts_and_ports_by_service_name(service_name, filters) + service_names = self.hostRepository.getHostsAndPortsByServiceName(service_name, filters) - self.mock_db_adapter.metadata.bind.execute.assert_called_once_with(expected_query, service_name) + self.mockDbAdapter.metadata.bind.execute.assert_called_once_with(expectedQuery, service_name) self.assertEqual([{'name': 'service_name1'}, {'name': 'service_name2'}], service_names) def test_exists_WhenProvidedAExistingHosts_ReturnsTrue(self): - from db.repositories.HostRepository import HostRepository - self.mock_db_adapter.metadata.bind.execute.return_value = mock_execute_fetchall([['some-ip']]) - - host_repository = HostRepository(self.mock_db_adapter) - self.assertTrue(host_repository.exists("some_host")) - self.mock_db_adapter.metadata.bind.execute.assert_called_once_with(exists_query, "some_host", "some_host") + self.mockDbAdapter.metadata.bind.execute.return_value = mockExecuteFetchAll([['some-ip']]) + self.assertTrue(self.hostRepository.exists("some_host")) + self.mockDbAdapter.metadata.bind.execute.assert_called_once_with(existsQuery, "some_host", "some_host") def test_exists_WhenProvidedANonExistingHosts_ReturnsFalse(self): - from db.repositories.HostRepository import HostRepository - self.mock_db_adapter.metadata.bind.execute.return_value = mock_execute_fetchall([]) + self.mockDbAdapter.metadata.bind.execute.return_value = mockExecuteFetchAll([]) - host_repository = HostRepository(self.mock_db_adapter) - self.assertFalse(host_repository.exists("some_host")) - self.mock_db_adapter.metadata.bind.execute.assert_called_once_with(exists_query, "some_host", "some_host") + self.assertFalse(self.hostRepository.exists("some_host")) + self.mockDbAdapter.metadata.bind.execute.assert_called_once_with(existsQuery, "some_host", "some_host") - def test_get_hosts_InvokedWithNoFilters_ReturnsHosts(self): - from db.repositories.HostRepository import HostRepository + def test_getHosts_InvokedWithNoFilters_ReturnsHosts(self): from app.auxiliary import Filters - self.mock_db_adapter.metadata.bind.execute.return_value = mock_execute_fetchall([['host1'], ['host2']]) - expected_query = "SELECT * FROM hostObj AS hosts WHERE 1=1" - host_repository = HostRepository(self.mock_db_adapter) + self.mockDbAdapter.metadata.bind.execute.return_value = mockExecuteFetchAll([['host1'], ['host2']]) + expectedQuery = "SELECT * FROM hostObj AS hosts WHERE 1=1" filters: Filters = Filters() filters.apply(up=True, down=True, checked=True, portopen=True, portfiltered=True, portclosed=True, tcp=True, udp=True) - result = host_repository.get_hosts(filters) + result = self.hostRepository.getHosts(filters) self.assertEqual([['host1'], ['host2']], result) - self.mock_db_adapter.metadata.bind.execute.assert_called_once_with(expected_query) + self.mockDbAdapter.metadata.bind.execute.assert_called_once_with(expectedQuery) - def test_get_hosts_InvokedWithAFewFilters_ReturnsFilteredHosts(self): - from db.repositories.HostRepository import HostRepository + def test_getHosts_InvokedWithAFewFilters_ReturnsFilteredHosts(self): from app.auxiliary import Filters - self.mock_db_adapter.metadata.bind.execute.return_value = mock_execute_fetchall([['host1'], ['host2']]) - expected_query = ("SELECT * FROM hostObj AS hosts WHERE 1=1" - " AND hosts.status != 'down' AND hosts.checked != 'True'") - host_repository = HostRepository(self.mock_db_adapter) + self.mockDbAdapter.metadata.bind.execute.return_value = mockExecuteFetchAll([['host1'], ['host2']]) + expectedQuery = ("SELECT * FROM hostObj AS hosts WHERE 1=1" + " AND hosts.status != 'down' AND hosts.checked != 'True'") filters: Filters = Filters() filters.apply(up=True, down=False, checked=False, portopen=True, portfiltered=True, portclosed=True, tcp=True, udp=True) - result = host_repository.get_hosts(filters) + result = self.hostRepository.getHosts(filters) self.assertEqual([['host1'], ['host2']], result) - self.mock_db_adapter.metadata.bind.execute.assert_called_once_with(expected_query) + self.mockDbAdapter.metadata.bind.execute.assert_called_once_with(expectedQuery) - def test_get_host_info_WhenProvidedHostIpAddress_FetchesHostInformation(self): + def test_getHostInfo_WhenProvidedHostIpAddress_FetchesHostInformation(self): from db.database import hostObj - from db.repositories.HostRepository import HostRepository - mock_db_session = MagicMock() expected_host_info: hostObj = MagicMock() - self.mock_db_adapter.session.return_value = mock_db_session - mock_db_session.query.return_value = mock_query_with_filter_by(mock_first_by_return_value(expected_host_info)) + self.mockDbSession.query.return_value = mockQueryWithFilterBy(mockFirstByReturnValue(expected_host_info)) - repository = HostRepository(self.mock_db_adapter) - actual_host_info = repository.get_host_information("127.0.0.1") + actual_host_info = self.hostRepository.getHostInformation("127.0.0.1") self.assertEqual(actual_host_info, expected_host_info) - def test_get_hosts_and_ports_InvokedWithNoFilters_FetchesHostsAndPortsMatchingKeywords(self): + def test_getHostsAndPorts_InvokedWithNoFilters_FetchesHostsAndPortsMatchingKeywords(self): from app.auxiliary import Filters filters: Filters = Filters() filters.apply(up=True, down=True, checked=True, portopen=True, portfiltered=True, portclosed=True, tcp=True, udp=True) - expected_query = expected_get_hosts_and_ports_query() - self.get_hosts_and_ports_test_case(filters=filters, - service_name="some_service_name", expected_query=expected_query) + expectedQuery = expectedGetHostsAndPortsQuery() + self.getHostsAndPortsTestCase(filters=filters, + service_name="some_service_name", expectedQuery=expectedQuery) - def test_get_hosts_and_ports_InvokedWithFewFilters_FetchesHostsAndPortsWithFiltersApplied(self): + def test_getHostsAndPorts_InvokedWithFewFilters_FetchesHostsAndPortsWithFiltersApplied(self): from app.auxiliary import Filters filters: Filters = Filters() filters.apply(up=True, down=False, checked=True, portopen=True, portfiltered=True, portclosed=True, tcp=True, udp=False) - expected_query = expected_get_hosts_and_ports_query( + expectedQuery = expectedGetHostsAndPortsQuery( with_filter=" AND hosts.status != 'down' AND ports.protocol != 'udp'") - self.get_hosts_and_ports_test_case(filters=filters, - service_name="some_service_name", expected_query=expected_query) + self.getHostsAndPortsTestCase(filters=filters, + service_name="some_service_name", expectedQuery=expectedQuery) + + def test_deleteHost_InvokedWithAHostId_DeletesProcess(self): + self.mockDbSession.query.return_value = mockQueryWithFilterBy(mockFirstByReturnValue(self.mockProcess)) + + self.hostRepository.deleteHost("some-host-id") + self.mockDbSession.delete.assert_called_once_with(self.mockProcess) + self.mockDbSession.commit.assert_called_once() + + def test_toggleHostCheckStatus_WhenHostIsSetToTrue_TogglesToFalse(self): + self.mockProcess.checked = 'True' + self.mockDbSession.query.return_value = mockQueryWithFilterBy(mockFirstByReturnValue(self.mockProcess)) + self.hostRepository.toggleHostCheckStatus("some-ip-address") + self.assertEqual('False', self.mockProcess.checked) + self.mockDbSession.add.assert_called_once_with(self.mockProcess) + self.mockDbAdapter.commit.assert_called_once() + + def test_toggleHostCheckStatus_WhenHostIsSetToFalse_TogglesToTrue(self): + self.mockProcess.checked = 'False' + self.mockDbSession.query.return_value = mockQueryWithFilterBy(mockFirstByReturnValue(self.mockProcess)) + self.hostRepository.toggleHostCheckStatus("some-ip-address") + self.assertEqual('True', self.mockProcess.checked) + self.mockDbSession.add.assert_called_once_with(self.mockProcess) + self.mockDbAdapter.commit.assert_called_once() diff --git a/tests/db/repositories/test_NoteRepository.py b/tests/db/repositories/test_NoteRepository.py new file mode 100644 index 0000000..c34b93b --- /dev/null +++ b/tests/db/repositories/test_NoteRepository.py @@ -0,0 +1,54 @@ +""" +LEGION (https://govanguard.io) +Copyright (c) 2018 GoVanguard + + This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public + License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later + version. + + This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied + warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more + details. + + You should have received a copy of the GNU General Public License along with this program. + If not, see . + +Author(s): Dmitriy Dubson (d.dubson@gmail.com) +""" +import unittest +from unittest.mock import MagicMock, patch + +from db.repositories.NoteRepository import NoteRepository +from tests.db.helpers.db_helpers import mockQueryWithFilterBy, mockFirstByReturnValue + + +class NoteRepositoryTest(unittest.TestCase): + @patch('utilities.stenoLogging.get_logger') + def setUp(self, get_logger) -> None: + from db.database import note + self.mockDbAdapter = MagicMock() + self.mockDbSession = MagicMock() + self.someNote: note = MagicMock() + self.mockLog = MagicMock() + self.noteRepository: NoteRepository = NoteRepository(self.mockDbAdapter, self.mockLog) + + def test_getNoteByHostId_WhenProvidedHostId_ReturnsNote(self): + self.mockDbAdapter.session.return_value = self.mockDbSession + self.mockDbSession.query.return_value = mockQueryWithFilterBy(mockFirstByReturnValue("some-note")) + + note = self.noteRepository.getNoteByHostId("some-host-id") + self.assertEqual("some-note", note) + + def test_storeNotes_WhenProvidedHostIdAndNoteAndNoteAlreadyExists_UpdatesNote(self): + self.mockDbAdapter.session.return_value = self.mockDbSession + self.mockDbSession.query.return_value = mockQueryWithFilterBy(mockFirstByReturnValue(self.someNote)) + self.noteRepository.storeNotes("some-host-id", "some-note") + self.mockDbSession.add.assert_called_once_with(self.someNote) + self.mockDbAdapter.commit.assert_called_once() + + def test_storeNotes_WhenProvidedHostIdAndNoteAndNoteDoesNotExist_SavesNewNote(self): + self.mockDbAdapter.session.return_value = self.mockDbSession + self.mockDbSession.query.return_value = mockQueryWithFilterBy(mockFirstByReturnValue(None)) + self.noteRepository.storeNotes("some-host-id", "some-note") + self.mockDbSession.add.assert_called_once() + self.mockDbAdapter.commit.assert_called_once() diff --git a/tests/db/repositories/test_PortRepository.py b/tests/db/repositories/test_PortRepository.py index 1305118..6bccc0e 100644 --- a/tests/db/repositories/test_PortRepository.py +++ b/tests/db/repositories/test_PortRepository.py @@ -16,44 +16,43 @@ Author(s): Dmitriy Dubson (d.dubson@gmail.com) """ import unittest +from unittest import mock from unittest.mock import patch, MagicMock -from tests.db.helpers.db_helpers import mock_first_by_return_value, mock_execute_fetchall +from tests.db.helpers.db_helpers import mockFirstByReturnValue, mockExecuteFetchAll, mockExecuteAll, \ + mockQueryWithFilterBy class PortRepositoryTest(unittest.TestCase): @patch('utilities.stenoLogging.get_logger') def setUp(self, get_logger) -> None: - self.mock_db_adapter = MagicMock() - - def test_get_ports_by_ip_and_protocol_ReturnsPorts(self): from db.repositories.PortRepository import PortRepository + self.mockDbAdapter = MagicMock() + self.mockDbSession = MagicMock() + self.mockDbAdapter.session.return_value = self.mockDbSession + self.repository = PortRepository(self.mockDbAdapter) + def test_getPortsByIPAndProtocol_ReturnsPorts(self): expected_query = ("SELECT ports.portId FROM portObj AS ports " "INNER JOIN hostObj AS hosts ON hosts.id = ports.hostId " "WHERE hosts.ip = ? and ports.protocol = ?") - repository = PortRepository(self.mock_db_adapter) - self.mock_db_adapter.metadata.bind.execute.return_value = mock_first_by_return_value( + self.mockDbAdapter.metadata.bind.execute.return_value = mockFirstByReturnValue( [['port-id1'], ['port-id2']]) - ports = repository.get_ports_by_ip_and_protocol("some_host_ip", "tcp") + ports = self.repository.getPortsByIPAndProtocol("some_host_ip", "tcp") - self.mock_db_adapter.metadata.bind.execute.assert_called_once_with(expected_query, "some_host_ip", "tcp") + self.mockDbAdapter.metadata.bind.execute.assert_called_once_with(expected_query, "some_host_ip", "tcp") self.assertEqual([['port-id1'], ['port-id2']], ports) - def test_get_port_states_by_host_id_ReturnsPortsStates(self): - from db.repositories.PortRepository import PortRepository - + def test_getPortStatesByHostId_ReturnsPortsStates(self): expected_query = 'SELECT port.state FROM portObj as port WHERE port.hostId = ?' - repository = PortRepository(self.mock_db_adapter) - self.mock_db_adapter.metadata.bind.execute.return_value = mock_execute_fetchall( + self.mockDbAdapter.metadata.bind.execute.return_value = mockExecuteFetchAll( [['port-state1'], ['port-state2']]) - port_states = repository.get_port_states_by_host_id("some_host_id") + port_states = self.repository.getPortStatesByHostId("some_host_id") - self.mock_db_adapter.metadata.bind.execute.assert_called_once_with(expected_query, "some_host_id") + self.mockDbAdapter.metadata.bind.execute.assert_called_once_with(expected_query, "some_host_id") self.assertEqual([['port-state1'], ['port-state2']], port_states) - def test_get_ports_and_services_by_host_ip_InvokedWithNoFilters_ReturnsPortsAndServices(self): - from db.repositories.PortRepository import PortRepository + def test_getPortsAndServicesByHostIP_InvokedWithNoFilters_ReturnsPortsAndServices(self): from app.auxiliary import Filters expected_query = ("SELECT hosts.ip, ports.portId, ports.protocol, ports.state, ports.hostId, ports.serviceId, " @@ -61,19 +60,17 @@ def test_get_ports_and_services_by_host_ip_InvokedWithNoFilters_ReturnsPortsAndS "FROM portObj AS ports INNER JOIN hostObj AS hosts ON hosts.id = ports.hostId " "LEFT OUTER JOIN serviceObj AS services ON services.id = ports.serviceId " "WHERE hosts.ip = ?") - self.mock_db_adapter.metadata.bind.execute.return_value = mock_execute_fetchall([['ip1'], ['ip2']]) - repository = PortRepository(self.mock_db_adapter) + self.mockDbAdapter.metadata.bind.execute.return_value = mockExecuteFetchAll([['ip1'], ['ip2']]) filters: Filters = Filters() filters.apply(up=True, down=True, checked=True, portopen=True, portfiltered=True, portclosed=True, tcp=True, udp=True) - results = repository.get_ports_and_services_by_host_ip("some_host_ip", filters) + results = self.repository.getPortsAndServicesByHostIP("some_host_ip", filters) - self.mock_db_adapter.metadata.bind.execute.assert_called_once_with(expected_query, "some_host_ip") + self.mockDbAdapter.metadata.bind.execute.assert_called_once_with(expected_query, "some_host_ip") self.assertEqual([['ip1'], ['ip2']], results) - def test_get_ports_and_services_by_host_ip_InvokedWithFewFilters_ReturnsPortsAndServices(self): - from db.repositories.PortRepository import PortRepository + def test_getPortsAndServicesByHostIP_InvokedWithFewFilters_ReturnsPortsAndServices(self): from app.auxiliary import Filters expected_query = ("SELECT hosts.ip, ports.portId, ports.protocol, ports.state, ports.hostId, ports.serviceId, " @@ -81,13 +78,33 @@ def test_get_ports_and_services_by_host_ip_InvokedWithFewFilters_ReturnsPortsAnd "FROM portObj AS ports INNER JOIN hostObj AS hosts ON hosts.id = ports.hostId " "LEFT OUTER JOIN serviceObj AS services ON services.id = ports.serviceId " "WHERE hosts.ip = ? AND ports.protocol != 'tcp' AND ports.protocol != 'udp'") - self.mock_db_adapter.metadata.bind.execute.return_value = mock_execute_fetchall([['ip1'], ['ip2']]) - repository = PortRepository(self.mock_db_adapter) + self.mockDbAdapter.metadata.bind.execute.return_value = mockExecuteFetchAll([['ip1'], ['ip2']]) filters: Filters = Filters() filters.apply(up=True, down=True, checked=True, portopen=True, portfiltered=True, portclosed=True, tcp=False, udp=False) - results = repository.get_ports_and_services_by_host_ip("some_host_ip", filters) + results = self.repository.getPortsAndServicesByHostIP("some_host_ip", filters) - self.mock_db_adapter.metadata.bind.execute.assert_called_once_with(expected_query, "some_host_ip") + self.mockDbAdapter.metadata.bind.execute.assert_called_once_with(expected_query, "some_host_ip") self.assertEqual([['ip1'], ['ip2']], results) + + def test_deleteAllPortsAndScriptsByHostId_WhenProvidedByHostIDAndProtocol_DeletesAllPortsAndScripts(self): + mockFilterHost = mockProtocolFilter = mockReturnAll = MagicMock() + mockPort1 = mockPort2 = MagicMock() + mockReturnAll.all.return_value = [mockPort1, mockPort2] + mockProtocolFilter.filter.return_value = mockReturnAll + mockFilterHost.filter.return_value = mockProtocolFilter + + mockFilterScript = mockReturnAllScripts = MagicMock() + mockReturnAllScripts.all.return_value = ['some-script1', 'some-script2'] + mockFilterScript.filter.return_value = mockReturnAllScripts + + self.mockDbSession.query.side_effect = [mockFilterHost, mockFilterScript, mockFilterScript] + + self.repository.deleteAllPortsAndScriptsByHostId("some-host-id", "some-protocol") + self.mockDbSession.delete.assert_has_calls([ + mock.call('some-script1'), mock.call('some-script2'), + mock.call('some-script1'), mock.call('some-script2'), + mock.call(mockPort1), mock.call(mockPort2) + ]) + self.mockDbSession.commit.assert_called_once() diff --git a/tests/db/repositories/test_ProcessRepository.py b/tests/db/repositories/test_ProcessRepository.py index 3b6b992..b1671a7 100644 --- a/tests/db/repositories/test_ProcessRepository.py +++ b/tests/db/repositories/test_ProcessRepository.py @@ -19,169 +19,288 @@ from unittest import mock from unittest.mock import MagicMock, patch -from tests.db.helpers.db_helpers import mock_execute_fetchall, mock_first_by_side_effect, mock_first_by_return_value, \ - mock_query_with_filter_by +from tests.db.helpers.db_helpers import mockExecuteFetchAll, mockFirstBySideEffect, mockFirstByReturnValue, \ + mockQueryWithFilterBy + + +def build_mock_process(status: str, display: str) -> MagicMock: + process = MagicMock() + process.status = status + process.display = display + return process class ProcessRepositoryTest(unittest.TestCase): @patch('utilities.stenoLogging.get_logger') def setUp(self, get_logger) -> None: - self.mock_db_adapter = MagicMock() - self.mock_logger = MagicMock() - - def test_store_process_output_WhenProvidedExistingProcessIdAndOutput_StoresProcessOutput(self): from db.repositories.ProcessRepository import ProcessRepository + self.mockProcess = MagicMock() + self.mockDbSession = MagicMock() + self.mockDbAdapter = MagicMock() + self.mockLogger = MagicMock() + self.mockFilters = MagicMock() + self.mockDbAdapter.session.return_value = self.mockDbSession + self.processRepository = ProcessRepository(self.mockDbAdapter, self.mockLogger) + + def test_getProcesses_WhenProvidedShowProcessesWithNoNmapFlag_ReturnsProcesses(self): + expectedQuery = ('SELECT "0", "0", "0", process.name, "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0" ' + 'FROM process AS process WHERE process.closed = "False" AND process.name != "nmap" ' + 'GROUP BY process.name') + self.mockDbAdapter.metadata.bind.execute.return_value = mockExecuteFetchAll( + [['some-process'], ['some-process2']]) + processes = self.processRepository.getProcesses(self.mockFilters, showProcesses='noNmap') + self.assertEqual(processes, [['some-process'], ['some-process2']]) + self.mockDbAdapter.metadata.bind.execute.assert_called_once_with(expectedQuery) + + def test_getProcesses_WhenProvidedShowProcessesWithFlagFalse_ReturnsProcesses(self): + expectedQuery = ('SELECT process.id, process.hostIp, process.tabTitle, process.outputfile, output.output ' + 'FROM process AS process INNER JOIN process_output AS output ON process.id = output.processId ' + 'WHERE process.display = ? AND process.closed = "False" order by process.id desc') + self.mockDbAdapter.metadata.bind.execute.return_value = mockExecuteFetchAll( + [['some-process'], ['some-process2']]) + processes = self.processRepository.getProcesses(self.mockFilters, showProcesses='False') + self.assertEqual(processes, [['some-process'], ['some-process2']]) + self.mockDbAdapter.metadata.bind.execute.assert_called_once_with(expectedQuery, 'False') + + def test_getProcesses_WhenProvidedShowProcessesWithNoFlag_ReturnsProcesses(self): + expectedQuery = "SELECT * FROM process AS process WHERE process.display=? order by id asc" + self.mockDbAdapter.metadata.bind.execute.return_value = mockExecuteFetchAll( + [['some-process'], ['some-process2']]) + processes = self.processRepository.getProcesses(self.mockFilters, "True", sort='asc', ncol='id') + self.assertEqual(processes, [['some-process'], ['some-process2']]) + self.mockDbAdapter.metadata.bind.execute.assert_called_once_with(expectedQuery, 'True') + + def test_storeProcess_WhenProvidedAProcess_StoreProcess(self): + processId = self.processRepository.storeProcess(self.mockProcess) + + self.mockDbSession.add.assert_called_once() + self.mockDbAdapter.commit.assert_called_once() + + def test_storeProcessOutput_WhenProvidedExistingProcessIdAndOutput_StoresProcessOutput(self): from db.database import process, process_output expected_process: process = MagicMock() process.status = 'Running' expected_process_output: process_output = MagicMock() - mock_db_session = MagicMock() - self.mock_db_adapter.session.return_value = mock_db_session mock_query = MagicMock() - mock_query.filter_by.return_value = mock_first_by_side_effect([expected_process, expected_process_output]) - mock_db_session.query.return_value = mock_query + mock_query.filter_by.return_value = mockFirstBySideEffect([expected_process, expected_process_output]) + self.mockDbSession.query.return_value = mock_query - process_repository = ProcessRepository(self.mock_db_adapter, self.mock_logger) - process_repository.store_process_output("some_process_id", "this is some cool output") + self.processRepository.storeProcessOutput("some_process_id", "this is some cool output") - mock_db_session.add.assert_has_calls([ + self.mockDbSession.add.assert_has_calls([ mock.call(expected_process_output), mock.call(expected_process) ]) - self.mock_db_adapter.commit.assert_called_once() - - def test_store_process_output_WhenProvidedProcessIdDoesNotExist_DoesNotPerformAnyUpdate(self): - from db.repositories.ProcessRepository import ProcessRepository + self.mockDbAdapter.commit.assert_called_once() - mock_db_session = MagicMock() - self.mock_db_adapter.session.return_value = mock_db_session - mock_db_session.query.return_value = mock_query_with_filter_by(mock_first_by_return_value(False)) + def test_storeProcessOutput_WhenProvidedProcessIdDoesNotExist_DoesNotPerformAnyUpdate(self): + self.mockDbAdapter.session.return_value = self.mockDbSession + self.mockDbSession.query.return_value = mockQueryWithFilterBy(mockFirstByReturnValue(False)) - process_repository = ProcessRepository(self.mock_db_adapter, self.mock_logger) - process_repository.store_process_output("some_non_existant_process_id", "this is some cool output") + self.processRepository.storeProcessOutput("some_non_existent_process_id", "this is some cool output") - mock_db_session.add.assert_not_called() - self.mock_db_adapter.commit.assert_not_called() + self.mockDbSession.add.assert_not_called() + self.mockDbAdapter.commit.assert_not_called() - def test_store_process_output_WhenProvidedExistingProcessIdAndOutputButProcKilled_StoresOutputButStatusNotUpdated( + def test_storeProcessOutput_WhenProvidedExistingProcessIdAndOutputButProcKilled_StoresOutputButStatusNotUpdated( self): - self.when_process_does_not_finish_gracefully("Killed") + self.whenProcessDoesNotFinishGracefully("Killed") - def test_store_process_output_WhenProvidedExistingProcessIdAndOutputButProcCancelled_StoresOutputButStatusNotUpdated( + def test_storeProcessOutput_WhenProvidedExistingProcessIdAndOutputButProcCancelled_StoresOutputButStatusNotUpdated( self): - self.when_process_does_not_finish_gracefully("Cancelled") + self.whenProcessDoesNotFinishGracefully("Cancelled") - def test_store_process_output_WhenProvidedExistingProcessIdAndOutputButProcCrashed_StoresOutputButStatusNotUpdated( + def test_storeProcessOutput_WhenProvidedExistingProcessIdAndOutputButProcCrashed_StoresOutputButStatusNotUpdated( self): - self.when_process_does_not_finish_gracefully("Crashed") - - def test_get_status_by_process_id_WhenGivenProcId_FetchesProcessStatus(self): - from db.repositories.ProcessRepository import ProcessRepository + self.whenProcessDoesNotFinishGracefully("Crashed") - expected_query = 'SELECT process.status FROM process AS process WHERE process.id=?' - self.mock_db_adapter.metadata.bind.execute.return_value = mock_execute_fetchall([['Running']]) + def test_getStatusByProcessId_WhenGivenProcId_FetchesProcessStatus(self): + expectedQuery = 'SELECT process.status FROM process AS process WHERE process.id=?' + self.mockDbAdapter.metadata.bind.execute.return_value = mockExecuteFetchAll([['Running']]) - process_repository = ProcessRepository(self.mock_db_adapter, self.mock_logger) - actual_status = process_repository.get_status_by_process_id("some_process_id") + actual_status = self.processRepository.getStatusByProcessId("some_process_id") self.assertEqual(actual_status, 'Running') - self.mock_db_adapter.metadata.bind.execute.assert_called_once_with(expected_query, "some_process_id") + self.mockDbAdapter.metadata.bind.execute.assert_called_once_with(expectedQuery, "some_process_id") - def test_get_status_by_process_id_WhenProcIdDoesNotExist_ReturnsNegativeOne(self): - from db.repositories.ProcessRepository import ProcessRepository + def test_getStatusByProcessId_WhenProcIdDoesNotExist_ReturnsNegativeOne(self): + expectedQuery = 'SELECT process.status FROM process AS process WHERE process.id=?' + self.mockDbAdapter.metadata.bind.execute.return_value = mockExecuteFetchAll(False) + + actual_status = self.processRepository.getStatusByProcessId("some_process_id") + + self.assertEqual(actual_status, -1) + self.mockDbAdapter.metadata.bind.execute.assert_called_once_with(expectedQuery, "some_process_id") + + def test_getPIDByProcessId_WhenGivenProcId_FetchesProcessId(self): + expectedQuery = 'SELECT process.pid FROM process AS process WHERE process.id=?' + self.mockDbAdapter.metadata.bind.execute.return_value = mockExecuteFetchAll([['1234']]) + + actual_status = self.processRepository.getPIDByProcessId("some_process_id") + + self.assertEqual(actual_status, '1234') + self.mockDbAdapter.metadata.bind.execute.assert_called_once_with(expectedQuery, "some_process_id") - expected_query = 'SELECT process.status FROM process AS process WHERE process.id=?' - self.mock_db_adapter.metadata.bind.execute.return_value = mock_execute_fetchall(False) + def test_getPIDByProcessId_WhenProcIdDoesNotExist_ReturnsNegativeOne(self): + expectedQuery = 'SELECT process.pid FROM process AS process WHERE process.id=?' + self.mockDbAdapter.metadata.bind.execute.return_value = mockExecuteFetchAll(False) - process_repository = ProcessRepository(self.mock_db_adapter, self.mock_logger) - actual_status = process_repository.get_status_by_process_id("some_process_id") + actual_status = self.processRepository.getPIDByProcessId("some_process_id") self.assertEqual(actual_status, -1) - self.mock_db_adapter.metadata.bind.execute.assert_called_once_with(expected_query, "some_process_id") + self.mockDbAdapter.metadata.bind.execute.assert_called_once_with(expectedQuery, "some_process_id") - def test_get_pid_by_process_id_WhenGivenProcId_FetchesProcessId(self): - from db.repositories.ProcessRepository import ProcessRepository + def test_isKilledProcess_WhenProvidedKilledProcessId_ReturnsTrue(self): + expectedQuery = "SELECT process.status FROM process AS process WHERE process.id=?" + self.mockDbAdapter.metadata.bind.execute.return_value = mockExecuteFetchAll([["Killed"]]) - expected_query = 'SELECT process.pid FROM process AS process WHERE process.id=?' - self.mock_db_adapter.metadata.bind.execute.return_value = mock_execute_fetchall([['1234']]) + self.assertTrue(self.processRepository.isKilledProcess("some_process_id")) + self.mockDbAdapter.metadata.bind.execute.assert_called_once_with(expectedQuery, "some_process_id") - process_repository = ProcessRepository(self.mock_db_adapter, self.mock_logger) - actual_status = process_repository.get_pid_by_process_id("some_process_id") + def test_isKilledProcess_WhenProvidedNonKilledProcessId_ReturnsFalse(self): + expectedQuery = "SELECT process.status FROM process AS process WHERE process.id=?" + self.mockDbAdapter.metadata.bind.execute.return_value = mockExecuteFetchAll([["Running"]]) - self.assertEqual(actual_status, '1234') - self.mock_db_adapter.metadata.bind.execute.assert_called_once_with(expected_query, "some_process_id") + self.assertFalse(self.processRepository.isKilledProcess("some_process_id")) + self.mockDbAdapter.metadata.bind.execute.assert_called_once_with(expectedQuery, "some_process_id") - def test_get_pid_by_process_id_WhenProcIdDoesNotExist_ReturnsNegativeOne(self): - from db.repositories.ProcessRepository import ProcessRepository + def test_isCancelledProcess_WhenProvidedCancelledProcessId_ReturnsTrue(self): + expectedQuery = "SELECT process.status FROM process AS process WHERE process.id=?" + self.mockDbAdapter.metadata.bind.execute.return_value = mockExecuteFetchAll([["Cancelled"]]) - expected_query = 'SELECT process.pid FROM process AS process WHERE process.id=?' - self.mock_db_adapter.metadata.bind.execute.return_value = mock_execute_fetchall(False) + self.assertTrue(self.processRepository.isCancelledProcess("some_process_id")) + self.mockDbAdapter.metadata.bind.execute.assert_called_once_with(expectedQuery, "some_process_id") - process_repository = ProcessRepository(self.mock_db_adapter, self.mock_logger) - actual_status = process_repository.get_pid_by_process_id("some_process_id") + def test_isCancelledProcess_WhenProvidedNonCancelledProcessId_ReturnsFalse(self): + expectedQuery = "SELECT process.status FROM process AS process WHERE process.id=?" + self.mockDbAdapter.metadata.bind.execute.return_value = mockExecuteFetchAll([["Running"]]) - self.assertEqual(actual_status, -1) - self.mock_db_adapter.metadata.bind.execute.assert_called_once_with(expected_query, "some_process_id") + self.assertFalse(self.processRepository.isCancelledProcess("some_process_id")) + self.mockDbAdapter.metadata.bind.execute.assert_called_once_with(expectedQuery, "some_process_id") - def test_is_killed_process_WhenProvidedKilledProcessId_ReturnsTrue(self): - from db.repositories.ProcessRepository import ProcessRepository + def test_storeProcessCrashStatus_WhenProvidedProcessId_StoresProcessCrashStatus(self): + self.mockProcessStatusAndReturnSingle("Running") + self.processRepository.storeProcessCrashStatus("some-process-id") + self.assertProcessStatusUpdatedTo("Crashed") - expected_query = "SELECT process.status FROM process AS process WHERE process.id=?" - self.mock_db_adapter.metadata.bind.execute.return_value = mock_execute_fetchall([["Killed"]]) + def test_storeProcessCancelledStatus_WhenProvidedProcessId_StoresProcessCancelledStatus(self): + self.mockProcessStatusAndReturnSingle("Running") + self.processRepository.storeProcessCancelStatus("some-process-id") + self.assertProcessStatusUpdatedTo("Cancelled") - process_repository = ProcessRepository(self.mock_db_adapter, self.mock_logger) + def test_storeProcessRunningStatus_WhenProvidedProcessId_StoresProcessRunningStatus(self): + self.mockProcessStatusAndReturnSingle("Waiting") + self.processRepository.storeProcessRunningStatus("some-process-id", "3123") + self.assertProcessStatusUpdatedTo("Running") - self.assertTrue(process_repository.is_killed_process("some_process_id")) - self.mock_db_adapter.metadata.bind.execute.assert_called_once_with(expected_query, "some_process_id") + def test_storeProcessKillStatus_WhenProvidedProcessId_StoresProcessKillStatus(self): + self.mockProcessStatusAndReturnSingle("Running") + self.processRepository.storeProcessKillStatus("some-process-id") + self.assertProcessStatusUpdatedTo("Killed") - def test_is_killed_process_WhenProvidedNonKilledProcessId_ReturnsFalse(self): - from db.repositories.ProcessRepository import ProcessRepository + def test_storeProcessRunningElapsedTime_WhenProvidedProcessId_StoresProcessRunningElapsedTime(self): + self.mockProcess.elapsed = "some-time" + self.mockDbSession.query.return_value = mockQueryWithFilterBy(mockFirstByReturnValue(self.mockProcess)) - expected_query = "SELECT process.status FROM process AS process WHERE process.id=?" - self.mock_db_adapter.metadata.bind.execute.return_value = mock_execute_fetchall([["Running"]]) + self.processRepository.storeProcessRunningElapsedTime("some-process-id", "another-time") + self.assertEqual("another-time", self.mockProcess.elapsed) + self.mockDbSession.add.assert_called_once_with(self.mockProcess) + self.mockDbAdapter.commit.assert_called_once() - process_repository = ProcessRepository(self.mock_db_adapter, self.mock_logger) + def test_getHostsByToolName_WhenProvidedToolNameAndClosedFalse_StoresProcessRunningElapsedTime(self): + expectedQuery = ('SELECT process.id, "0", "0", "0", "0", "0", "0", process.hostIp, process.port, ' + 'process.protocol, "0", "0", process.outputfile, "0", "0", "0" FROM process AS process ' + 'WHERE process.name=? and process.closed="False"') + self.mockDbAdapter.metadata.bind.execute.return_value = mockExecuteFetchAll([["some-host1"], ["some-host2"]]) - self.assertFalse(process_repository.is_killed_process("some_process_id")) - self.mock_db_adapter.metadata.bind.execute.assert_called_once_with(expected_query, "some_process_id") + hosts = self.processRepository.getHostsByToolName("some-toolname", "False") + self.assertEqual([["some-host1"], ["some-host2"]], hosts) + self.mockDbAdapter.metadata.bind.execute.assert_called_once_with(expectedQuery, "some-toolname") - def test_is_cancelled_process_WhenProvidedCancelledProcessId_ReturnsTrue(self): - from db.repositories.ProcessRepository import ProcessRepository + def test_getHostsByToolName_WhenProvidedToolNameAndClosedAsFetchAll_StoresProcessRunningElapsedTime(self): + expectedQuery = ('SELECT "0", "0", "0", "0", "0", process.hostIp, process.port, process.protocol, "0", "0", ' + 'process.outputfile, "0", "0", "0" FROM process AS process WHERE process.name=?') + self.mockDbAdapter.metadata.bind.execute.return_value = mockExecuteFetchAll([["some-host1"], ["some-host2"]]) - expected_query = "SELECT process.status FROM process AS process WHERE process.id=?" - self.mock_db_adapter.metadata.bind.execute.return_value = mock_execute_fetchall([["Cancelled"]]) + hosts = self.processRepository.getHostsByToolName("some-toolname", "FetchAll") + self.assertEqual([["some-host1"], ["some-host2"]], hosts) + self.mockDbAdapter.metadata.bind.execute.assert_called_once_with(expectedQuery, "some-toolname") - process_repository = ProcessRepository(self.mock_db_adapter, self.mock_logger) + def test_storeCloseStatus_WhenProvidedProcessId_StoresCloseStatus(self): + self.mockProcess.closed = 'False' + self.mockDbSession.query.return_value = mockQueryWithFilterBy(mockFirstByReturnValue(self.mockProcess)) + self.processRepository.storeCloseStatus("some-process-id") - self.assertTrue(process_repository.is_cancelled_process("some_process_id")) - self.mock_db_adapter.metadata.bind.execute.assert_called_once_with(expected_query, "some_process_id") + self.assertEqual('True', self.mockProcess.closed) + self.mockDbSession.add.assert_called_once_with(self.mockProcess) + self.mockDbAdapter.commit.assert_called_once() - def test_is_cancelled_process_WhenProvidedNonCancelledProcessId_ReturnsFalse(self): - from db.repositories.ProcessRepository import ProcessRepository + def test_storeScreenshot_WhenProvidedIPAndPortAndFileName_StoresScreenshot(self): + processId = self.processRepository.storeScreenshot("some-ip", "some-port", "some-filename") - expected_query = "SELECT process.status FROM process AS process WHERE process.id=?" - self.mock_db_adapter.metadata.bind.execute.return_value = mock_execute_fetchall([["Running"]]) + self.mockDbSession.add.assert_called_once() + self.mockDbSession.commit.assert_called_once() - process_repository = ProcessRepository(self.mock_db_adapter, self.mock_logger) + def test_toggleProcessDisplayStatus_whenResetAllIsTrue_setDisplayToFalseForAllProcessesThatAreNotRunning( + self): + process1 = build_mock_process(status="Waiting", display="True") + process2 = build_mock_process(status="Waiting", display="True") + mock_query_response = MagicMock() + mock_filtered_response = MagicMock() + mock_filtered_response.all.return_value = [process1, process2] + mock_query_response.filter_by.return_value = mock_filtered_response + self.mockDbSession.query.return_value = mock_query_response + self.processRepository.toggleProcessDisplayStatus(resetAll=True) + + self.assertEqual("False", process1.display) + self.assertEqual("False", process2.display) + self.mockDbSession.add.assert_has_calls([ + mock.call(process1), + mock.call(process2), + ]) + self.mockDbAdapter.commit.assert_called_once() - self.assertFalse(process_repository.is_cancelled_process("some_process_id")) - self.mock_db_adapter.metadata.bind.execute.assert_called_once_with(expected_query, "some_process_id") + def test_toggleProcessDisplayStatus_whenResetAllIFalse_setDisplayToFalseForAllProcessesThatAreNotRunningOrWaiting( + self): + process1 = build_mock_process(status="Random Status", display="True") + process2 = build_mock_process(status="Another Random Status", display="True") + process3 = build_mock_process(status="Running", display="True") + mock_query_response = MagicMock() + mock_filtered_response = MagicMock() + mock_filtered_response.all.return_value = [process1, process2] + mock_query_response.filter_by.return_value = mock_filtered_response + self.mockDbSession.query.return_value = mock_query_response + self.processRepository.toggleProcessDisplayStatus() + + self.assertEqual("False", process1.display) + self.assertEqual("False", process2.display) + self.assertEqual("True", process3.display) + self.mockDbSession.add.assert_has_calls([ + mock.call(process1), + mock.call(process2), + ]) + self.mockDbAdapter.commit.assert_called_once() - def when_process_does_not_finish_gracefully(self, process_status: str): - from db.repositories.ProcessRepository import ProcessRepository + def mockProcessStatusAndReturnSingle(self, processStatus: str): + self.mockProcess.status = processStatus + self.mockDbSession.query.return_value = mockQueryWithFilterBy(mockFirstByReturnValue(self.mockProcess)) + + def assertProcessStatusUpdatedTo(self, expected_status: str): + self.assertEqual(expected_status, self.mockProcess.status) + self.mockDbSession.add.assert_called_once_with(self.mockProcess) + self.mockDbAdapter.commit.assert_called_once() + + def whenProcessDoesNotFinishGracefully(self, process_status: str): from db.database import process, process_output expected_process: process = MagicMock() expected_process.status = process_status expected_process_output: process_output = MagicMock() - mock_db_session = MagicMock() - self.mock_db_adapter.session.return_value = mock_db_session - mock_db_session.query.return_value = mock_query_with_filter_by( - mock_first_by_side_effect([expected_process, expected_process_output])) + self.mockDbSession.query.return_value = mockQueryWithFilterBy( + mockFirstBySideEffect([expected_process, expected_process_output])) - process_repository = ProcessRepository(self.mock_db_adapter, self.mock_logger) - process_repository.store_process_output("some_process_id", "this is some cool output") + self.processRepository.storeProcessOutput("some_process_id", "this is some cool output") - mock_db_session.add.assert_called_once_with(expected_process_output) - self.mock_db_adapter.commit.assert_called_once() + self.mockDbSession.add.assert_called_once_with(expected_process_output) + self.mockDbAdapter.commit.assert_called_once() diff --git a/tests/db/repositories/test_ScriptRepository.py b/tests/db/repositories/test_ScriptRepository.py new file mode 100644 index 0000000..e3d56e1 --- /dev/null +++ b/tests/db/repositories/test_ScriptRepository.py @@ -0,0 +1,30 @@ +import unittest +from unittest.mock import MagicMock + +from tests.db.helpers.db_helpers import mockExecuteFetchAll + + +class ScriptRepositoryTest(unittest.TestCase): + def setUp(self) -> None: + from db.repositories.ScriptRepository import ScriptRepository + self.mockDbAdapter = MagicMock() + self.scriptRepository = ScriptRepository(self.mockDbAdapter) + + def test_getScriptsByHostIP_WhenProvidedAHostIP_ReturnsAllScripts(self): + expectedQuery = ("SELECT host.id, host.scriptId, port.portId, port.protocol FROM l1ScriptObj AS host " + "INNER JOIN hostObj AS hosts ON hosts.id = host.hostId " + "LEFT OUTER JOIN portObj AS port ON port.id = host.portId WHERE hosts.ip=?") + self.mockDbAdapter.metadata.bind.execute.return_value = mockExecuteFetchAll( + [['some-script1'], ['some-script2']]) + scripts = self.scriptRepository.getScriptsByHostIP("some-host-ip") + self.assertEqual([['some-script1'], ['some-script2']], scripts) + self.mockDbAdapter.metadata.bind.execute.assert_called_once_with(expectedQuery, "some-host-ip") + + def test_getScriptOutputById_WhenProvidedAScriptId_ReturnsScriptOutput(self): + expectedQuery = "SELECT script.output FROM l1ScriptObj as script WHERE script.id = ?" + self.mockDbAdapter.metadata.bind.execute.return_value = mockExecuteFetchAll( + [['some-script-output1'], ['some-script-output2']]) + + scripts = self.scriptRepository.getScriptOutputById("some-id") + self.assertEqual([['some-script-output1'], ['some-script-output2']], scripts) + self.mockDbAdapter.metadata.bind.execute.assert_called_once_with(expectedQuery, "some-id") diff --git a/tests/db/repositories/test_ServiceRepository.py b/tests/db/repositories/test_ServiceRepository.py index a535c42..0a6ed63 100644 --- a/tests/db/repositories/test_ServiceRepository.py +++ b/tests/db/repositories/test_ServiceRepository.py @@ -18,63 +18,58 @@ import unittest from unittest.mock import MagicMock, patch -from tests.db.helpers.db_helpers import mock_execute_fetchall, mock_first_by_return_value, mock_query_with_filter_by +from tests.db.helpers.db_helpers import mockExecuteFetchAll, mockFirstByReturnValue, mockQueryWithFilterBy class ServiceRepositoryTest(unittest.TestCase): @patch('utilities.stenoLogging.get_logger') def setUp(self, get_logger) -> None: - self.mock_db_adapter = MagicMock() - - def get_service_names_test_case(self, filters, expected_query): from db.repositories.ServiceRepository import ServiceRepository + self.mockDbAdapter = MagicMock() + self.repository = ServiceRepository(self.mockDbAdapter) - repository = ServiceRepository(self.mock_db_adapter) - - self.mock_db_adapter.metadata.bind.execute.return_value = mock_execute_fetchall( + def getServiceNamesTestCase(self, filters, expectedQuery): + self.mockDbAdapter.metadata.bind.execute.return_value = mockExecuteFetchAll( [{'name': 'service_name1'}, {'name': 'service_name2'}]) - service_names = repository.get_service_names(filters) + service_names = self.repository.getServiceNames(filters) - self.mock_db_adapter.metadata.bind.execute.assert_called_once_with(expected_query) + self.mockDbAdapter.metadata.bind.execute.assert_called_once_with(expectedQuery) self.assertEqual([{'name': 'service_name1'}, {'name': 'service_name2'}], service_names) - def test_get_service_names_InvokedWithNoFilters_FetchesAllServiceNames(self): + def test_getServiceNames_InvokedWithNoFilters_FetchesAllServiceNames(self): from app.auxiliary import Filters - expected_query = query = ("SELECT DISTINCT service.name FROM serviceObj as service " - "INNER JOIN portObj as ports " - "INNER JOIN hostObj AS hosts " - "ON hosts.id = ports.hostId AND service.id=ports.serviceId WHERE 1=1 " - "ORDER BY service.name ASC") + expectedQuery = query = ("SELECT DISTINCT service.name FROM serviceObj as service " + "INNER JOIN portObj as ports " + "INNER JOIN hostObj AS hosts " + "ON hosts.id = ports.hostId AND service.id=ports.serviceId WHERE 1=1 " + "ORDER BY service.name ASC") filters: Filters = Filters() filters.apply(up=True, down=True, checked=True, portopen=True, portfiltered=True, portclosed=True, tcp=True, udp=True) - self.get_service_names_test_case(filters=filters, expected_query=expected_query) + self.getServiceNamesTestCase(filters=filters, expectedQuery=expectedQuery) - def test_get_service_names_InvokedWithFewFilters_FetchesAllServiceNamesWithFiltersApplied(self): + def test_getServiceNames_InvokedWithFewFilters_FetchesAllServiceNamesWithFiltersApplied(self): from app.auxiliary import Filters - expected_query = ("SELECT DISTINCT service.name FROM serviceObj as service " - "INNER JOIN portObj as ports " - "INNER JOIN hostObj AS hosts " - "ON hosts.id = ports.hostId AND service.id=ports.serviceId WHERE 1=1 " - "AND hosts.status != 'down' AND ports.protocol != 'udp' " - "ORDER BY service.name ASC") + expectedQuery = ("SELECT DISTINCT service.name FROM serviceObj as service " + "INNER JOIN portObj as ports " + "INNER JOIN hostObj AS hosts " + "ON hosts.id = ports.hostId AND service.id=ports.serviceId WHERE 1=1 " + "AND hosts.status != 'down' AND ports.protocol != 'udp' " + "ORDER BY service.name ASC") filters: Filters = Filters() filters.apply(up=True, down=False, checked=True, portopen=True, portfiltered=True, portclosed=True, tcp=True, udp=False) - self.get_service_names_test_case(filters=filters, expected_query=expected_query) + self.getServiceNamesTestCase(filters=filters, expectedQuery=expectedQuery) - def test_get_service_names_by_host_ip_and_port_WhenProvidedWithHostIpAndPort_ReturnsServiceNames(self): - from db.repositories.ServiceRepository import ServiceRepository - self.mock_db_adapter.metadata.bind.execute.return_value = mock_first_by_return_value( + def test_getServiceNamesByHostIPAndPort_WhenProvidedWithHostIpAndPort_ReturnsServiceNames(self): + self.mockDbAdapter.metadata.bind.execute.return_value = mockFirstByReturnValue( [['service-name1'], ['service-name2']]) - - expected_query = ("SELECT services.name FROM serviceObj AS services " - "INNER JOIN hostObj AS hosts ON hosts.id = ports.hostId " - "INNER JOIN portObj AS ports ON services.id=ports.serviceId " - "WHERE hosts.ip=? and ports.portId = ?") - service_repository = ServiceRepository(self.mock_db_adapter) - result = service_repository.get_service_names_by_host_ip_and_port("some_host", "1234") + expectedQuery = ("SELECT services.name FROM serviceObj AS services " + "INNER JOIN hostObj AS hosts ON hosts.id = ports.hostId " + "INNER JOIN portObj AS ports ON services.id=ports.serviceId " + "WHERE hosts.ip=? and ports.portId = ?") + result = self.repository.getServiceNamesByHostIPAndPort("some_host", "1234") self.assertEqual([['service-name1'], ['service-name2']], result) - self.mock_db_adapter.metadata.bind.execute.assert_called_once_with(expected_query, "some_host", "1234") + self.mockDbAdapter.metadata.bind.execute.assert_called_once_with(expectedQuery, "some_host", "1234") diff --git a/tests/db/test_filters.py b/tests/db/test_filters.py index 1e5021d..7565a99 100644 --- a/tests/db/test_filters.py +++ b/tests/db/test_filters.py @@ -24,104 +24,104 @@ class FiltersTest(unittest.TestCase): def setUp(self, get_logger) -> None: return - def test_apply_filters_InvokedWithNoFilters_ReturnsEmptyString(self): + def test_applyFilters_InvokedWithNoFilters_ReturnsEmptyString(self): from app.auxiliary import Filters - from db.filters import apply_filters + from db.filters import applyFilters filters: Filters = Filters() filters.apply(up=True, down=True, checked=True, portopen=True, portfiltered=True, portclosed=True, tcp=True, udp=True) - result = apply_filters(filters) + result = applyFilters(filters) self.assertEqual("", result) - def test_apply_filters_InvokedWithHostDownFilter_ReturnsQueryFilterWithHostsDown(self): + def test_applyFilters_InvokedWithHostDownFilter_ReturnsQueryFilterWithHostsDown(self): from app.auxiliary import Filters - from db.filters import apply_filters + from db.filters import applyFilters filters: Filters = Filters() filters.apply(up=True, down=False, checked=True, portopen=True, portfiltered=True, portclosed=True, tcp=True, udp=True) - result = apply_filters(filters) + result = applyFilters(filters) self.assertEqual(" AND hosts.status != 'down'", result) - def test_apply_filters_InvokedWithHostUpFilter_ReturnsQueryFilterWithHostsUp(self): + def test_applyFilters_InvokedWithHostUpFilter_ReturnsQueryFilterWithHostsUp(self): from app.auxiliary import Filters - from db.filters import apply_filters + from db.filters import applyFilters filters: Filters = Filters() filters.apply(up=False, down=True, checked=True, portopen=True, portfiltered=True, portclosed=True, tcp=True, udp=True) - result = apply_filters(filters) + result = applyFilters(filters) self.assertEqual(" AND hosts.status != 'up'", result) - def test_apply_filters_InvokedWithHostCheckedFilter_ReturnsQueryFilterWithHostsChecked(self): + def test_applyFilters_InvokedWithHostCheckedFilter_ReturnsQueryFilterWithHostsChecked(self): from app.auxiliary import Filters - from db.filters import apply_filters + from db.filters import applyFilters filters: Filters = Filters() filters.apply(up=True, down=True, checked=False, portopen=True, portfiltered=True, portclosed=True, tcp=True, udp=True) - result = apply_filters(filters) + result = applyFilters(filters) self.assertEqual(" AND hosts.checked != 'True'", result) - def test_apply_filters_InvokedWithPortOpenFilter_ReturnsQueryFilterWithPortOpen(self): + def test_applyFilters_InvokedWithPortOpenFilter_ReturnsQueryFilterWithPortOpen(self): from app.auxiliary import Filters - from db.filters import apply_filters + from db.filters import applyFilters filters: Filters = Filters() filters.apply(up=True, down=True, checked=True, portopen=False, portfiltered=True, portclosed=True, tcp=True, udp=True) - result = apply_filters(filters) + result = applyFilters(filters) self.assertEqual(" AND ports.state != 'open' AND ports.state != 'open|filtered'", result) - def test_apply_filters_InvokedWithPortClosedFilter_ReturnsQueryFilterWithPortClosed(self): + def test_applyFilters_InvokedWithPortClosedFilter_ReturnsQueryFilterWithPortClosed(self): from app.auxiliary import Filters - from db.filters import apply_filters + from db.filters import applyFilters filters: Filters = Filters() filters.apply(up=True, down=True, checked=True, portopen=True, portfiltered=True, portclosed=False, tcp=True, udp=True) - result = apply_filters(filters) + result = applyFilters(filters) self.assertEqual(" AND ports.state != 'closed'", result) - def test_apply_filters_InvokedWithPortFilteredFilter_ReturnsQueryFilterWithPortFiltered(self): + def test_applyFilters_InvokedWithPortFilteredFilter_ReturnsQueryFilterWithPortFiltered(self): from app.auxiliary import Filters - from db.filters import apply_filters + from db.filters import applyFilters filters: Filters = Filters() filters.apply(up=True, down=True, checked=True, portopen=True, portfiltered=False, portclosed=True, tcp=True, udp=True) - result = apply_filters(filters) + result = applyFilters(filters) self.assertEqual(" AND ports.state != 'filtered' AND ports.state != 'open|filtered'", result) - def test_apply_filters_InvokedWithTcpProtocolFilter_ReturnsQueryFilterWithTcp(self): + def test_applyFilters_InvokedWithTcpProtocolFilter_ReturnsQueryFilterWithTcp(self): from app.auxiliary import Filters - from db.filters import apply_filters + from db.filters import applyFilters filters: Filters = Filters() filters.apply(up=True, down=True, checked=True, portopen=True, portfiltered=True, portclosed=True, tcp=False, udp=True) - result = apply_filters(filters) + result = applyFilters(filters) self.assertEqual(" AND ports.protocol != 'tcp'", result) - def test_apply_filters_InvokedWithUdpProtocolFilter_ReturnsQueryFilterWithUdp(self): + def test_applyFilters_InvokedWithUdpProtocolFilter_ReturnsQueryFilterWithUdp(self): from app.auxiliary import Filters - from db.filters import apply_filters + from db.filters import applyFilters filters: Filters = Filters() filters.apply(up=True, down=True, checked=True, portopen=True, portfiltered=True, portclosed=True, tcp=True, udp=False) - result = apply_filters(filters) + result = applyFilters(filters) self.assertEqual(" AND ports.protocol != 'udp'", result) - def test_apply_filters_InvokedWithKeywordsFilter_ReturnsQueryFilterWithKeywords(self): + def test_applyFilters_InvokedWithKeywordsFilter_ReturnsQueryFilterWithKeywords(self): from app.auxiliary import Filters - from db.filters import apply_filters + from db.filters import applyFilters filters: Filters = Filters() keyword = "some-keyword" filters.apply(up=True, down=True, checked=True, portopen=True, portfiltered=True, portclosed=True, tcp=True, udp=True, keywords=[keyword]) - result = apply_filters(filters) + result = applyFilters(filters) self.assertEqual(" AND (hosts.ip LIKE '%some-keyword%' OR hosts.osMatch LIKE '%some-keyword%'" " OR hosts.hostname LIKE '%some-keyword%')", result)