diff --git a/pyrdp/mitm/DeviceRedirectionMITM.py b/pyrdp/mitm/DeviceRedirectionMITM.py index 881c4fbf7..20e5c778b 100644 --- a/pyrdp/mitm/DeviceRedirectionMITM.py +++ b/pyrdp/mitm/DeviceRedirectionMITM.py @@ -1,6 +1,6 @@ # # This file is part of the PyRDP project. -# Copyright (C) 2019-2020 GoSecure Inc. +# Copyright (C) 2019-2021 GoSecure Inc. # Licensed under the GPLv3 or later. # @@ -16,6 +16,7 @@ from pyrdp.logging.StatCounter import StatCounter, STAT from pyrdp.mitm.FileMapping import FileMapping from pyrdp.mitm.state import RDPMITMState +from pyrdp.mitm.TCPMITM import TCPMITM from pyrdp.pdu import DeviceAnnounce, DeviceCloseRequestPDU, DeviceCloseResponsePDU, DeviceCreateRequestPDU, \ DeviceCreateResponsePDU, DeviceDirectoryControlResponsePDU, DeviceIORequestPDU, DeviceIOResponsePDU, \ DeviceListAnnounceRequest, DeviceQueryDirectoryRequestPDU, DeviceQueryDirectoryResponsePDU, DeviceReadRequestPDU, \ @@ -52,13 +53,14 @@ class DeviceRedirectionMITM(Subject): FORGED_COMPLETION_ID = 1000000 def __init__(self, client: DeviceRedirectionLayer, server: DeviceRedirectionLayer, log: LoggerAdapter, - statCounter: StatCounter, state: RDPMITMState): + statCounter: StatCounter, state: RDPMITMState, tcp: TCPMITM): """ :param client: device redirection layer for the client side :param server: device redirection layer for the server side :param log: logger for this component :param statCounter: stat counter object :param state: shared RDP MITM state + :param tcp: TCP MITM component """ super().__init__() @@ -67,6 +69,7 @@ def __init__(self, client: DeviceRedirectionLayer, server: DeviceRedirectionLaye self.state = state self.log = log self.statCounter = statCounter + self.tcp = tcp self.mappings: Dict[(int, int), FileMapping] = {} self.filesystemRoot = self.config.filesystemDir / self.state.sessionID @@ -90,6 +93,14 @@ def __init__(self, client: DeviceRedirectionLayer, server: DeviceRedirectionLaye onPDUReceived=self.onServerPDUReceived, ) + self.tcp.client.createObserver( + onDisconnection=self.onDisconnection + ) + + self.tcp.server.createObserver( + onDisconnection=self.onDisconnection + ) + def deviceRoot(self, deviceID: int): return self.filesystemRoot / f"device{deviceID}" @@ -102,6 +113,10 @@ def createDeviceRoot(self, deviceID: int): def config(self): return self.state.config + def onDisconnection(self, reason): + for mapping in self.mappings.values(): + mapping.onDisconnection(reason) + def onClientPDUReceived(self, pdu: DeviceRedirectionPDU): self.statCounter.increment(STAT.DEVICE_REDIRECTION, STAT.DEVICE_REDIRECTION_CLIENT) self.handlePDU(pdu, self.server) diff --git a/pyrdp/mitm/FileMapping.py b/pyrdp/mitm/FileMapping.py index 416918ec4..ac6fda09e 100644 --- a/pyrdp/mitm/FileMapping.py +++ b/pyrdp/mitm/FileMapping.py @@ -33,7 +33,8 @@ def __init__(self, file: io.BinaryIO, dataPath: Path, filesystemPath: Path, file self.written = False def seek(self, offset: int): - self.file.seek(offset) + if not self.file.closed: + self.file.seek(offset) def write(self, data: bytes): self.file.write(data) @@ -54,6 +55,9 @@ def getSha1Hash(self): return sha1.hexdigest() def finalize(self): + if self.file.closed: + return + self.log.debug("Closing file %(path)s", {"path": self.dataPath}) self.file.close() @@ -82,6 +86,11 @@ def finalize(self): "path": str(self.filesystemPath.relative_to(self.filesystemDir)), "shasum": fileHash }) + def onDisconnection(self, reason): + if not self.file.closed: + self.file.close() + Path(self.file.name).unlink(missing_ok=True) + @staticmethod def generate(remotePath: str, outDir: Path, filesystemDir: Path, log: LoggerAdapter): remotePath = Path(remotePath.replace("\\", "/")) diff --git a/pyrdp/mitm/RDPMITM.py b/pyrdp/mitm/RDPMITM.py index 6e7058a96..9f5503f58 100644 --- a/pyrdp/mitm/RDPMITM.py +++ b/pyrdp/mitm/RDPMITM.py @@ -357,7 +357,7 @@ def buildDeviceChannel(self, client: MCSServerChannel, server: MCSClientChannel) LayerChainItem.chain(client, clientSecurity, clientVirtualChannel, clientLayer) LayerChainItem.chain(server, serverSecurity, serverVirtualChannel, serverLayer) - deviceRedirection = DeviceRedirectionMITM(clientLayer, serverLayer, self.getLog(MCSChannelName.DEVICE_REDIRECTION), self.statCounter, self.state) + deviceRedirection = DeviceRedirectionMITM(clientLayer, serverLayer, self.getLog(MCSChannelName.DEVICE_REDIRECTION), self.statCounter, self.state, self.tcp) self.channelMITMs[client.channelID] = deviceRedirection if self.config.enableCrawler: diff --git a/test/test_DeviceRedirectionMITM.py b/test/test_DeviceRedirectionMITM.py index cf5a1da74..5d0672f2e 100644 --- a/test/test_DeviceRedirectionMITM.py +++ b/test/test_DeviceRedirectionMITM.py @@ -29,7 +29,7 @@ def setUp(self): self.state = Mock() self.state.config = MagicMock() self.state.config.outDir = Path("/tmp") - self.mitm = DeviceRedirectionMITM(self.client, self.server, self.log, self.statCounter, self.state) + self.mitm = DeviceRedirectionMITM(self.client, self.server, self.log, self.statCounter, self.state, Mock()) @patch("pyrdp.mitm.FileMapping.FileMapping.generate") def sendCreateResponse(self, request, response, generate): diff --git a/test/test_FileMapping.py b/test/test_FileMapping.py index 2c50b1c0d..58c00f77d 100644 --- a/test/test_FileMapping.py +++ b/test/test_FileMapping.py @@ -23,6 +23,7 @@ def createMapping(self, mkdir: MagicMock, mkstemp: MagicMock, mock_open_object): mkstemp.return_value = (1, str(self.outDir / "tmp" / "tmp_test")) mapping = FileMapping.generate("/test", self.outDir, Path("filesystems"), self.log) mapping.getSha1Hash = Mock(return_value = self.hash) + mapping.file.closed = False return mapping, mkdir, mkstemp, mock_open_object def test_generate_createsTempFile(self): diff --git a/test/test_X224MITM.py b/test/test_X224MITM.py index d6911c965..9e5813667 100644 --- a/test/test_X224MITM.py +++ b/test/test_X224MITM.py @@ -11,7 +11,7 @@ from pyrdp.pdu import X224ConnectionRequestPDU, NegotiationRequestPDU -class FileMappingTest(unittest.TestCase): +class X224MITMTest(unittest.TestCase): def setUp(self): self.mitm = X224MITM(Mock(), Mock(), Mock(), Mock(), Mock(), MagicMock())