Skip to content

Commit

Permalink
Merge pull request #309 from GoSecure/fix-filemapping-disconnect
Browse files Browse the repository at this point in the history
Watch for disconnection events in DeviceRedirectionMITM
  • Loading branch information
xshill committed Apr 1, 2021
2 parents 8eb764d + 15e72b1 commit 64a3ab4
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 6 deletions.
19 changes: 17 additions & 2 deletions pyrdp/mitm/DeviceRedirectionMITM.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#
# This file is part of the PyRDP project.
# Copyright (C) 2019-2020 GoSecure Inc.
# Copyright (C) 2019-2021 GoSecure Inc.
# Licensed under the GPLv3 or later.
#

Expand All @@ -16,6 +16,7 @@
from pyrdp.logging.StatCounter import StatCounter, STAT
from pyrdp.mitm.FileMapping import FileMapping
from pyrdp.mitm.state import RDPMITMState
from pyrdp.mitm.TCPMITM import TCPMITM
from pyrdp.pdu import DeviceAnnounce, DeviceCloseRequestPDU, DeviceCloseResponsePDU, DeviceCreateRequestPDU, \
DeviceCreateResponsePDU, DeviceDirectoryControlResponsePDU, DeviceIORequestPDU, DeviceIOResponsePDU, \
DeviceListAnnounceRequest, DeviceQueryDirectoryRequestPDU, DeviceQueryDirectoryResponsePDU, DeviceReadRequestPDU, \
Expand Down Expand Up @@ -52,13 +53,14 @@ class DeviceRedirectionMITM(Subject):
FORGED_COMPLETION_ID = 1000000

def __init__(self, client: DeviceRedirectionLayer, server: DeviceRedirectionLayer, log: LoggerAdapter,
statCounter: StatCounter, state: RDPMITMState):
statCounter: StatCounter, state: RDPMITMState, tcp: TCPMITM):
"""
:param client: device redirection layer for the client side
:param server: device redirection layer for the server side
:param log: logger for this component
:param statCounter: stat counter object
:param state: shared RDP MITM state
:param tcp: TCP MITM component
"""
super().__init__()

Expand All @@ -67,6 +69,7 @@ def __init__(self, client: DeviceRedirectionLayer, server: DeviceRedirectionLaye
self.state = state
self.log = log
self.statCounter = statCounter
self.tcp = tcp
self.mappings: Dict[(int, int), FileMapping] = {}
self.filesystemRoot = self.config.filesystemDir / self.state.sessionID

Expand All @@ -90,6 +93,14 @@ def __init__(self, client: DeviceRedirectionLayer, server: DeviceRedirectionLaye
onPDUReceived=self.onServerPDUReceived,
)

self.tcp.client.createObserver(
onDisconnection=self.onDisconnection
)

self.tcp.server.createObserver(
onDisconnection=self.onDisconnection
)

def deviceRoot(self, deviceID: int):
return self.filesystemRoot / f"device{deviceID}"

Expand All @@ -102,6 +113,10 @@ def createDeviceRoot(self, deviceID: int):
def config(self):
return self.state.config

def onDisconnection(self, reason):
for mapping in self.mappings.values():
mapping.onDisconnection(reason)

def onClientPDUReceived(self, pdu: DeviceRedirectionPDU):
self.statCounter.increment(STAT.DEVICE_REDIRECTION, STAT.DEVICE_REDIRECTION_CLIENT)
self.handlePDU(pdu, self.server)
Expand Down
11 changes: 10 additions & 1 deletion pyrdp/mitm/FileMapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ def __init__(self, file: io.BinaryIO, dataPath: Path, filesystemPath: Path, file
self.written = False

def seek(self, offset: int):
self.file.seek(offset)
if not self.file.closed:
self.file.seek(offset)

def write(self, data: bytes):
self.file.write(data)
Expand All @@ -54,6 +55,9 @@ def getSha1Hash(self):
return sha1.hexdigest()

def finalize(self):
if self.file.closed:
return

self.log.debug("Closing file %(path)s", {"path": self.dataPath})
self.file.close()

Expand Down Expand Up @@ -82,6 +86,11 @@ def finalize(self):
"path": str(self.filesystemPath.relative_to(self.filesystemDir)), "shasum": fileHash
})

def onDisconnection(self, reason):
if not self.file.closed:
self.file.close()
Path(self.file.name).unlink(missing_ok=True)

@staticmethod
def generate(remotePath: str, outDir: Path, filesystemDir: Path, log: LoggerAdapter):
remotePath = Path(remotePath.replace("\\", "/"))
Expand Down
2 changes: 1 addition & 1 deletion pyrdp/mitm/RDPMITM.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def buildDeviceChannel(self, client: MCSServerChannel, server: MCSClientChannel)
LayerChainItem.chain(client, clientSecurity, clientVirtualChannel, clientLayer)
LayerChainItem.chain(server, serverSecurity, serverVirtualChannel, serverLayer)

deviceRedirection = DeviceRedirectionMITM(clientLayer, serverLayer, self.getLog(MCSChannelName.DEVICE_REDIRECTION), self.statCounter, self.state)
deviceRedirection = DeviceRedirectionMITM(clientLayer, serverLayer, self.getLog(MCSChannelName.DEVICE_REDIRECTION), self.statCounter, self.state, self.tcp)
self.channelMITMs[client.channelID] = deviceRedirection

if self.config.enableCrawler:
Expand Down
2 changes: 1 addition & 1 deletion test/test_DeviceRedirectionMITM.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def setUp(self):
self.state = Mock()
self.state.config = MagicMock()
self.state.config.outDir = Path("/tmp")
self.mitm = DeviceRedirectionMITM(self.client, self.server, self.log, self.statCounter, self.state)
self.mitm = DeviceRedirectionMITM(self.client, self.server, self.log, self.statCounter, self.state, Mock())

@patch("pyrdp.mitm.FileMapping.FileMapping.generate")
def sendCreateResponse(self, request, response, generate):
Expand Down
1 change: 1 addition & 0 deletions test/test_FileMapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def createMapping(self, mkdir: MagicMock, mkstemp: MagicMock, mock_open_object):
mkstemp.return_value = (1, str(self.outDir / "tmp" / "tmp_test"))
mapping = FileMapping.generate("/test", self.outDir, Path("filesystems"), self.log)
mapping.getSha1Hash = Mock(return_value = self.hash)
mapping.file.closed = False
return mapping, mkdir, mkstemp, mock_open_object

def test_generate_createsTempFile(self):
Expand Down
2 changes: 1 addition & 1 deletion test/test_X224MITM.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from pyrdp.pdu import X224ConnectionRequestPDU, NegotiationRequestPDU


class FileMappingTest(unittest.TestCase):
class X224MITMTest(unittest.TestCase):
def setUp(self):
self.mitm = X224MITM(Mock(), Mock(), Mock(), Mock(), Mock(), MagicMock())

Expand Down

0 comments on commit 64a3ab4

Please sign in to comment.