From 8c098640eea7284c040cc07aa910f36c08cfe483 Mon Sep 17 00:00:00 2001 From: Francis Labelle Date: Tue, 17 Nov 2020 17:18:30 -0500 Subject: [PATCH 01/17] Remove completionID from the key used to identify open files. Fixes #264. --- pyrdp/mitm/DeviceRedirectionMITM.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyrdp/mitm/DeviceRedirectionMITM.py b/pyrdp/mitm/DeviceRedirectionMITM.py index 0c6ee804c..e0752c4d5 100644 --- a/pyrdp/mitm/DeviceRedirectionMITM.py +++ b/pyrdp/mitm/DeviceRedirectionMITM.py @@ -217,7 +217,7 @@ def handleCreateResponse(self, request: DeviceCreateRequestPDU, response: Device mapping = FileMapping.generate(remotePath, self.config.fileDir) proxy = FileProxy(mapping.localPath, "wb") - key = (response.deviceID, response.completionID, response.fileID) + key = (response.deviceID, response.fileID) self.openedFiles[key] = proxy self.openedMappings[key] = mapping @@ -234,7 +234,7 @@ def handleReadResponse(self, request: DeviceReadRequestPDU, response: DeviceRead :param request: the device read request :param response: the device IO response to the request """ - key = (response.deviceID, response.completionID, request.fileID) + key = (response.deviceID, request.fileID) if key in self.openedFiles: file = self.openedFiles[key] @@ -258,7 +258,7 @@ def handleCloseResponse(self, request: DeviceCloseRequestPDU, response: DeviceCl """ self.statCounter.increment(STAT.DEVICE_REDIRECTION_FILE_CLOSE) - key = (response.deviceID, response.completionID, request.fileID) + key = (response.deviceID, request.fileID) if key in self.openedFiles: file = self.openedFiles.pop(key) From 2092562a57c63999ba59d3e3d9097ebf9f896b42 Mon Sep 17 00:00:00 2001 From: Francis Labelle Date: Thu, 19 Nov 2020 16:57:15 -0500 Subject: [PATCH 02/17] Rename files to their sha1 hash when they get closed --- pyrdp/mitm/DeviceRedirectionMITM.py | 16 +++++++++++++--- pyrdp/mitm/FileMapping.py | 4 ++++ 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/pyrdp/mitm/DeviceRedirectionMITM.py b/pyrdp/mitm/DeviceRedirectionMITM.py index e0752c4d5..3022e5b24 100644 --- a/pyrdp/mitm/DeviceRedirectionMITM.py +++ b/pyrdp/mitm/DeviceRedirectionMITM.py @@ -283,16 +283,26 @@ def handleCloseResponse(self, request: DeviceCloseRequestPDU, response: DeviceCl currentMapping.hash = sha1.hexdigest() + isDuplicate = False + # Check if a file with the same hash exists. If so, keep that one and remove the current file. - for localPath, mapping in self.fileMap.items(): + for _, mapping in self.fileMap.items(): if mapping is currentMapping: continue if mapping.hash == currentMapping.hash: - currentMapping.localPath.unlink() - self.fileMap.pop(currentMapping.localPath.name) + isDuplicate = True break + if isDuplicate: + currentMapping.localPath.unlink() + self.fileMap.pop(currentMapping.localPath.name) + else: + oldName = currentMapping.localPath.name + currentMapping.renameToHash() + self.fileMap.pop(oldName) + self.fileMap[currentMapping.localPath.name] = currentMapping + self.saveMapping() def handleClientLogin(self): diff --git a/pyrdp/mitm/FileMapping.py b/pyrdp/mitm/FileMapping.py index a265d39e3..6dabfd802 100644 --- a/pyrdp/mitm/FileMapping.py +++ b/pyrdp/mitm/FileMapping.py @@ -30,6 +30,10 @@ def __init__(self, remotePath: Path, localPath: Path, creationTime: datetime.dat self.creationTime = creationTime self.hash: str = fileHash + def renameToHash(self): + newPath = self.localPath.parents[0] / self.hash + self.localPath = self.localPath.rename(newPath) + @staticmethod def generate(remotePath: Path, outDir: Path): localName = f"{names.get_first_name()}{names.get_last_name()}" From 826431a143db38a22519341f1b1eb85308956f23 Mon Sep 17 00:00:00 2001 From: Francis Labelle Date: Mon, 23 Nov 2020 17:36:31 -0500 Subject: [PATCH 03/17] Fix bug preventing responses to forged requests from being dropped --- pyrdp/mitm/DeviceRedirectionMITM.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyrdp/mitm/DeviceRedirectionMITM.py b/pyrdp/mitm/DeviceRedirectionMITM.py index 3022e5b24..4d1e6b84b 100644 --- a/pyrdp/mitm/DeviceRedirectionMITM.py +++ b/pyrdp/mitm/DeviceRedirectionMITM.py @@ -135,7 +135,7 @@ def handlePDU(self, pdu: DeviceRedirectionPDU, destination: DeviceRedirectionLay if isinstance(pdu, DeviceIORequestPDU) and destination is self.client: self.handleIORequest(pdu) elif isinstance(pdu, DeviceIOResponsePDU) and destination is self.server: - dropPDU = pdu.completionID in self.forgedRequests + dropPDU = (pdu.deviceID, pdu.completionID) in self.forgedRequests self.handleIOResponse(pdu) elif isinstance(pdu, DeviceListAnnounceRequest): From 389c3ac43889ca17b5e48f09bb364a222b32c2a2 Mon Sep 17 00:00:00 2001 From: Francis Labelle Date: Mon, 23 Nov 2020 17:37:02 -0500 Subject: [PATCH 04/17] Start implementings DeviceRedirectionMITM tests --- .gitignore | 3 + test/test_DeviceRedirectionMITM.py | 145 +++++++++++++++++++++++++++++ 2 files changed, 148 insertions(+) create mode 100644 test/test_DeviceRedirectionMITM.py diff --git a/.gitignore b/.gitignore index 8a4fab6f2..622b86f17 100644 --- a/.gitignore +++ b/.gitignore @@ -36,3 +36,6 @@ mitm.json # twisted /twisted/plugins/dropin.cache + +# code coverage +htmlcov/ diff --git a/test/test_DeviceRedirectionMITM.py b/test/test_DeviceRedirectionMITM.py new file mode 100644 index 000000000..5bdef5254 --- /dev/null +++ b/test/test_DeviceRedirectionMITM.py @@ -0,0 +1,145 @@ +import unittest +from pathlib import Path +from unittest.mock import Mock, MagicMock, patch, mock_open + +from pyrdp.enum import IOOperationSeverity +from pyrdp.logging.StatCounter import StatCounter, STAT +from pyrdp.mitm.DeviceRedirectionMITM import DeviceRedirectionMITM +from pyrdp.pdu import DeviceIOResponsePDU + + +def MockIOError(): + ioError = Mock(deviceID = 0, completionID = 0, ioStatus = IOOperationSeverity.STATUS_SEVERITY_ERROR << 30) + return ioError + + +@patch("builtins.open", new_callable=mock_open) +class DeviceRedirectionMITMTest(unittest.TestCase): + def setUp(self): + self.client = Mock() + self.server = Mock() + self.log = Mock() + self.statCounter = Mock() + self.state = Mock() + self.state.config = MagicMock() + self.state.config.outDir = Path("/tmp") + self.mitm = DeviceRedirectionMITM(self.client, self.server, self.log, self.statCounter, self.state) + + def test_stats(self, *args): + self.mitm.handlePDU = Mock() + self.mitm.statCounter = StatCounter() + + self.mitm.onClientPDUReceived(Mock()) + self.assertEqual(self.mitm.statCounter.stats[STAT.DEVICE_REDIRECTION], 1) + self.assertEqual(self.mitm.statCounter.stats[STAT.DEVICE_REDIRECTION_CLIENT], 1) + + self.mitm.onServerPDUReceived(Mock()) + self.assertEqual(self.mitm.statCounter.stats[STAT.DEVICE_REDIRECTION], 2) + self.assertEqual(self.mitm.statCounter.stats[STAT.DEVICE_REDIRECTION_SERVER], 1) + + self.mitm.handleIORequest(Mock()) + self.assertEqual(self.mitm.statCounter.stats[STAT.DEVICE_REDIRECTION_IOREQUEST], 1) + + self.mitm.handleIOResponse(Mock()) + self.assertEqual(self.mitm.statCounter.stats[STAT.DEVICE_REDIRECTION_IORESPONSE], 1) + + error = MockIOError() + self.mitm.handleIORequest(error) + self.mitm.handleIOResponse(error) + self.assertEqual(self.mitm.statCounter.stats[STAT.DEVICE_REDIRECTION_IOERROR], 1) + + self.mitm.handleCloseResponse(Mock(), Mock()) + self.assertEqual(self.mitm.statCounter.stats[STAT.DEVICE_REDIRECTION_FILE_CLOSE], 1) + + self.mitm.sendForgedFileRead(Mock(), Mock()) + self.assertEqual(self.mitm.statCounter.stats[STAT.DEVICE_REDIRECTION_FORGED_FILE_READ], 1) + + self.mitm.sendForgedDirectoryListing(Mock(), MagicMock()) + self.assertEqual(self.mitm.statCounter.stats[STAT.DEVICE_REDIRECTION_FORGED_DIRECTORY_LISTING], 1) + + def test_ioError_showsWarning(self, *args): + self.log.warning = Mock() + error = MockIOError() + + self.mitm.handleIORequest(error) + self.mitm.handleIOResponse(error) + self.log.warning.assert_called_once() + + def test_deviceListAnnounce_logsDevices(self, *args): + pdu = Mock() + pdu.deviceList = [Mock(), Mock(), Mock()] + + self.mitm.observer = Mock() + self.mitm.handleDeviceListAnnounceRequest(pdu) + + self.assertEqual(self.log.info.call_count, len(pdu.deviceList)) + self.assertEqual(self.mitm.observer.onDeviceAnnounce.call_count, len(pdu.deviceList)) + + def test_handleClientLogin_logsCredentials(self, *args): + creds = "PASSWORD" + self.log.info = Mock() + + self.state.credentialsCandidate = creds + self.state.inputBuffer = "" + self.mitm.handleClientLogin() + self.log.info.assert_called_once() + self.assertTrue(creds in self.log.info.call_args[0][1].values()) + + self.log.info.reset_mock() + self.state.credentialsCandidate = "" + self.state.inputBuffer = creds + self.mitm.handleClientLogin() + self.log.info.assert_called_once() + self.assertTrue(creds in self.log.info.call_args[0][1].values()) + + def test_handleIOResponse_uniqueResponse(self, *args): + handler = Mock() + self.mitm.responseHandlers[1234] = handler + + pdu = Mock(deviceID = 0, completionID = 0, majorFunction = 1234, ioStatus = 0) + self.mitm.handleIORequest(pdu) + self.mitm.handleIOResponse(pdu) + handler.assert_called_once() + + # Second response should not go through + self.mitm.handleIOResponse(pdu) + handler.assert_called_once() + + + def test_handleIOResponse_matchingOnly(self, *args): + handler = Mock() + self.mitm.responseHandlers[1234] = handler + + request = Mock(deviceID = 0, completionID = 0) + matching_response = Mock(deviceID = 0, completionID = 0, majorFunction = 1234, ioStatus = 0) + bad_completionID = Mock(deviceID = 0, completionID = 1, majorFunction = 1234, ioStatus = 0) + bad_deviceID = Mock(deviceID = 1, completionID = 0, majorFunction = 1234, ioStatus = 0) + + self.mitm.handleIORequest(request) + self.mitm.handleIOResponse(matching_response) + handler.assert_called_once() + + self.mitm.handleIORequest(request) + + self.mitm.handleIOResponse(bad_completionID) + handler.assert_called_once() + self.log.error.assert_called_once() + self.log.error.reset_mock() + + self.mitm.handleIOResponse(bad_deviceID) + handler.assert_called_once() + self.log.error.assert_called_once() + self.log.error.reset_mock() + + def test_handlePDU_hidesForgedResponses(self, *args): + handler = Mock() + completionID = self.mitm.sendForgedFileRead(0, "forged") + request = self.mitm.forgedRequests[(0, completionID)] + request.handlers[1234] = handler + + self.assertEqual(len(self.mitm.forgedRequests), 1) + response = Mock(deviceID = 0, completionID = completionID, majorFunction = 1234, ioStatus = 0) + response.__class__ = DeviceIOResponsePDU + self.mitm.handlePDU(response, self.mitm.server) + handler.assert_called_once() + self.mitm.server.sendPDU.assert_not_called() From 5fd6e7219ba22b67f026f858743cb8f3fa690363 Mon Sep 17 00:00:00 2001 From: Francis Labelle Date: Tue, 24 Nov 2020 19:26:29 -0500 Subject: [PATCH 05/17] Add tests for file operations --- test/test_DeviceRedirectionMITM.py | 102 ++++++++++++++++++++++++++++- 1 file changed, 100 insertions(+), 2 deletions(-) diff --git a/test/test_DeviceRedirectionMITM.py b/test/test_DeviceRedirectionMITM.py index 5bdef5254..6562beba7 100644 --- a/test/test_DeviceRedirectionMITM.py +++ b/test/test_DeviceRedirectionMITM.py @@ -2,7 +2,7 @@ from pathlib import Path from unittest.mock import Mock, MagicMock, patch, mock_open -from pyrdp.enum import IOOperationSeverity +from pyrdp.enum import CreateOption, FileAccessMask, IOOperationSeverity from pyrdp.logging.StatCounter import StatCounter, STAT from pyrdp.mitm.DeviceRedirectionMITM import DeviceRedirectionMITM from pyrdp.pdu import DeviceIOResponsePDU @@ -105,7 +105,6 @@ def test_handleIOResponse_uniqueResponse(self, *args): self.mitm.handleIOResponse(pdu) handler.assert_called_once() - def test_handleIOResponse_matchingOnly(self, *args): handler = Mock() self.mitm.responseHandlers[1234] = handler @@ -143,3 +142,102 @@ def test_handlePDU_hidesForgedResponses(self, *args): self.mitm.handlePDU(response, self.mitm.server) handler.assert_called_once() self.mitm.server.sendPDU.assert_not_called() + + def test_handleCreateResponse_createsNoFile(self, mock_open): + createRequest = Mock( + deviceID = 0, + completionID = 0, + desiredAccess = (FileAccessMask.GENERIC_READ | FileAccessMask.FILE_READ_DATA), + createOptions = CreateOption.FILE_NON_DIRECTORY_FILE, + path = "file", + ) + createResponse = Mock(deviceID = 0, completionID = 0, fileID = 0) + + with patch("pyrdp.mitm.FileMapping.FileMapping.generate") as generate: + self.mitm.handleCreateResponse(createRequest, createResponse) + self.assertEqual(len(self.mitm.openedFiles), 1) + generate.assert_called_once() + mock_open.assert_not_called() + + def test_handleReadResponse_createsFile(self, mock_open): + request = Mock( + deviceID = 0, + completionID = 0, + fileID = 0, + desiredAccess = (FileAccessMask.GENERIC_READ | FileAccessMask.FILE_READ_DATA), + createOptions = CreateOption.FILE_NON_DIRECTORY_FILE, + path = "file", + ) + response = Mock(deviceID = 0, completionID = 0, fileID = 0, payload = "test payload") + self.mitm.saveMapping = Mock() + + with patch("pyrdp.mitm.FileMapping.FileMapping.generate") as generate: + self.mitm.handleCreateResponse(request, response) + self.mitm.handleReadResponse(request, response) + mock_open.assert_called_once() + self.mitm.saveMapping.assert_called_once() + + # Make sure it checks the file ID + request.fileID, response.fileID = 1, 1 + mock_write = Mock() + list(self.mitm.openedFiles.values())[0].write = mock_write + self.mitm.handleReadResponse(request, response) + mock_write.assert_not_called() + + def test_handleCloseResponse_closesFile(self, mock_open): + request = Mock( + deviceID=0, + completionID=0, + fileID=0, + desiredAccess=(FileAccessMask.GENERIC_READ | FileAccessMask.FILE_READ_DATA), + createOptions=CreateOption.FILE_NON_DIRECTORY_FILE, + path="file", + ) + response = Mock(deviceID=0, completionID=0, fileID=0, payload="test payload") + self.mitm.saveMapping = Mock() + + with patch("pyrdp.mitm.FileMapping.FileMapping.generate") as generate: + close = Mock() + + self.mitm.handleCreateResponse(request, response) + + mapping = list(self.mitm.openedMappings.values())[0] + mapping.renameToHash = Mock() + self.mitm.fileMap[mapping.localPath.name] = Mock() + + file = list(self.mitm.openedFiles.values())[0] + file.close = close + file.file = Mock() + + self.mitm.handleCloseResponse(request, response) + + close.assert_called_once() + mapping.renameToHash.assert_called_once() + self.mitm.saveMapping.assert_called_once() + + def test_handleCloseResponse_removesDuplicates(self, mock_open): + request = Mock( + deviceID=0, + completionID=0, + fileID=0, + desiredAccess=(FileAccessMask.GENERIC_READ | FileAccessMask.FILE_READ_DATA), + createOptions=CreateOption.FILE_NON_DIRECTORY_FILE, + path="file", + ) + response = Mock(deviceID=0, completionID=0, fileID=0, payload="test payload") + self.mitm.saveMapping = Mock() + hash = "hash" + + with patch("pyrdp.mitm.FileMapping.FileMapping.generate") as generate, patch("hashlib.sha1") as sha1: + sha1.return_value.hexdigest = Mock(return_value = hash) + self.mitm.handleCreateResponse(request, response) + + list(self.mitm.openedFiles.values())[0].file = Mock() + mapping = list(self.mitm.openedMappings.values())[0] + mapping.localPath.unlink = Mock() + self.mitm.fileMap[mapping.localPath.name] = Mock() + self.mitm.fileMap["duplicate"] = Mock(hash = hash) + + self.mitm.handleCloseResponse(request, response) + mapping.localPath.unlink.assert_called_once() + self.mitm.saveMapping.assert_called_once() From 6bb7030f1e6d36e18f271c8ed39fe491ac32bb75 Mon Sep 17 00:00:00 2001 From: Francis Labelle Date: Tue, 24 Nov 2020 19:35:16 -0500 Subject: [PATCH 06/17] Check that PAKID_CORE_USER_LOGGEDON goes through --- test/test_DeviceRedirectionMITM.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/test/test_DeviceRedirectionMITM.py b/test/test_DeviceRedirectionMITM.py index 6562beba7..004f2c3f9 100644 --- a/test/test_DeviceRedirectionMITM.py +++ b/test/test_DeviceRedirectionMITM.py @@ -2,10 +2,10 @@ from pathlib import Path from unittest.mock import Mock, MagicMock, patch, mock_open -from pyrdp.enum import CreateOption, FileAccessMask, IOOperationSeverity +from pyrdp.enum import CreateOption, FileAccessMask, IOOperationSeverity, DeviceRedirectionPacketID from pyrdp.logging.StatCounter import StatCounter, STAT from pyrdp.mitm.DeviceRedirectionMITM import DeviceRedirectionMITM -from pyrdp.pdu import DeviceIOResponsePDU +from pyrdp.pdu import DeviceIOResponsePDU, DeviceRedirectionPDU def MockIOError(): @@ -92,6 +92,13 @@ def test_handleClientLogin_logsCredentials(self, *args): self.log.info.assert_called_once() self.assertTrue(creds in self.log.info.call_args[0][1].values()) + self.mitm.handleClientLogin = Mock() + pdu = Mock(packetID = DeviceRedirectionPacketID.PAKID_CORE_USER_LOGGEDON) + pdu.__class__ = DeviceRedirectionPDU + + self.mitm.handlePDU(pdu, self.client) + self.mitm.handleClientLogin.assert_called_once() + def test_handleIOResponse_uniqueResponse(self, *args): handler = Mock() self.mitm.responseHandlers[1234] = handler From 28b4bcc212692b95e2e4a063619344cfe30bb714 Mon Sep 17 00:00:00 2001 From: Francis Labelle Date: Tue, 24 Nov 2020 20:11:48 -0500 Subject: [PATCH 07/17] Add tests for forged file reads --- test/test_DeviceRedirectionMITM.py | 64 ++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/test/test_DeviceRedirectionMITM.py b/test/test_DeviceRedirectionMITM.py index 004f2c3f9..33eb9d33d 100644 --- a/test/test_DeviceRedirectionMITM.py +++ b/test/test_DeviceRedirectionMITM.py @@ -248,3 +248,67 @@ def test_handleCloseResponse_removesDuplicates(self, mock_open): self.mitm.handleCloseResponse(request, response) mapping.localPath.unlink.assert_called_once() self.mitm.saveMapping.assert_called_once() + + +class ForgedRequestTest(unittest.TestCase): + def setUp(self): + self.request = DeviceRedirectionMITM.ForgedRequest(0, 0, Mock()) + + def test_sendIORequest_sendsToClient(self): + self.request.sendIORequest(Mock()) + self.request.mitm.client.sendPDU.assert_called_once() + + def test_onCloseResponse_completesRequest(self): + self.request.onCloseResponse(Mock()) + self.assertTrue(self.request.isComplete) + + def test_onCreateResponse_checksStatus(self): + self.request.onCreateResponse(Mock(ioStatus = 1)) + self.assertIsNone(self.request.fileID) + + +class ForgedFileReadRequestTest(unittest.TestCase): + def setUp(self): + self.request = DeviceRedirectionMITM.ForgedFileReadRequest(0, 0, Mock(), "file") + + def test_onCreateResponse_sendsReadRequest(self): + self.request.sendReadRequest = Mock() + self.request.onCreateResponse(Mock(ioStatus = 0)) + self.request.sendReadRequest.assert_called_once() + + def test_onCreateResponse_completesRequest(self): + self.request.onCreateResponse(Mock(ioStatus = 1)) + self.request.mitm.observer.onFileDownloadComplete.assert_called_once() + self.assertTrue(self.request.isComplete) + + def test_handleFileComplete_sendsCloseRequest(self): + self.request.sendCloseRequest = Mock() + self.request.fileID = Mock() + self.request.handleFileComplete(1) + self.request.sendCloseRequest.assert_called_once() + + def test_onReadResponse_closesOnError(self): + self.request.fileID = Mock() + self.request.sendCloseRequest = Mock() + self.request.mitm.observer.onFileDownloadComplete = Mock() + self.request.onReadResponse(Mock(ioStatus = 1)) + self.request.sendCloseRequest.assert_called_once() + self.request.mitm.observer.onFileDownloadComplete.assert_called_once() + + def test_onReadResponse_updatesProgress(self): + payload = b"testing" + self.request.sendReadRequest = Mock() + self.request.mitm.observer.onFileDownloadResult = Mock() + self.request.onReadResponse(Mock(ioStatus = 0, payload = payload)) + + self.assertEqual(self.request.offset, len(payload)) + self.request.mitm.observer.onFileDownloadResult.assert_called_once() + self.request.sendReadRequest.assert_called_once() + + def test_onReadResponse_closesWhenDone(self): + self.request.fileID = Mock() + self.request.sendCloseRequest = Mock() + self.request.mitm.observer.onFileDownloadComplete = Mock() + self.request.onReadResponse(Mock(ioStatus = 0, payload = b"")) + self.request.sendCloseRequest.assert_called_once() + self.request.mitm.observer.onFileDownloadComplete.assert_called_once() From 24fe50eba57cca18370164528518c85f308457c2 Mon Sep 17 00:00:00 2001 From: Francis Labelle Date: Wed, 25 Nov 2020 15:35:36 -0500 Subject: [PATCH 08/17] Fix trailing slash not getting removed when wildcard is used in directory listing --- pyrdp/mitm/DeviceRedirectionMITM.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyrdp/mitm/DeviceRedirectionMITM.py b/pyrdp/mitm/DeviceRedirectionMITM.py index 4d1e6b84b..2226f55d9 100644 --- a/pyrdp/mitm/DeviceRedirectionMITM.py +++ b/pyrdp/mitm/DeviceRedirectionMITM.py @@ -526,7 +526,7 @@ def send(self): openPath = self.path[: self.path.index("*")] if openPath.endswith("\\"): - openPath = self.path[: -1] + openPath = openPath[: -1] # We need to start by opening the directory. request = DeviceCreateRequestPDU( From 6c58677301b0997c6818c5ce60b4beddc7f4be66 Mon Sep 17 00:00:00 2001 From: Francis Labelle Date: Wed, 25 Nov 2020 15:53:48 -0500 Subject: [PATCH 09/17] Add tests for forged directory listing requests --- test/test_DeviceRedirectionMITM.py | 69 +++++++++++++++++++++++++++++- 1 file changed, 67 insertions(+), 2 deletions(-) diff --git a/test/test_DeviceRedirectionMITM.py b/test/test_DeviceRedirectionMITM.py index 33eb9d33d..539d76849 100644 --- a/test/test_DeviceRedirectionMITM.py +++ b/test/test_DeviceRedirectionMITM.py @@ -2,10 +2,11 @@ from pathlib import Path from unittest.mock import Mock, MagicMock, patch, mock_open -from pyrdp.enum import CreateOption, FileAccessMask, IOOperationSeverity, DeviceRedirectionPacketID +from pyrdp.enum import CreateOption, FileAccessMask, IOOperationSeverity, DeviceRedirectionPacketID, MajorFunction, \ + MinorFunction from pyrdp.logging.StatCounter import StatCounter, STAT from pyrdp.mitm.DeviceRedirectionMITM import DeviceRedirectionMITM -from pyrdp.pdu import DeviceIOResponsePDU, DeviceRedirectionPDU +from pyrdp.pdu import DeviceIOResponsePDU, DeviceRedirectionPDU, DeviceQueryDirectoryRequestPDU def MockIOError(): @@ -312,3 +313,67 @@ def test_onReadResponse_closesWhenDone(self): self.request.onReadResponse(Mock(ioStatus = 0, payload = b"")) self.request.sendCloseRequest.assert_called_once() self.request.mitm.observer.onFileDownloadComplete.assert_called_once() + + +class ForgedDirectoryListingRequestTest(unittest.TestCase): + def setUp(self): + self.request = DeviceRedirectionMITM.ForgedDirectoryListingRequest(0, 0, Mock(), "directory") + + def test_send_removesTrailingSlash(self): + self.request.sendIORequest = Mock() + self.request.path = "directory\\" + + self.request.send() + ioRequest = self.request.sendIORequest.call_args[0][0] + self.assertEqual(ioRequest.path, "directory") + + def test_send_handlesWildcard(self): + self.request.sendIORequest = Mock() + self.request.path = "directory\\*" + + self.request.send() + ioRequest = self.request.sendIORequest.call_args[0][0] + self.assertEqual(ioRequest.path, "directory") + + def test_send_handlesNormalPath(self): + self.request.sendIORequest = Mock() + self.request.send() + + ioRequest = self.request.sendIORequest.call_args[0][0] + self.request.sendIORequest.assert_called_once() + self.assertEqual(ioRequest.path, "directory") + + def test_onCreateResponse_completesOnError(self): + self.request.onCreateResponse(Mock(ioStatus = 1)) + self.assertTrue(self.request.isComplete) + + def test_onCreateResponse_sendsDirectoryRequest(self): + self.request.sendIORequest = Mock() + self.request.onCreateResponse(Mock(ioStatus = 0)) + self.request.sendIORequest.assert_called_once() + self.assertEqual(self.request.sendIORequest.call_args[0][0].majorFunction, MajorFunction.IRP_MJ_DIRECTORY_CONTROL) + self.assertEqual(self.request.sendIORequest.call_args[0][0].minorFunction, MinorFunction.IRP_MN_QUERY_DIRECTORY) + + def test_onDirectoryControlResponse_completesOnError(self): + self.request.sendIORequest = Mock() + self.request.onDirectoryControlResponse(Mock(ioStatus = 1, minorFunction = MinorFunction.IRP_MN_QUERY_DIRECTORY)) + self.request.sendIORequest.assert_called_once() + self.assertEqual(self.request.sendIORequest.call_args[0][0].majorFunction, MajorFunction.IRP_MJ_CLOSE) + self.request.mitm.observer.onDirectoryListingComplete.assert_called_once() + + def test_onDirectoryControlResponse_handlesSuccessfulResponse(self): + self.request.sendIORequest = Mock() + response = MagicMock( + ioStatus = 0, + minorFunction = MinorFunction.IRP_MN_QUERY_DIRECTORY, + fileInformation = [MagicMock()] + ) + + self.request.onDirectoryControlResponse(response) + + # Sends result to observer + self.request.mitm.observer.onDirectoryListingResult.assert_called_once() + + # Sends follow-up directory listing request + self.assertEqual(self.request.sendIORequest.call_args[0][0].majorFunction, MajorFunction.IRP_MJ_DIRECTORY_CONTROL) + self.assertEqual(self.request.sendIORequest.call_args[0][0].minorFunction, MinorFunction.IRP_MN_QUERY_DIRECTORY) From e23e754ddc9bab64f71e127f07e4f09b2a6a44f1 Mon Sep 17 00:00:00 2001 From: Francis Labelle Date: Wed, 25 Nov 2020 16:13:40 -0500 Subject: [PATCH 10/17] Add step in ci.yml to run unit tests --- .github/workflows/ci.yml | 100 +++++++++++++++++++++------------------ 1 file changed, 54 insertions(+), 46 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 98351c76c..edc0cb4f4 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -35,52 +35,56 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 - - uses: actions/setup-python@v1 - with: - python-version: '3.7' # Version range or exact version of a Python version to use, using semvers version range syntax. - architecture: 'x64' - - - name: Python version - run: python --version - - name: Pip version - run: pip --version - - - name: Install setuptools - run: sudo apt install python3-setuptools - - name: Install PyRDP dependencies - run: sudo apt install libdbus-1-dev libdbus-glib-1-dev libgl1-mesa-glx git python3-dev - - name: Install wheel - working-directory: . - run: pip install wheel - - name: Install PyRDP - working-directory: . - run: pip install -U -e .[full] - - - name: Install ci dependencies - run: pip install -r requirements-ci.txt - - - name: Extract test files - uses: DuckSoft/extract-7z-action@v1.0 - with: - pathSource: test/files/test_files.zip - pathTarget: test/files - - - name: Integration Test with a prerecorded PCAP. - working-directory: ./ - run: coverage run test/test_prerecorded.py - - - name: pyrdp-mitm.py initialization integration test - working-directory: ./ - run: coverage run --append test/test_mitm_initialization.py dummy_value - - - name: pyrdp-player.py read a replay in headless mode test - working-directory: ./ - run: coverage run --append bin/pyrdp-player.py --headless test/files/test_session.replay - - - name: Coverage - working-directory: ./ - run: coverage report --fail-under=40 + - uses: actions/checkout@v2 + - uses: actions/setup-python@v1 + with: + python-version: '3.7' # Version range or exact version of a Python version to use, using semvers version range syntax. + architecture: 'x64' + + - name: Python version + run: python --version + - name: Pip version + run: pip --version + + - name: Install setuptools + run: sudo apt install python3-setuptools + - name: Install PyRDP dependencies + run: sudo apt install libdbus-1-dev libdbus-glib-1-dev libgl1-mesa-glx git python3-dev + - name: Install wheel + working-directory: . + run: pip install wheel + - name: Install PyRDP + working-directory: . + run: pip install -U -e .[full] + + - name: Install ci dependencies + run: pip install -r requirements-ci.txt + + - name: Extract test files + uses: DuckSoft/extract-7z-action@v1.0 + with: + pathSource: test/files/test_files.zip + pathTarget: test/files + + - name: Integration Test with a prerecorded PCAP. + working-directory: ./ + run: coverage run test/test_prerecorded.py + + - name: pyrdp-mitm.py initialization integration test + working-directory: ./ + run: coverage run --append test/test_mitm_initialization.py dummy_value + + - name: pyrdp-player.py read a replay in headless mode test + working-directory: ./ + run: coverage run --append bin/pyrdp-player.py --headless test/files/test_session.replay + + - name: Run unit tests + working-directory: ./ + run: coverage run --append -m unittest discover -v + + - name: Coverage + working-directory: ./ + run: coverage report --fail-under=40 @@ -122,3 +126,7 @@ jobs: - name: pyrdp-player.py read a replay in headless mode test working-directory: ./ run: python bin/pyrdp-player.py --headless test/files/test_session.replay + + - name: Run unit tests + working-directory: ./ + run: coverage run --append -m unittest discover -v From 7ff88f892b7b42571dafaab6d8e7dfba7fffa769 Mon Sep 17 00:00:00 2001 From: Francis Labelle Date: Wed, 25 Nov 2020 16:43:33 -0500 Subject: [PATCH 11/17] Use coverage on Windows --- .github/workflows/ci.yml | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index edc0cb4f4..a31d6de07 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -82,7 +82,7 @@ jobs: working-directory: ./ run: coverage run --append -m unittest discover -v - - name: Coverage + - name: Coverage report working-directory: ./ run: coverage report --fail-under=40 @@ -108,6 +108,9 @@ jobs: - name: Install PyRDP working-directory: . run: pip install -U -e .[full] + - name: Install coverage + working-directory: . + run: pip install coverage - name: Extract test files uses: DuckSoft/extract-7z-action@v1.0 @@ -117,16 +120,20 @@ jobs: - name: Integration Test with a prerecorded PCAP. working-directory: ./ - run: python test/test_prerecorded.py + run: coverage run test/test_prerecorded.py - name: pyrdp-mitm.py initialization test working-directory: ./ - run: python test/test_mitm_initialization.py dummy_value + run: coverage run --append test/test_mitm_initialization.py dummy_value - name: pyrdp-player.py read a replay in headless mode test working-directory: ./ - run: python bin/pyrdp-player.py --headless test/files/test_session.replay + run: coverage run --append bin/pyrdp-player.py --headless test/files/test_session.replay - name: Run unit tests working-directory: ./ run: coverage run --append -m unittest discover -v + + - name: Coverage report + working-directory: ./ + run: coverage report --fail-under=40 From e9e9f045d0cfb1707c5ebb75748484936d8a02d7 Mon Sep 17 00:00:00 2001 From: Francis Labelle Date: Fri, 27 Nov 2020 10:40:02 -0500 Subject: [PATCH 12/17] Prevent empty lines from being used as patterns --- pyrdp/mitm/FileCrawlerMITM.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyrdp/mitm/FileCrawlerMITM.py b/pyrdp/mitm/FileCrawlerMITM.py index 656287a60..bdf5bfed4 100644 --- a/pyrdp/mitm/FileCrawlerMITM.py +++ b/pyrdp/mitm/FileCrawlerMITM.py @@ -126,7 +126,7 @@ def parsePatterns(self, path: str) -> List[str]: try: with open(path, "r") as f: for line in f: - if line and line[0] in ["#", " ", "\n"]: + if not line or line[0] in ["#", " ", "\n"]: continue patternList.append(line.lower().rstrip()) From 985569be3d542433207db7ff8adf28457d271b04 Mon Sep 17 00:00:00 2001 From: Francis Labelle Date: Fri, 27 Nov 2020 11:59:14 -0500 Subject: [PATCH 13/17] Fix bug where forged file responses were being dropped. The forged request IDs were never incremented causing requests to be "forgotten" when a response to a simultaneous request was received. --- pyrdp/mitm/DeviceRedirectionMITM.py | 2 +- test/test_DeviceRedirectionMITM.py | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/pyrdp/mitm/DeviceRedirectionMITM.py b/pyrdp/mitm/DeviceRedirectionMITM.py index 2226f55d9..fa807c8cd 100644 --- a/pyrdp/mitm/DeviceRedirectionMITM.py +++ b/pyrdp/mitm/DeviceRedirectionMITM.py @@ -328,7 +328,7 @@ def findNextRequestID(self) -> int: """ completionID = DeviceRedirectionMITM.FORGED_COMPLETION_ID - while completionID in self.forgedRequests: + while completionID in [key[1] for key in self.forgedRequests]: completionID += 1 return completionID diff --git a/test/test_DeviceRedirectionMITM.py b/test/test_DeviceRedirectionMITM.py index 539d76849..179a70f26 100644 --- a/test/test_DeviceRedirectionMITM.py +++ b/test/test_DeviceRedirectionMITM.py @@ -250,6 +250,13 @@ def test_handleCloseResponse_removesDuplicates(self, mock_open): mapping.localPath.unlink.assert_called_once() self.mitm.saveMapping.assert_called_once() + def test_findNextRequestID_incrementsRequestID(self, *args): + baseID = self.mitm.findNextRequestID() + self.mitm.sendForgedFileRead(0, Mock()) + self.assertEqual(self.mitm.findNextRequestID(), baseID + 1) + self.mitm.sendForgedFileRead(1, Mock()) + self.assertEqual(self.mitm.findNextRequestID(), baseID + 2) + class ForgedRequestTest(unittest.TestCase): def setUp(self): From dfba32b7320220e05b098b9fa246dd38d388f5e9 Mon Sep 17 00:00:00 2001 From: Francis Labelle Date: Fri, 27 Nov 2020 13:58:20 -0500 Subject: [PATCH 14/17] Make sessionID part of RDPMITMState --- bin/pyrdp-convert.py | 2 +- pyrdp/mitm/ClipboardMITM.py | 6 ++++-- pyrdp/mitm/FileCrawlerMITM.py | 2 +- pyrdp/mitm/RDPMITM.py | 6 +++--- pyrdp/mitm/state.py | 5 ++++- test/test_prerecorded.py | 2 +- 6 files changed, 14 insertions(+), 9 deletions(-) diff --git a/bin/pyrdp-convert.py b/bin/pyrdp-convert.py index 349e95c25..701aae66e 100755 --- a/bin/pyrdp-convert.py +++ b/bin/pyrdp-convert.py @@ -122,7 +122,7 @@ def sendBytesStub(_: bytes): # We'll set up the recorder ourselves config.recordReplays = False - state = RDPMITMState(config) + state = RDPMITMState(config, log.sessionID) sink, outfile = getSink(format, output_path) transport = ConversionLayer(sink) if sink else FileLayer(outfile) diff --git a/pyrdp/mitm/ClipboardMITM.py b/pyrdp/mitm/ClipboardMITM.py index 8b6d8779a..51b9cf5b1 100644 --- a/pyrdp/mitm/ClipboardMITM.py +++ b/pyrdp/mitm/ClipboardMITM.py @@ -14,6 +14,7 @@ from pyrdp.enum import ClipboardFormatNumber, ClipboardMessageFlags, ClipboardMessageType, PlayerPDUType, FileContentsFlags from pyrdp.layer import ClipboardLayer from pyrdp.logging.StatCounter import StatCounter, STAT +from pyrdp.mitm.state import RDPMITMState from pyrdp.pdu import ClipboardPDU, FormatDataRequestPDU, FormatDataResponsePDU, FileContentsRequestPDU, FileContentsResponsePDU from pyrdp.parser.rdp.virtual_channel.clipboard import FileDescriptor from pyrdp.recording import Recorder @@ -32,7 +33,7 @@ class PassiveClipboardStealer: """ def __init__(self, config: MITMConfig, client: ClipboardLayer, server: ClipboardLayer, log: LoggerAdapter, recorder: Recorder, - statCounter: StatCounter): + statCounter: StatCounter, state: RDPMITMState): """ :param client: clipboard layer for the client side :param server: clipboard layer for the server side @@ -44,13 +45,14 @@ def __init__(self, config: MITMConfig, client: ClipboardLayer, server: Clipboard self.server = server self.config = config self.log = log + self.state = state self.recorder = recorder self.forwardNextDataResponse = True self.files = [] self.transfers = {} self.timeouts = {} # Track active timeout monitoring tasks. - self.fileDir = f"{self.config.fileDir}/{self.log.sessionID}" + self.fileDir = f"{self.config.fileDir}/{self.state.sessionID}" self.client.createObserver( onPDUReceived = self.onClientPDUReceived, diff --git a/pyrdp/mitm/FileCrawlerMITM.py b/pyrdp/mitm/FileCrawlerMITM.py index bdf5bfed4..c457c0e98 100644 --- a/pyrdp/mitm/FileCrawlerMITM.py +++ b/pyrdp/mitm/FileCrawlerMITM.py @@ -216,7 +216,7 @@ def crawlListing(self, requestID: int): def downloadFile(self, file: VirtualFile): remotePath = file.path - basePath = f"{self.config.fileDir}/{self.log.sessionID}" + basePath = f"{self.config.fileDir}/{self.state.sessionID}" localPath = f"{basePath}{remotePath}" self.log.info("Saving %(remotePath)s to %(localPath)s", {"remotePath": remotePath, "localPath": localPath}) diff --git a/pyrdp/mitm/RDPMITM.py b/pyrdp/mitm/RDPMITM.py index 5df6a13e8..e18188351 100644 --- a/pyrdp/mitm/RDPMITM.py +++ b/pyrdp/mitm/RDPMITM.py @@ -85,7 +85,7 @@ def __init__(self, mainLogger: SessionLogger, crawlerLogger: SessionLogger, conf self.statCounter = StatCounter() """Class to keep track of connection-related statistics such as # of mouse events, # of output events, etc.""" - self.state = state if state is not None else RDPMITMState(self.config) + self.state = state if state is not None else RDPMITMState(self.config, self.log.sessionID) """The MITM state""" self.client = RDPLayerSet() @@ -152,7 +152,7 @@ def __init__(self, mainLogger: SessionLogger, crawlerLogger: SessionLogger, conf replayFileName = "rdp_replay_{}_{}_{}.pyrdp"\ .format(date.strftime('%Y%m%d_%H-%M-%S'), date.microsecond // 1000, - self.log.sessionID) + self.state.sessionID) self.recorder.setRecordFilename(replayFileName) self.recorder.addTransport(FileLayer(self.config.replayDir / replayFileName)) @@ -339,7 +339,7 @@ def buildClipboardChannel(self, client: MCSServerChannel, server: MCSClientChann if self.config.disableActiveClipboardStealing: mitm = PassiveClipboardStealer(self.config, clientLayer, serverLayer, self.getLog(MCSChannelName.CLIPBOARD), - self.recorder, self.statCounter) + self.recorder, self.statCounter, self.state) else: mitm = ActiveClipboardStealer(self.config, clientLayer, serverLayer, self.getLog(MCSChannelName.CLIPBOARD), self.recorder, self.statCounter) diff --git a/pyrdp/mitm/state.py b/pyrdp/mitm/state.py index e93bf9aa0..35998994a 100644 --- a/pyrdp/mitm/state.py +++ b/pyrdp/mitm/state.py @@ -21,7 +21,7 @@ class RDPMITMState: State object for the RDP MITM. This is for data that needs to be shared across components. """ - def __init__(self, config: MITMConfig): + def __init__(self, config: MITMConfig, sessionID: str): self.requestedProtocols: Optional[NegotiationProtocols] = None """The original request protocols""" @@ -73,6 +73,9 @@ def __init__(self, config: MITMConfig): self.ctrlPressed = False """The current keybaord ctrl state""" + self.sessionID = sessionID + """The current session ID""" + self.securitySettings.addObserver(self.crypters[ParserMode.CLIENT]) self.securitySettings.addObserver(self.crypters[ParserMode.SERVER]) diff --git a/test/test_prerecorded.py b/test/test_prerecorded.py index a0e745aed..0a30e9b46 100644 --- a/test/test_prerecorded.py +++ b/test/test_prerecorded.py @@ -122,7 +122,7 @@ def sendBytesStub(_: bytes): config.outDir = output_directory # replay_transport = FileLayer(output_path) - state = RDPMITMState(config) + state = RDPMITMState(config, log.sessionID) super().__init__(log, log, config, state, CustomMITMRecorder([], state)) self.client.tcp.sendBytes = sendBytesStub From a13294e0912090c50a97de918f195118fb84f968 Mon Sep 17 00:00:00 2001 From: Francis Labelle Date: Fri, 27 Nov 2020 13:58:20 -0500 Subject: [PATCH 15/17] Make sessionID part of RDPMITMState --- bin/pyrdp-convert.py | 2 +- pyrdp/mitm/ClipboardMITM.py | 10 ++++++---- pyrdp/mitm/FileCrawlerMITM.py | 2 +- pyrdp/mitm/RDPMITM.py | 8 ++++---- pyrdp/mitm/state.py | 5 ++++- test/test_prerecorded.py | 2 +- 6 files changed, 17 insertions(+), 12 deletions(-) diff --git a/bin/pyrdp-convert.py b/bin/pyrdp-convert.py index 349e95c25..701aae66e 100755 --- a/bin/pyrdp-convert.py +++ b/bin/pyrdp-convert.py @@ -122,7 +122,7 @@ def sendBytesStub(_: bytes): # We'll set up the recorder ourselves config.recordReplays = False - state = RDPMITMState(config) + state = RDPMITMState(config, log.sessionID) sink, outfile = getSink(format, output_path) transport = ConversionLayer(sink) if sink else FileLayer(outfile) diff --git a/pyrdp/mitm/ClipboardMITM.py b/pyrdp/mitm/ClipboardMITM.py index 8b6d8779a..29d1931fd 100644 --- a/pyrdp/mitm/ClipboardMITM.py +++ b/pyrdp/mitm/ClipboardMITM.py @@ -14,6 +14,7 @@ from pyrdp.enum import ClipboardFormatNumber, ClipboardMessageFlags, ClipboardMessageType, PlayerPDUType, FileContentsFlags from pyrdp.layer import ClipboardLayer from pyrdp.logging.StatCounter import StatCounter, STAT +from pyrdp.mitm.state import RDPMITMState from pyrdp.pdu import ClipboardPDU, FormatDataRequestPDU, FormatDataResponsePDU, FileContentsRequestPDU, FileContentsResponsePDU from pyrdp.parser.rdp.virtual_channel.clipboard import FileDescriptor from pyrdp.recording import Recorder @@ -32,7 +33,7 @@ class PassiveClipboardStealer: """ def __init__(self, config: MITMConfig, client: ClipboardLayer, server: ClipboardLayer, log: LoggerAdapter, recorder: Recorder, - statCounter: StatCounter): + statCounter: StatCounter, state: RDPMITMState): """ :param client: clipboard layer for the client side :param server: clipboard layer for the server side @@ -44,13 +45,14 @@ def __init__(self, config: MITMConfig, client: ClipboardLayer, server: Clipboard self.server = server self.config = config self.log = log + self.state = state self.recorder = recorder self.forwardNextDataResponse = True self.files = [] self.transfers = {} self.timeouts = {} # Track active timeout monitoring tasks. - self.fileDir = f"{self.config.fileDir}/{self.log.sessionID}" + self.fileDir = f"{self.config.fileDir}/{self.state.sessionID}" self.client.createObserver( onPDUReceived = self.onClientPDUReceived, @@ -206,8 +208,8 @@ class ActiveClipboardStealer(PassiveClipboardStealer): """ def __init__(self, config: MITMConfig, client: ClipboardLayer, server: ClipboardLayer, log: LoggerAdapter, recorder: Recorder, - statCounter: StatCounter): - super().__init__(config, client, server, log, recorder, statCounter) + statCounter: StatCounter, state: RDPMITMState): + super().__init__(config, client, server, log, recorder, statCounter, state) def handlePDU(self, pdu: ClipboardPDU, destination: ClipboardLayer): """ diff --git a/pyrdp/mitm/FileCrawlerMITM.py b/pyrdp/mitm/FileCrawlerMITM.py index bdf5bfed4..c457c0e98 100644 --- a/pyrdp/mitm/FileCrawlerMITM.py +++ b/pyrdp/mitm/FileCrawlerMITM.py @@ -216,7 +216,7 @@ def crawlListing(self, requestID: int): def downloadFile(self, file: VirtualFile): remotePath = file.path - basePath = f"{self.config.fileDir}/{self.log.sessionID}" + basePath = f"{self.config.fileDir}/{self.state.sessionID}" localPath = f"{basePath}{remotePath}" self.log.info("Saving %(remotePath)s to %(localPath)s", {"remotePath": remotePath, "localPath": localPath}) diff --git a/pyrdp/mitm/RDPMITM.py b/pyrdp/mitm/RDPMITM.py index 5df6a13e8..9babc199f 100644 --- a/pyrdp/mitm/RDPMITM.py +++ b/pyrdp/mitm/RDPMITM.py @@ -85,7 +85,7 @@ def __init__(self, mainLogger: SessionLogger, crawlerLogger: SessionLogger, conf self.statCounter = StatCounter() """Class to keep track of connection-related statistics such as # of mouse events, # of output events, etc.""" - self.state = state if state is not None else RDPMITMState(self.config) + self.state = state if state is not None else RDPMITMState(self.config, self.log.sessionID) """The MITM state""" self.client = RDPLayerSet() @@ -152,7 +152,7 @@ def __init__(self, mainLogger: SessionLogger, crawlerLogger: SessionLogger, conf replayFileName = "rdp_replay_{}_{}_{}.pyrdp"\ .format(date.strftime('%Y%m%d_%H-%M-%S'), date.microsecond // 1000, - self.log.sessionID) + self.state.sessionID) self.recorder.setRecordFilename(replayFileName) self.recorder.addTransport(FileLayer(self.config.replayDir / replayFileName)) @@ -339,10 +339,10 @@ def buildClipboardChannel(self, client: MCSServerChannel, server: MCSClientChann if self.config.disableActiveClipboardStealing: mitm = PassiveClipboardStealer(self.config, clientLayer, serverLayer, self.getLog(MCSChannelName.CLIPBOARD), - self.recorder, self.statCounter) + self.recorder, self.statCounter, self.state) else: mitm = ActiveClipboardStealer(self.config, clientLayer, serverLayer, self.getLog(MCSChannelName.CLIPBOARD), - self.recorder, self.statCounter) + self.recorder, self.statCounter, self.state) self.channelMITMs[client.channelID] = mitm def buildDeviceChannel(self, client: MCSServerChannel, server: MCSClientChannel): diff --git a/pyrdp/mitm/state.py b/pyrdp/mitm/state.py index e93bf9aa0..35998994a 100644 --- a/pyrdp/mitm/state.py +++ b/pyrdp/mitm/state.py @@ -21,7 +21,7 @@ class RDPMITMState: State object for the RDP MITM. This is for data that needs to be shared across components. """ - def __init__(self, config: MITMConfig): + def __init__(self, config: MITMConfig, sessionID: str): self.requestedProtocols: Optional[NegotiationProtocols] = None """The original request protocols""" @@ -73,6 +73,9 @@ def __init__(self, config: MITMConfig): self.ctrlPressed = False """The current keybaord ctrl state""" + self.sessionID = sessionID + """The current session ID""" + self.securitySettings.addObserver(self.crypters[ParserMode.CLIENT]) self.securitySettings.addObserver(self.crypters[ParserMode.SERVER]) diff --git a/test/test_prerecorded.py b/test/test_prerecorded.py index a0e745aed..0a30e9b46 100644 --- a/test/test_prerecorded.py +++ b/test/test_prerecorded.py @@ -122,7 +122,7 @@ def sendBytesStub(_: bytes): config.outDir = output_directory # replay_transport = FileLayer(output_path) - state = RDPMITMState(config) + state = RDPMITMState(config, log.sessionID) super().__init__(log, log, config, state, CustomMITMRecorder([], state)) self.client.tcp.sendBytes = sendBytesStub From bcfb1203d8458c220b0e146a3d0a03d87270eb26 Mon Sep 17 00:00:00 2001 From: Francis Labelle Date: Wed, 23 Dec 2020 11:24:21 -0500 Subject: [PATCH 16/17] Standardize outputs of DeviceRedirectionMITM and FileCrawlerMITM --- pyrdp/mitm/DeviceRedirectionMITM.py | 137 ++++++---------------- pyrdp/mitm/FileCrawlerMITM.py | 176 +++++++++++++--------------- pyrdp/mitm/FileMapping.py | 121 ++++++++++--------- pyrdp/mitm/config.py | 7 ++ test/test_DeviceRedirectionMITM.py | 120 ++++++++----------- test/test_FileMapping.py | 85 ++++++++++++++ 6 files changed, 317 insertions(+), 329 deletions(-) create mode 100644 test/test_FileMapping.py diff --git a/pyrdp/mitm/DeviceRedirectionMITM.py b/pyrdp/mitm/DeviceRedirectionMITM.py index fa807c8cd..881c4fbf7 100644 --- a/pyrdp/mitm/DeviceRedirectionMITM.py +++ b/pyrdp/mitm/DeviceRedirectionMITM.py @@ -4,20 +4,17 @@ # Licensed under the GPLv3 or later. # -import hashlib -import json from logging import LoggerAdapter -from pathlib import Path from typing import Dict, Optional, Union -from pyrdp.core import FileProxy, ObservedBy, Observer, Subject -from pyrdp.enum import CreateOption, DeviceRedirectionPacketID, DeviceType, DirectoryAccessMask, FileAccessMask, FileAttributes, \ +from pyrdp.core import ObservedBy, Observer, Subject +from pyrdp.enum import CreateOption, DeviceRedirectionPacketID, DeviceType, DirectoryAccessMask, FileAccessMask, \ + FileAttributes, \ FileCreateDisposition, FileCreateOptions, FileShareAccess, FileSystemInformationClass, IOOperationSeverity, \ MajorFunction, MinorFunction from pyrdp.layer import DeviceRedirectionLayer from pyrdp.logging.StatCounter import StatCounter, STAT -from pyrdp.mitm.config import MITMConfig -from pyrdp.mitm.FileMapping import FileMapping, FileMappingDecoder, FileMappingEncoder +from pyrdp.mitm.FileMapping import FileMapping from pyrdp.mitm.state import RDPMITMState from pyrdp.pdu import DeviceAnnounce, DeviceCloseRequestPDU, DeviceCloseResponsePDU, DeviceCreateRequestPDU, \ DeviceCreateResponsePDU, DeviceDirectoryControlResponsePDU, DeviceIORequestPDU, DeviceIOResponsePDU, \ @@ -60,7 +57,8 @@ def __init__(self, client: DeviceRedirectionLayer, server: DeviceRedirectionLaye :param client: device redirection layer for the client side :param server: device redirection layer for the server side :param log: logger for this component - :param config: MITM configuration + :param statCounter: stat counter object + :param state: shared RDP MITM state """ super().__init__() @@ -69,10 +67,8 @@ def __init__(self, client: DeviceRedirectionLayer, server: DeviceRedirectionLaye self.state = state self.log = log self.statCounter = statCounter - self.openedFiles: Dict[int, FileProxy] = {} - self.openedMappings: Dict[int, FileMapping] = {} - self.fileMap: Dict[str, FileMapping] = {} - self.fileMapPath = self.config.outDir / "mapping.json" + self.mappings: Dict[(int, int), FileMapping] = {} + self.filesystemRoot = self.config.filesystemDir / self.state.sessionID self.currentIORequests: Dict[(int, int), DeviceIORequestPDU] = {} self.forgedRequests: Dict[(int, int), DeviceRedirectionMITM.ForgedRequest] = {} @@ -94,28 +90,18 @@ def __init__(self, client: DeviceRedirectionLayer, server: DeviceRedirectionLaye onPDUReceived=self.onServerPDUReceived, ) - try: - with open(self.fileMapPath, "r") as f: - self.fileMap: Dict[str, FileMapping] = json.loads(f.read(), cls=FileMappingDecoder) - except IOError: - self.log.warning("Could not read the RDPDR file mapping at %(path)s. The file may not exist or it may have incorrect permissions. A new mapping will be created.", { - "path": str(self.fileMapPath), - }) - except json.JSONDecodeError: - self.log.error("Failed to decode file mapping, overwriting previous file") + def deviceRoot(self, deviceID: int): + return self.filesystemRoot / f"device{deviceID}" + + def createDeviceRoot(self, deviceID: int): + path = self.deviceRoot(deviceID) + path.mkdir(parents=True, exist_ok=True) + return path @property def config(self): return self.state.config - def saveMapping(self): - """ - Save the file mapping to a file in JSON format. - """ - - with open(self.fileMapPath, "w") as f: - f.write(json.dumps(self.fileMap, cls=FileMappingEncoder, indent=4, sort_keys=True)) - def onClientPDUReceived(self, pdu: DeviceRedirectionPDU): self.statCounter.increment(STAT.DEVICE_REDIRECTION, STAT.DEVICE_REDIRECTION_CLIENT) self.handlePDU(pdu, self.server) @@ -198,6 +184,7 @@ def handleDeviceListAnnounceRequest(self, pdu: DeviceListAnnounceRequest): "deviceName": device.preferredDOSName }) + self.createDeviceRoot(device.deviceID) self.observer.onDeviceAnnounce(device) def handleCreateResponse(self, request: DeviceCreateRequestPDU, response: DeviceCreateResponsePDU): @@ -210,23 +197,12 @@ def handleCreateResponse(self, request: DeviceCreateRequestPDU, response: Device """ isFileRead = request.desiredAccess & (FileAccessMask.GENERIC_READ | FileAccessMask.FILE_READ_DATA) != 0 - isNotDirectory = request.createOptions & CreateOption.FILE_NON_DIRECTORY_FILE != 0 - - if isFileRead and isNotDirectory: - remotePath = Path(request.path) - mapping = FileMapping.generate(remotePath, self.config.fileDir) - proxy = FileProxy(mapping.localPath, "wb") + isDirectory = request.createOptions & CreateOption.FILE_NON_DIRECTORY_FILE == 0 + if isFileRead and not isDirectory: + mapping = FileMapping.generate(request.path, self.config.fileDir, self.deviceRoot(response.deviceID), self.log) key = (response.deviceID, response.fileID) - self.openedFiles[key] = proxy - self.openedMappings[key] = mapping - - proxy.createObserver( - onFileCreated = lambda _: self.log.info("Saving file '%(remotePath)s' to '%(localPath)s'", { - "localPath": mapping.localPath, "remotePath": mapping.remotePath - }), - onFileClosed = lambda _: self.log.debug("Closing file %(path)s", {"path": mapping.localPath}) - ) + self.mappings[key] = mapping def handleReadResponse(self, request: DeviceReadRequestPDU, response: DeviceReadResponsePDU): """ @@ -236,74 +212,25 @@ def handleReadResponse(self, request: DeviceReadRequestPDU, response: DeviceRead """ key = (response.deviceID, request.fileID) - if key in self.openedFiles: - file = self.openedFiles[key] - file.seek(request.offset) - file.write(response.payload) - - # Save the mapping permanently - mapping = self.openedMappings[key] - fileName = mapping.localPath.name - - if fileName not in self.fileMap: - self.fileMap[fileName] = mapping - self.saveMapping() + if key in self.mappings: + mapping = self.mappings[key] + mapping.seek(request.offset) + mapping.write(response.payload) def handleCloseResponse(self, request: DeviceCloseRequestPDU, response: DeviceCloseResponsePDU): """ Close the file if it was open. Compute the hash of the file, then delete it if we already have a file with the same hash. :param request: the device close request - :param _: the device IO response to the request + :param response: the device IO response to the request """ self.statCounter.increment(STAT.DEVICE_REDIRECTION_FILE_CLOSE) key = (response.deviceID, request.fileID) - if key in self.openedFiles: - file = self.openedFiles.pop(key) - file.close() - - if file.file is None: - return - - currentMapping = self.openedMappings.pop(key) - - # Compute the hash for the final file - with open(currentMapping.localPath, "rb") as f: - sha1 = hashlib.sha1() - - while True: - buffer = f.read(65536) - - if len(buffer) == 0: - break - - sha1.update(buffer) - - currentMapping.hash = sha1.hexdigest() - - isDuplicate = False - - # Check if a file with the same hash exists. If so, keep that one and remove the current file. - for _, mapping in self.fileMap.items(): - if mapping is currentMapping: - continue - - if mapping.hash == currentMapping.hash: - isDuplicate = True - break - - if isDuplicate: - currentMapping.localPath.unlink() - self.fileMap.pop(currentMapping.localPath.name) - else: - oldName = currentMapping.localPath.name - currentMapping.renameToHash() - self.fileMap.pop(oldName) - self.fileMap[currentMapping.localPath.name] = currentMapping - - self.saveMapping() + if key in self.mappings: + mapping = self.mappings.pop(key) + mapping.finalize() def handleClientLogin(self): """ @@ -311,7 +238,9 @@ def handleClientLogin(self): """ if self.state.credentialsCandidate or self.state.inputBuffer: - self.log.info("Credentials candidate from heuristic: %(credentials_candidate)s", {"credentials_candidate" : (self.state.credentialsCandidate or self.state.inputBuffer) }) + self.log.info("Credentials candidate from heuristic: %(credentials_candidate)s", { + "credentials_candidate" : (self.state.credentialsCandidate or self.state.inputBuffer) + }) # Deactivate the logger for this client self.state.loggedIn = True @@ -343,7 +272,7 @@ def sendForgedFileRead(self, deviceID: int, path: str) -> int: if not self.config.extractFiles: self.log.info('Ignored attempt to forge file reads because file extraction is disabled.') - return + return 0 self.statCounter.increment(STAT.DEVICE_REDIRECTION_FORGED_FILE_READ) @@ -367,7 +296,7 @@ def sendForgedDirectoryListing(self, deviceID: int, path: str) -> int: if not self.config.extractFiles: self.log.info('Ignored attempt to forge directory listing because file extraction is disabled.') - return + return 0 self.statCounter.increment(STAT.DEVICE_REDIRECTION_FORGED_DIRECTORY_LISTING) diff --git a/pyrdp/mitm/FileCrawlerMITM.py b/pyrdp/mitm/FileCrawlerMITM.py index c457c0e98..6877a8069 100644 --- a/pyrdp/mitm/FileCrawlerMITM.py +++ b/pyrdp/mitm/FileCrawlerMITM.py @@ -3,18 +3,19 @@ # Copyright (C) 2019 GoSecure Inc. # Licensed under the GPLv3 or later. # +import fnmatch from collections import defaultdict from logging import LoggerAdapter from pathlib import Path -from typing import BinaryIO, Dict, List, Optional, Set +from typing import Dict, List, Optional, Set from pyrdp.enum.virtual_channel.device_redirection import DeviceType -from pyrdp.mitm.config import MITMConfig from pyrdp.mitm.DeviceRedirectionMITM import DeviceRedirectionMITM, DeviceRedirectionMITMObserver +from pyrdp.mitm.FileMapping import FileMapping +from pyrdp.mitm.config import MITMConfig from pyrdp.mitm.state import RDPMITMState from pyrdp.pdu import DeviceAnnounce -import fnmatch class VirtualFile: """ @@ -61,12 +62,11 @@ def __init__(self, mainLogger: LoggerAdapter, fileLogger: LoggerAdapter, config: self.deviceRedirection: Optional[DeviceRedirectionMITM] = None # Pending crawler requests - self.fileDownloadRequests: Dict[int, Path] = {} self.directoryListingRequests: Dict[int, Path] = {} self.directoryListingLists = defaultdict(list) # Download management - self.downloadFiles: Dict[str, BinaryIO] = {} + self.fileMappings: Dict[str, FileMapping] = {} self.downloadDirectories: Set[int] = set() # Crawler detection patterns @@ -101,9 +101,6 @@ def preparePatterns(self): Should only be called once. """ - matchPath = None - ignorePath = None - # Get the default file in pyrdp/mitm/crawler_config if self.config.crawlerMatchFileName: matchPath = Path(self.config.crawlerMatchFileName).absolute() @@ -136,10 +133,22 @@ def parsePatterns(self, path: str) -> List[str]: return patternList + def onDeviceAnnounce(self, device: DeviceAnnounce): + if device.deviceType == DeviceType.RDPDR_DTYP_FILESYSTEM: + + drive = VirtualFile(device.deviceID, device.preferredDOSName, "/", True) + + self.devices[drive.deviceID] = drive + self.unvisitedDrive.append(drive) + + # If the crawler hasn't started, start one instance + if len(self.devices) == 1: + self.dispatchDownload() + def dispatchDownload(self): """ Processes each queue in order of priority. - File download have priority over directory download. + File downloads have priority over directory downloads. Crawl each folder before visiting another drive. """ @@ -161,15 +170,56 @@ def dispatchDownload(self): # List an unvisited drive elif len(self.unvisitedDrive) != 0: drive = self.unvisitedDrive.pop() - - # TODO : Maybe dump whole drive if there isn't a lot of files? - # Maybe if theres no directory at the root directory -> dump all? self.log.info("Begin crawling disk %(disk)s", {"disk" : drive.name}) self.fileLogger.info("Begin crawling disk %(disk)s", {"disk" : drive.name}) self.listDirectory(drive.deviceID, drive.path) else: self.log.info("Done crawling.") + def listDirectory(self, deviceID: int, path: str, download: bool = False): + """ + List the directory + :param deviceID: Drive we are actually listing. + :param path: Path of the directory we are listing. + :param download: Wether or not we need to download this directory. + """ + listingPath = str(Path(path).absolute()).replace("/", "\\") + + if not listingPath.endswith("*"): + if not listingPath.endswith("\\"): + listingPath += "\\" + + listingPath += "*" + + requestID = self.deviceRedirection.sendForgedDirectoryListing(deviceID, listingPath) + + # If the directory is flagged for download, keep trace of the incoming request to trigger download. + if download: + self.downloadDirectories.add(requestID) + + self.directoryListingRequests[requestID] = Path(path).absolute() + + def onDirectoryListingResult(self, deviceID: int, requestID: int, fileName: str, isDirectory: bool): + if requestID not in self.directoryListingRequests: + return + + path = self.directoryListingRequests[requestID] + filePath = path / fileName + + file = VirtualFile(deviceID, fileName, str(filePath), isDirectory) + directoryList = self.directoryListingLists[requestID] + directoryList.append(file) + + def onDirectoryListingComplete(self, deviceID: int, requestID: int): + self.directoryListingRequests.pop(requestID, {}) + + # If directory was flagged for download + if requestID in self.downloadDirectories: + self.downloadDirectories.remove(requestID) + self.addListingToDownloadQueue(requestID) + else: + self.crawlListing(requestID) + def addListingToDownloadQueue(self, requestID: int): directoryList = self.directoryListingLists.pop(requestID, {}) @@ -181,6 +231,7 @@ def addListingToDownloadQueue(self, requestID: int): self.matchedDirectoryQueue.append(item) else: self.matchedFileQueue.append(item) + self.dispatchDownload() def crawlListing(self, requestID: int): @@ -211,74 +262,34 @@ def crawlListing(self, requestID: int): if matched: self.matchedFileQueue.append(item) - self.fileLogger.info("%(file)s - %(isDirectory)s - %(isDownloaded)s", {"file" : item.path, "isDirectory": item.isDirectory, "isDownloaded": matched}) + self.fileLogger.info("%(file)s - %(isDirectory)s - %(isMatched)s", { + "file" : item.path, + "isDirectory": item.isDirectory, + "isMatched": matched + }) + self.dispatchDownload() def downloadFile(self, file: VirtualFile): remotePath = file.path - basePath = f"{self.config.fileDir}/{self.state.sessionID}" - localPath = f"{basePath}{remotePath}" - - self.log.info("Saving %(remotePath)s to %(localPath)s", {"remotePath": remotePath, "localPath": localPath}) - - try: - # Create parent directory, don't raise error if it already exists - Path(localPath).parent.mkdir(parents=True, exist_ok=True) - targetFile = open(localPath, "wb") - except Exception as e: - self.log.exception(e) - self.log.error("Cannot save file: %(localPath)s", {"localPath": localPath}) - return - - self.downloadFiles[remotePath] = targetFile + mapping = FileMapping.generate( + remotePath, + self.config.fileDir, + self.deviceRedirection.createDeviceRoot(file.deviceID), + self.log + ) + + self.fileMappings[remotePath] = mapping self.deviceRedirection.sendForgedFileRead(file.deviceID, remotePath) - def listDirectory(self, deviceID: int, path: str, download: bool = False): - """ - List the directory - :param deviceID: Drive we are actually listing. - :param path: Path of the directory we are listing. - :param download: Wether or not we need to download this directory. - """ - listingPath = str(Path(path).absolute()).replace("/", "\\") - - if not listingPath.endswith("*"): - if not listingPath.endswith("\\"): - listingPath += "\\" - - listingPath += "*" - - requestID = self.deviceRedirection.sendForgedDirectoryListing(deviceID, listingPath) - - # If the directory is flagged for download, keep trace of the incoming request to trigger download. - if download: - self.downloadDirectories.add(requestID) - - self.directoryListingRequests[requestID] = Path(path).absolute() - - def onDeviceAnnounce(self, device: DeviceAnnounce): - if device.deviceType == DeviceType.RDPDR_DTYP_FILESYSTEM: - - drive = VirtualFile(device.deviceID, device.preferredDOSName, "/", True) - - self.devices[drive.deviceID] = drive - self.unvisitedDrive.append(drive) - - # If the crawler hasn't started, start one instance - if len(self.devices) == 1: - self.dispatchDownload() - def onFileDownloadResult(self, deviceID: int, requestID: int, path: str, offset: int, data: bytes): - remotePath = path.replace("\\", "/") - - targetFile = self.downloadFiles[remotePath] - targetFile.write(data) + mapping = self.fileMappings[path] + mapping.seek(offset) + mapping.write(data) def onFileDownloadComplete(self, deviceID: int, requestID: int, path: str, errorCode: int): - remotePath = path.replace("\\", "/") - - file = self.downloadFiles.pop(remotePath) - file.close() + mapping = self.fileMappings.pop(path) + mapping.finalize() if errorCode != 0: # TODO : Handle common error codes like : @@ -286,29 +297,8 @@ def onFileDownloadComplete(self, deviceID: int, requestID: int, path: str, error # Doc : https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-erref/18d8fbe8-a967-4f1c-ae50-99ca8e491d2d self.log.error("Error happened when downloading %(remotePath)s. The file may not have been saved completely. Error code: %(errorCode)s", { - "remotePath": remotePath, + "remotePath": path, "errorCode": "0x%08lx" % errorCode, }) self.dispatchDownload() - - def onDirectoryListingResult(self, deviceID: int, requestID: int, fileName: str, isDirectory: bool): - if requestID not in self.directoryListingRequests: - return - - path = self.directoryListingRequests[requestID] - filePath = path / fileName - - file = VirtualFile(deviceID, fileName, str(filePath), isDirectory) - directoryList = self.directoryListingLists[requestID] - directoryList.append(file) - - def onDirectoryListingComplete(self, deviceID: int, requestID: int): - self.directoryListingRequests.pop(requestID, {}) - - # If directory was flagged for download - if requestID in self.downloadDirectories: - self.downloadDirectories.remove(requestID) - self.addListingToDownloadQueue(requestID) - else: - self.crawlListing(requestID) diff --git a/pyrdp/mitm/FileMapping.py b/pyrdp/mitm/FileMapping.py index 6dabfd802..9550b6d55 100644 --- a/pyrdp/mitm/FileMapping.py +++ b/pyrdp/mitm/FileMapping.py @@ -4,12 +4,11 @@ # Licensed under the GPLv3 or later. # -import datetime -import json +import hashlib +import tempfile +from logging import LoggerAdapter from pathlib import Path -from typing import Dict - -import names +from typing import io class FileMapping: @@ -18,73 +17,79 @@ class FileMapping: transferred over RDP. """ - def __init__(self, remotePath: Path, localPath: Path, creationTime: datetime.datetime, fileHash: str): + def __init__(self, file: io.BinaryIO, dataPath: Path, filesystemPath: Path, filesystemDir: Path, log: LoggerAdapter): """ - :param remotePath: the path of the file on the original machine - :param localPath: the path of the file on the intercepting machine - :param creationTime: the creation time of the local file - :param fileHash: the file hash in hex format (empty string if the file is not complete) + :param file: the file handle for dataPath + :param dataPath: path where the file is actually saved + :param filesystemPath: the path to the replicated filesystem, which will be symlinked to dataPath + :param log: logger """ - self.remotePath = remotePath - self.localPath = localPath - self.creationTime = creationTime - self.hash: str = fileHash + self.file = file + self.filesystemPath = filesystemPath + self.dataPath = dataPath + self.filesystemDir = filesystemDir + self.log = log + self.written = False - def renameToHash(self): - newPath = self.localPath.parents[0] / self.hash - self.localPath = self.localPath.rename(newPath) + def seek(self, offset: int): + self.file.seek(offset) - @staticmethod - def generate(remotePath: Path, outDir: Path): - localName = f"{names.get_first_name()}{names.get_last_name()}" - creationTime = datetime.datetime.now() + def write(self, data: bytes): + self.file.write(data) + self.written = True - index = 2 - suffix = "" + def getHash(self): + with open(self.dataPath, "rb") as f: + sha1 = hashlib.sha1() - while True: - if not (outDir / f"{localName}{suffix}").exists(): - break - else: - suffix = f"_{index}" - index += 1 + while True: + buffer = f.read(65536) - localName += suffix + if len(buffer) == 0: + break - return FileMapping(remotePath, outDir / localName, creationTime, "") + sha1.update(buffer) + return sha1.hexdigest() -class FileMappingEncoder(json.JSONEncoder): - """ - JSON encoder for FileMapping objects. - """ + def finalize(self): + self.log.debug("Closing file %(path)s", {"path": self.dataPath}) + self.file.close() - def default(self, o): - if isinstance(o, datetime.datetime): - return o.isoformat() - elif not isinstance(o, FileMapping): - return super().default(o) + fileHash = self.getHash() - return { - "remotePath": str(o.remotePath), - "localPath": str(o.localPath), - "creationTime": o.creationTime, - "sha1": o.hash - } + # Go up one directory because files are saved to outDir / tmp while we're downloading them + hashPath = (self.dataPath.parents[1] / fileHash) + # Don't keep the file if we haven't written anything to it or it's a duplicate, otherwise rename and move to files dir + if not self.written or hashPath.exists(): + self.dataPath.unlink() + else: + self.dataPath = self.dataPath.rename(hashPath) -class FileMappingDecoder(json.JSONDecoder): - """ - JSON decoder for FileMapping objects. - """ + # Whether it's a duplicate or a new file, we need to create a link to it in the filesystem clone + if self.written: + self.filesystemPath.parents[0].mkdir(exist_ok=True) + self.filesystemPath.unlink(missing_ok=True) + self.filesystemPath.symlink_to(hashPath) + + self.log.info("SHA1 '%(path)s' = '%(hash)s'", { + "path": self.filesystemPath.relative_to(self.filesystemDir), "hash": fileHash + }) + + @staticmethod + def generate(remotePath: str, outDir: Path, filesystemDir: Path, log: LoggerAdapter): + remotePath = Path(remotePath.replace("\\", "/")) + filesystemPath = filesystemDir / remotePath.relative_to("/") + + tmpOutDir = outDir / "tmp" + tmpOutDir.mkdir(exist_ok=True) - def __init__(self): - super().__init__(object_hook=self.decodeFileMapping) + handle, tmpPath = tempfile.mkstemp("", "", tmpOutDir) + file = open(handle, "wb") - def decodeFileMapping(self, dct: Dict): - for key in ["remotePath", "localPath", "creationTime"]: - if key not in dct: - return dct + log.info("Saving file '%(remotePath)s' to '%(localPath)s'", { + "localPath": tmpPath, "remotePath": remotePath + }) - creationTime = datetime.datetime.strptime(dct["creationTime"], "%Y-%m-%dT%H:%M:%S.%f") - return FileMapping(Path(dct["remotePath"]), Path(dct["localPath"]), creationTime, dct["sha1"]) \ No newline at end of file + return FileMapping(file, Path(tmpPath), filesystemPath, filesystemDir, log) diff --git a/pyrdp/mitm/config.py b/pyrdp/mitm/config.py index a88d28fd2..f12a5cf4b 100644 --- a/pyrdp/mitm/config.py +++ b/pyrdp/mitm/config.py @@ -96,6 +96,13 @@ def fileDir(self) -> Path: """ return self.outDir / "files" + @property + def filesystemDir(self) -> Path: + """ + Get the directory for filesystem clones. + """ + return self.outDir / "filesystems" + @property def certDir(self) -> Path: """ diff --git a/test/test_DeviceRedirectionMITM.py b/test/test_DeviceRedirectionMITM.py index 179a70f26..cd347de47 100644 --- a/test/test_DeviceRedirectionMITM.py +++ b/test/test_DeviceRedirectionMITM.py @@ -1,12 +1,12 @@ import unittest from pathlib import Path -from unittest.mock import Mock, MagicMock, patch, mock_open +from unittest.mock import Mock, MagicMock, patch from pyrdp.enum import CreateOption, FileAccessMask, IOOperationSeverity, DeviceRedirectionPacketID, MajorFunction, \ MinorFunction from pyrdp.logging.StatCounter import StatCounter, STAT from pyrdp.mitm.DeviceRedirectionMITM import DeviceRedirectionMITM -from pyrdp.pdu import DeviceIOResponsePDU, DeviceRedirectionPDU, DeviceQueryDirectoryRequestPDU +from pyrdp.pdu import DeviceIOResponsePDU, DeviceRedirectionPDU def MockIOError(): @@ -14,7 +14,6 @@ def MockIOError(): return ioError -@patch("builtins.open", new_callable=mock_open) class DeviceRedirectionMITMTest(unittest.TestCase): def setUp(self): self.client = Mock() @@ -26,7 +25,12 @@ def setUp(self): self.state.config.outDir = Path("/tmp") self.mitm = DeviceRedirectionMITM(self.client, self.server, self.log, self.statCounter, self.state) - def test_stats(self, *args): + @patch("pyrdp.mitm.FileMapping.FileMapping.generate") + def sendCreateResponse(self, request, response, generate): + self.mitm.handleCreateResponse(request, response) + return generate + + def test_stats(self): self.mitm.handlePDU = Mock() self.mitm.statCounter = StatCounter() @@ -58,7 +62,7 @@ def test_stats(self, *args): self.mitm.sendForgedDirectoryListing(Mock(), MagicMock()) self.assertEqual(self.mitm.statCounter.stats[STAT.DEVICE_REDIRECTION_FORGED_DIRECTORY_LISTING], 1) - def test_ioError_showsWarning(self, *args): + def test_ioError_showsWarning(self): self.log.warning = Mock() error = MockIOError() @@ -66,7 +70,7 @@ def test_ioError_showsWarning(self, *args): self.mitm.handleIOResponse(error) self.log.warning.assert_called_once() - def test_deviceListAnnounce_logsDevices(self, *args): + def test_deviceListAnnounce_logsDevices(self): pdu = Mock() pdu.deviceList = [Mock(), Mock(), Mock()] @@ -76,7 +80,7 @@ def test_deviceListAnnounce_logsDevices(self, *args): self.assertEqual(self.log.info.call_count, len(pdu.deviceList)) self.assertEqual(self.mitm.observer.onDeviceAnnounce.call_count, len(pdu.deviceList)) - def test_handleClientLogin_logsCredentials(self, *args): + def test_handleClientLogin_logsCredentials(self): creds = "PASSWORD" self.log.info = Mock() @@ -100,7 +104,7 @@ def test_handleClientLogin_logsCredentials(self, *args): self.mitm.handlePDU(pdu, self.client) self.mitm.handleClientLogin.assert_called_once() - def test_handleIOResponse_uniqueResponse(self, *args): + def test_handleIOResponse_uniqueResponse(self): handler = Mock() self.mitm.responseHandlers[1234] = handler @@ -113,7 +117,7 @@ def test_handleIOResponse_uniqueResponse(self, *args): self.mitm.handleIOResponse(pdu) handler.assert_called_once() - def test_handleIOResponse_matchingOnly(self, *args): + def test_handleIOResponse_matchingOnly(self): handler = Mock() self.mitm.responseHandlers[1234] = handler @@ -138,20 +142,21 @@ def test_handleIOResponse_matchingOnly(self, *args): self.log.error.assert_called_once() self.log.error.reset_mock() - def test_handlePDU_hidesForgedResponses(self, *args): + def test_handlePDU_hidesForgedResponses(self): + majorFunction = MajorFunction.IRP_MJ_CREATE handler = Mock() completionID = self.mitm.sendForgedFileRead(0, "forged") request = self.mitm.forgedRequests[(0, completionID)] - request.handlers[1234] = handler + request.handlers[majorFunction] = handler self.assertEqual(len(self.mitm.forgedRequests), 1) - response = Mock(deviceID = 0, completionID = completionID, majorFunction = 1234, ioStatus = 0) + response = Mock(deviceID = 0, completionID = completionID, majorFunction = majorFunction, ioStatus = 0) response.__class__ = DeviceIOResponsePDU self.mitm.handlePDU(response, self.mitm.server) handler.assert_called_once() self.mitm.server.sendPDU.assert_not_called() - def test_handleCreateResponse_createsNoFile(self, mock_open): + def test_handleCreateResponse_createsMapping(self): createRequest = Mock( deviceID = 0, completionID = 0, @@ -161,13 +166,11 @@ def test_handleCreateResponse_createsNoFile(self, mock_open): ) createResponse = Mock(deviceID = 0, completionID = 0, fileID = 0) - with patch("pyrdp.mitm.FileMapping.FileMapping.generate") as generate: - self.mitm.handleCreateResponse(createRequest, createResponse) - self.assertEqual(len(self.mitm.openedFiles), 1) - generate.assert_called_once() - mock_open.assert_not_called() + generate = self.sendCreateResponse(createRequest, createResponse) + self.assertEqual(len(self.mitm.mappings), 1) + generate.assert_called_once() - def test_handleReadResponse_createsFile(self, mock_open): + def test_handleReadResponse_writesData(self): request = Mock( deviceID = 0, completionID = 0, @@ -179,51 +182,19 @@ def test_handleReadResponse_createsFile(self, mock_open): response = Mock(deviceID = 0, completionID = 0, fileID = 0, payload = "test payload") self.mitm.saveMapping = Mock() - with patch("pyrdp.mitm.FileMapping.FileMapping.generate") as generate: - self.mitm.handleCreateResponse(request, response) - self.mitm.handleReadResponse(request, response) - mock_open.assert_called_once() - self.mitm.saveMapping.assert_called_once() + self.sendCreateResponse(request, response) + mapping = list(self.mitm.mappings.values())[0] + mapping.write = Mock() - # Make sure it checks the file ID - request.fileID, response.fileID = 1, 1 - mock_write = Mock() - list(self.mitm.openedFiles.values())[0].write = mock_write - self.mitm.handleReadResponse(request, response) - mock_write.assert_not_called() + self.mitm.handleReadResponse(request, response) + mapping.write.assert_called_once() - def test_handleCloseResponse_closesFile(self, mock_open): - request = Mock( - deviceID=0, - completionID=0, - fileID=0, - desiredAccess=(FileAccessMask.GENERIC_READ | FileAccessMask.FILE_READ_DATA), - createOptions=CreateOption.FILE_NON_DIRECTORY_FILE, - path="file", - ) - response = Mock(deviceID=0, completionID=0, fileID=0, payload="test payload") - self.mitm.saveMapping = Mock() - - with patch("pyrdp.mitm.FileMapping.FileMapping.generate") as generate: - close = Mock() - - self.mitm.handleCreateResponse(request, response) - - mapping = list(self.mitm.openedMappings.values())[0] - mapping.renameToHash = Mock() - self.mitm.fileMap[mapping.localPath.name] = Mock() - - file = list(self.mitm.openedFiles.values())[0] - file.close = close - file.file = Mock() + # Make sure it checks the file ID + request.fileID, response.fileID = 1, 1 + self.mitm.handleReadResponse(request, response) + mapping.write.assert_called_once() - self.mitm.handleCloseResponse(request, response) - - close.assert_called_once() - mapping.renameToHash.assert_called_once() - self.mitm.saveMapping.assert_called_once() - - def test_handleCloseResponse_removesDuplicates(self, mock_open): + def test_handleCloseResponse_finalizesMapping(self): request = Mock( deviceID=0, completionID=0, @@ -234,29 +205,30 @@ def test_handleCloseResponse_removesDuplicates(self, mock_open): ) response = Mock(deviceID=0, completionID=0, fileID=0, payload="test payload") self.mitm.saveMapping = Mock() - hash = "hash" - with patch("pyrdp.mitm.FileMapping.FileMapping.generate") as generate, patch("hashlib.sha1") as sha1: - sha1.return_value.hexdigest = Mock(return_value = hash) - self.mitm.handleCreateResponse(request, response) + self.sendCreateResponse(request, response) + mapping = list(self.mitm.mappings.values())[0] + mapping.finalize = Mock() - list(self.mitm.openedFiles.values())[0].file = Mock() - mapping = list(self.mitm.openedMappings.values())[0] - mapping.localPath.unlink = Mock() - self.mitm.fileMap[mapping.localPath.name] = Mock() - self.mitm.fileMap["duplicate"] = Mock(hash = hash) + self.mitm.handleCloseResponse(request, response) - self.mitm.handleCloseResponse(request, response) - mapping.localPath.unlink.assert_called_once() - self.mitm.saveMapping.assert_called_once() + mapping.finalize.assert_called_once() - def test_findNextRequestID_incrementsRequestID(self, *args): + def test_findNextRequestID_incrementsRequestID(self): baseID = self.mitm.findNextRequestID() self.mitm.sendForgedFileRead(0, Mock()) self.assertEqual(self.mitm.findNextRequestID(), baseID + 1) self.mitm.sendForgedFileRead(1, Mock()) self.assertEqual(self.mitm.findNextRequestID(), baseID + 2) + def test_sendForgedFileRead_failsWhenDisabled(self): + self.mitm.config.extractFiles = False + self.assertFalse(self.mitm.sendForgedFileRead(1, "/test")) + + def test_sendForgedDirectoryListing_failsWhenDisabled(self): + self.mitm.config.extractFiles = False + self.assertFalse(self.mitm.sendForgedDirectoryListing(1, "/")) + class ForgedRequestTest(unittest.TestCase): def setUp(self): diff --git a/test/test_FileMapping.py b/test/test_FileMapping.py new file mode 100644 index 000000000..db385f846 --- /dev/null +++ b/test/test_FileMapping.py @@ -0,0 +1,85 @@ +import unittest +from pathlib import Path +from unittest.mock import Mock, patch, MagicMock, mock_open + +from pyrdp.mitm.FileMapping import FileMapping + + +class FileMappingTest(unittest.TestCase): + def setUp(self): + self.log = Mock() + self.outDir = Path("test/") + self.hash = "testHash" + + @patch("builtins.open", new_callable=mock_open) + @patch("tempfile.mkstemp") + @patch("pathlib.Path.mkdir") + def createMapping(self, mkdir: MagicMock, mkstemp: MagicMock, mock_open_object): + mkstemp.return_value = (1, str(self.outDir / "tmp" / "tmp_test")) + mapping = FileMapping.generate("/test", self.outDir, Path("filesystems"), self.log) + mapping.getHash = Mock(return_value = self.hash) + return mapping, mkdir, mkstemp, mock_open_object + + def test_generate_createsTempFile(self): + mapping, mkdir, mkstemp, mock_open_object = self.createMapping() + mkstemp.return_value = (1, str(self.outDir / "tmp" / "tmp_test")) + + mkdir.assert_called_once_with(exist_ok = True) + mkstemp.assert_called_once() + mock_open_object.assert_called_once() + + tmpDir = mkstemp.call_args[0][-1] + self.assertEqual(tmpDir, self.outDir / "tmp") + + def test_write_setsWritten(self): + mapping, *_ = self.createMapping() + self.assertFalse(mapping.written) + mapping.write(b"data") + self.assertTrue(mapping.written) + + def test_finalize_removesUnwrittenFiles(self): + mapping, *_ = self.createMapping() + + with patch("pathlib.Path.unlink", autospec=True) as mock_unlink: + mapping.finalize() + self.assertTrue(any(args[0][0] == mapping.dataPath for args in mock_unlink.call_args_list)) + + @patch("pathlib.Path.exists", new_callable=lambda: Mock(return_value=True)) + @patch("pathlib.Path.symlink_to") + @patch("pathlib.Path.mkdir") + def test_finalize_removesDuplicates(self, *_): + mapping, *_ = self.createMapping() + mapping.write(b"data") + + with patch("pathlib.Path.unlink", autospec=True) as mock_unlink: + mapping.finalize() + self.assertTrue(any(args[0][0] == mapping.dataPath for args in mock_unlink.call_args_list)) + + @patch("pathlib.Path.unlink") + @patch("pathlib.Path.exists", new_callable=lambda: Mock(return_value=False)) + @patch("pathlib.Path.symlink_to") + @patch("pathlib.Path.mkdir") + def test_finalize_movesFileToOutDir(self, *_): + mapping, *_ = self.createMapping() + mapping.write(b"data") + + with patch("pathlib.Path.rename") as mock_rename: + mapping.finalize() + mock_rename.assert_called_once() + self.assertEqual(mock_rename.call_args[0][0].parents[0], self.outDir) + + @patch("pathlib.Path.rename") + @patch("pathlib.Path.unlink") + @patch("pathlib.Path.exists", new_callable=lambda: Mock(return_value=False)) + def test_finalize_createsSymlink(self, *_): + mapping, *_ = self.createMapping() + mapping.write(b"data") + + with patch("pathlib.Path.symlink_to") as mock_symlink_to, patch("pathlib.Path.mkdir", autospec=True) as mock_mkdir: + mapping.finalize() + + mock_mkdir.assert_called_once() + mock_symlink_to.assert_called_once() + + self.assertEqual(mock_mkdir.call_args[0][0], mapping.filesystemPath.parents[0]) + self.assertEqual(mock_symlink_to.call_args[0][0], self.outDir / self.hash) From b0dc149b371b4325ba27d2a83f81c8a8e6ed4f3e Mon Sep 17 00:00:00 2001 From: Francis Labelle Date: Wed, 23 Dec 2020 12:41:26 -0500 Subject: [PATCH 17/17] Remove missing_ok for Python 3.7 compatibility --- pyrdp/mitm/FileMapping.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pyrdp/mitm/FileMapping.py b/pyrdp/mitm/FileMapping.py index 9550b6d55..8d0c1baaf 100644 --- a/pyrdp/mitm/FileMapping.py +++ b/pyrdp/mitm/FileMapping.py @@ -70,7 +70,10 @@ def finalize(self): # Whether it's a duplicate or a new file, we need to create a link to it in the filesystem clone if self.written: self.filesystemPath.parents[0].mkdir(exist_ok=True) - self.filesystemPath.unlink(missing_ok=True) + + if self.filesystemPath.exists(): + self.filesystemPath.unlink() + self.filesystemPath.symlink_to(hashPath) self.log.info("SHA1 '%(path)s' = '%(hash)s'", {