diff --git a/pyrdp/logging/StatCounter.py b/pyrdp/logging/StatCounter.py new file mode 100644 index 000000000..6d6219caa --- /dev/null +++ b/pyrdp/logging/StatCounter.py @@ -0,0 +1,163 @@ +# +# This file is part of the PyRDP project. +# Copyright (C) 2019 GoSecure Inc. +# Licensed under the GPLv3 or later. +# + +import time +from logging import LoggerAdapter + + +class STAT: + """ + Type of statistics that a StatCounter object can hold. + """ + + CONNECTION_TIME = "connectionTime" + # Duration (in secs) for the TCP connection + + CLIENT_SERVER_RATIO = "clientServerRatio" + # Ratio of the # of messages coming from the client vs from the server. High value (>1) means high client interaction. + + TOTAL_INPUT = "totalInput" + # # of messages coming from the client to the server. + + TOTAL_OUTPUT = "totalOutput" + # # of messages coming from the server to the client. + + IO_INPUT = "input" + # Packet coming from the client to the server for the io channel + + IO_INPUT_FASTPATH = "fastPathInput" + # Packet coming from the client to the server for the io channel as a fastpath packet + + IO_INPUT_SLOWPATH = "slowPathInput" + # Packet coming from the client to the server for the io channel as a slowpath packet + + IO_OUTPUT = "output" + # Packet coming from the server to the client for the io channel + + IO_OUTPUT_FASTPATH = "fastPathOutput" + # Packet coming from the server to the client for the io channel as a fastpath packet + + IO_OUTPUT_SLOWPATH = "slowPathOutput" + # Packet coming from the server to the client for the io channel as a slowpath packet + + MCS = "mcs" + # Packet Coming from either end for any channel + + MCS_OUTPUT = "mcsOutput" + # Packet Coming from the server to the client for any channel + + MCS_OUTPUT_ = "mcsOutput_" + # Packet Coming from the server to the client for a given channel (must append channel # after it) + + MCS_INPUT = "mcsInput" + # Packet Coming from the client to the server for any channel + + MCS_INPUT_ = "mcsInput_" + # Packet Coming from the client to the server for a given channel (must append channel # after it) + + VIRTUAL_CHANNEL = "virtualChannel" + # Packet Coming from either end for any virtual channel that doesnt have a specific implementation (ex clipboard) + + VIRTUAL_CHANNEL_INPUT = "virtualChannelInput" + # Packet Coming from the client to the server for any virtual channel that doesnt have a specific implementation (ex clipboard) + + VIRTUAL_CHANNEL_OUTPUT = "virtualChannelOutput" + # Packet Coming from the server to the client for any virtual channel that doesnt have a specific implementation (ex clipboard) + + DEVICE_REDIRECTION = "deviceRedirection" + # Packet coming from either end for the rdpdr channel + + DEVICE_REDIRECTION_CLIENT = "deviceRedirectionClient" + # Packet coming from the client to the server for the rdpdr channel + + DEVICE_REDIRECTION_SERVER = "deviceRedirectionServer" + # Packet coming from the server to the client for the rdpdr channel + + DEVICE_REDIRECTION_IOREQUEST = "deviceRedirectionIORequest" + # IORequest packets for the rdpdr channel + + DEVICE_REDIRECTION_IORESPONSE = "deviceRedirectionIOResponse" + # IOResponse packets for the rdpdr channel + + DEVICE_REDIRECTION_IOERROR = "deviceRedirectionIOError" + # IO error packets for the rdpdr channel + + DEVICE_REDIRECTION_FILE_CLOSE = "deviceRedirectionFileClose" + # File Close packets for the rdpdr channel + + DEVICE_REDIRECTION_FORGED_FILE_READ = "deviceRedirectionForgedFileRead" + # File read packets forged by pyrdp for the rdpdr channel + + DEVICE_REDIRECTION_FORGED_DIRECTORY_LISTING = "deviceRedirectionForgedDirectoryListing" + # Directory listing packets forged by pyrdp for the rdpdr channel + + CLIPBOARD = "clipboard" + # Number of clipboard PDUs coming from either end + + CLIPBOARD_CLIENT = "clipboardClient" + # Number of clipboard PDUs coming from the client + + CLIPBOARD_SERVER = "clipboardServer" + # Number of clipboard PDUs coming from the server + + CLIPBOARD_COPY = "clipboardCopies" + # Number of times data has been copied by either end + + CLIPBOARD_PASTE = "clipboardPastes" + # Number of times data has been pasted by either end + + +class StatCounter: + """ + Class that keeps track of various statistics during an RDP connection (See STAT) + """ + + def __init__(self): + self.stats = {"report": 1.0} # 1.0 = True + + def increment(self, *args: str): + """ + Increments all statistics passed in arguments + :param args: list of statistics to increment by one. See STAT for list of allowed values. + """ + for stat in args: + if stat not in self.stats: + self.stats[stat] = 0 + self.stats[stat] += 1 + + def incrementWith(self, statDestination: str, *statsSource: str): + """ + Increments statDestination by all provided statSources + """ + if statDestination not in self.stats: + self.stats[statDestination] = 0 + for statSource in statsSource: + if statSource in self.stats: + self.stats[statDestination] += self.stats[statSource] + + def start(self): + """ + Initialize some statistics such as connectionTime + """ + self.stats[STAT.CONNECTION_TIME] = time.time() + + def stop(self): + """ + Calculates the last statistics such as interaction ratio and connectionTime + """ + self.stats[STAT.CONNECTION_TIME] = time.time() - self.stats[STAT.CONNECTION_TIME] + self.incrementWith(STAT.TOTAL_INPUT, STAT.MCS_INPUT, STAT.IO_INPUT_FASTPATH, STAT.VIRTUAL_CHANNEL_INPUT, STAT.CLIPBOARD_CLIENT, STAT.DEVICE_REDIRECTION_CLIENT) + self.incrementWith(STAT.TOTAL_OUTPUT, STAT.MCS_OUTPUT, STAT.IO_OUTPUT_FASTPATH, STAT.VIRTUAL_CHANNEL_OUTPUT, STAT.CLIPBOARD_SERVER, STAT.DEVICE_REDIRECTION_SERVER) + if self.stats[STAT.TOTAL_OUTPUT] > 0: + self.stats[STAT.CLIENT_SERVER_RATIO] = self.stats[STAT.TOTAL_INPUT] / self.stats[STAT.TOTAL_OUTPUT] + + def logReport(self, log: LoggerAdapter): + """ + Create an INFO log message to log the Connection report using the keys in self.stats. + :param log: Logger to use to log the report + """ + keys = ", ".join([f"{key}: %({key})s" for key in self.stats.keys()]) + log.info(f"Connection report: {keys}", self.stats) diff --git a/pyrdp/mitm/BasePathMITM.py b/pyrdp/mitm/BasePathMITM.py index 55c813981..0e7b07512 100644 --- a/pyrdp/mitm/BasePathMITM.py +++ b/pyrdp/mitm/BasePathMITM.py @@ -9,17 +9,18 @@ from pyrdp.enum import ScanCode from pyrdp.pdu.pdu import PDU from pyrdp.layer.layer import Layer - +from pyrdp.logging.StatCounter import StatCounter, STAT class BasePathMITM: """ Base MITM component for the fast-path and slow-path layers. """ - def __init__(self, state: RDPMITMState, client: Layer, server: Layer): + def __init__(self, state: RDPMITMState, client: Layer, server: Layer, statCounter: StatCounter): self.state = state self.client = client self.server = server + self.statCounter = statCounter def onClientPDUReceived(self, pdu: PDU): raise NotImplementedError("onClientPDUReceived must be overridden") diff --git a/pyrdp/mitm/ClipboardMITM.py b/pyrdp/mitm/ClipboardMITM.py index 5dbc6d8f5..d90bb133d 100644 --- a/pyrdp/mitm/ClipboardMITM.py +++ b/pyrdp/mitm/ClipboardMITM.py @@ -9,6 +9,7 @@ from pyrdp.core import decodeUTF16LE from pyrdp.enum import ClipboardFormatNumber, ClipboardMessageFlags, ClipboardMessageType, PlayerPDUType from pyrdp.layer import ClipboardLayer +from pyrdp.logging.StatCounter import StatCounter, STAT from pyrdp.pdu import ClipboardPDU, FormatDataRequestPDU, FormatDataResponsePDU from pyrdp.recording import Recorder @@ -18,13 +19,15 @@ class PassiveClipboardStealer: MITM component for the clipboard layer. Logs clipboard data when it is pasted. """ - def __init__(self, client: ClipboardLayer, server: ClipboardLayer, log: LoggerAdapter, recorder: Recorder): + def __init__(self, client: ClipboardLayer, server: ClipboardLayer, log: LoggerAdapter, recorder: Recorder, + statCounter: StatCounter): """ :param client: clipboard layer for the client side :param server: clipboard layer for the server side :param log: logger for this component :param recorder: recorder for clipboard data """ + self.statCounter = statCounter self.client = client self.server = server self.log = log @@ -40,9 +43,11 @@ def __init__(self, client: ClipboardLayer, server: ClipboardLayer, log: LoggerAd ) def onClientPDUReceived(self, pdu: ClipboardPDU): + self.statCounter.increment(STAT.CLIPBOARD, STAT.CLIPBOARD_CLIENT) self.handlePDU(pdu, self.server) def onServerPDUReceived(self, pdu: ClipboardPDU): + self.statCounter.increment(STAT.CLIPBOARD, STAT.CLIPBOARD_SERVER) self.handlePDU(pdu, self.client) def handlePDU(self, pdu: ClipboardPDU, destination: ClipboardLayer): @@ -63,6 +68,10 @@ def handlePDU(self, pdu: ClipboardPDU, destination: ClipboardLayer): self.log.info("Clipboard data: %(clipboardData)r", {"clipboardData": clipboardData}) self.recorder.record(pdu, PlayerPDUType.CLIPBOARD_DATA) + if self.forwardNextDataResponse: + # Means it's NOT a crafted response + self.statCounter.increment(STAT.CLIPBOARD_PASTE) + self.forwardNextDataResponse = True def decodeClipboardData(self, data: bytes) -> str: @@ -79,8 +88,9 @@ class ActiveClipboardStealer(PassiveClipboardStealer): clipboard is updated. """ - def __init__(self, client: ClipboardLayer, server: ClipboardLayer, log: LoggerAdapter, recorder: Recorder): - super().__init__(client, server, log, recorder) + def __init__(self, client: ClipboardLayer, server: ClipboardLayer, log: LoggerAdapter, recorder: Recorder, + statCounter: StatCounter): + super().__init__(client, server, log, recorder, statCounter) def handlePDU(self, pdu: ClipboardPDU, destination: ClipboardLayer): """ @@ -99,6 +109,8 @@ def sendPasteRequest(self, destination: ClipboardLayer): Sets forwardNextDataResponse to False to make sure that this request is not actually transferred to the other end. """ + self.statCounter.increment(STAT.CLIPBOARD_COPY) + formatDataRequestPDU = FormatDataRequestPDU(ClipboardFormatNumber.GENERIC) destination.sendPDU(formatDataRequestPDU) - self.forwardNextDataResponse = False \ No newline at end of file + self.forwardNextDataResponse = False diff --git a/pyrdp/mitm/DeviceRedirectionMITM.py b/pyrdp/mitm/DeviceRedirectionMITM.py index 5702c1a4c..61cad4820 100644 --- a/pyrdp/mitm/DeviceRedirectionMITM.py +++ b/pyrdp/mitm/DeviceRedirectionMITM.py @@ -15,6 +15,7 @@ FileCreateDisposition, FileCreateOptions, FileShareAccess, FileSystemInformationClass, IOOperationSeverity, \ MajorFunction, MinorFunction from pyrdp.layer import DeviceRedirectionLayer +from pyrdp.logging.StatCounter import StatCounter, STAT from pyrdp.mitm.config import MITMConfig from pyrdp.mitm.FileMapping import FileMapping, FileMappingDecoder, FileMappingEncoder from pyrdp.mitm.state import RDPMITMState @@ -54,7 +55,8 @@ class DeviceRedirectionMITM(Subject): FORGED_COMPLETION_ID = 1000000 - def __init__(self, client: DeviceRedirectionLayer, server: DeviceRedirectionLayer, log: LoggerAdapter, config: MITMConfig, state: RDPMITMState): + def __init__(self, client: DeviceRedirectionLayer, server: DeviceRedirectionLayer, log: LoggerAdapter, + config: MITMConfig, statCounter: StatCounter, state: RDPMITMState): """ :param client: device redirection layer for the client side :param server: device redirection layer for the server side @@ -67,6 +69,7 @@ def __init__(self, client: DeviceRedirectionLayer, server: DeviceRedirectionLaye self.server = server self.state = state self.log = log + self.statCounter = statCounter self.config = config self.currentIORequests: Dict[int, DeviceIORequestPDU] = {} self.openedFiles: Dict[int, FileProxy] = {} @@ -108,9 +111,11 @@ def saveMapping(self): f.write(json.dumps(self.fileMap, cls=FileMappingEncoder, indent=4, sort_keys=True)) def onClientPDUReceived(self, pdu: DeviceRedirectionPDU): + self.statCounter.increment(STAT.DEVICE_REDIRECTION, STAT.DEVICE_REDIRECTION_CLIENT) self.handlePDU(pdu, self.server) def onServerPDUReceived(self, pdu: DeviceRedirectionPDU): + self.statCounter.increment(STAT.DEVICE_REDIRECTION, STAT.DEVICE_REDIRECTION_SERVER) self.handlePDU(pdu, self.client) def handlePDU(self, pdu: DeviceRedirectionPDU, destination: DeviceRedirectionLayer): @@ -143,6 +148,7 @@ def handleIORequest(self, pdu: DeviceIORequestPDU): :param pdu: the device IO request """ + self.statCounter.increment(STAT.DEVICE_REDIRECTION_IOREQUEST) self.currentIORequests[pdu.completionID] = pdu def handleIOResponse(self, pdu: DeviceIOResponsePDU): @@ -151,6 +157,8 @@ def handleIOResponse(self, pdu: DeviceIOResponsePDU): :param pdu: the device IO response. """ + self.statCounter.increment(STAT.DEVICE_REDIRECTION_IORESPONSE) + if pdu.completionID in self.forgedRequests: request = self.forgedRequests[pdu.completionID] request.handleResponse(pdu) @@ -162,6 +170,7 @@ def handleIOResponse(self, pdu: DeviceIOResponsePDU): requestPDU = self.currentIORequests.pop(pdu.completionID) if pdu.ioStatus >> 30 == IOOperationSeverity.STATUS_SEVERITY_ERROR: + self.statCounter.increment(STAT.DEVICE_REDIRECTION_IOERROR) self.log.warning("Received an IO Response with an error IO status: %(responsePDU)s for request %(requestPDU)s", {"responsePDU": repr(pdu), "requestPDU": repr(requestPDU)}) if pdu.majorFunction in self.responseHandlers: @@ -240,6 +249,8 @@ def handleCloseResponse(self, request: DeviceCloseRequestPDU, _: DeviceCloseResp :param _: the device IO response to the request """ + self.statCounter.increment(STAT.DEVICE_REDIRECTION_FILE_CLOSE) + if request.fileID in self.openedFiles: file = self.openedFiles.pop(request.fileID) file.close() @@ -314,6 +325,8 @@ def sendForgedFileRead(self, deviceID: int, path: str) -> int: :param path: path of the file to download. The path should use '\' instead of '/' to separate directories. """ + self.statCounter.increment(STAT.DEVICE_REDIRECTION_FORGED_FILE_READ) + completionID = self.findNextRequestID() request = DeviceRedirectionMITM.ForgedFileReadRequest(deviceID, completionID, self, path) self.forgedRequests[completionID] = request @@ -332,6 +345,8 @@ def sendForgedDirectoryListing(self, deviceID: int, path: str) -> int: \Documents\* """ + self.statCounter.increment(STAT.DEVICE_REDIRECTION_FORGED_DIRECTORY_LISTING) + completionID = self.findNextRequestID() request = DeviceRedirectionMITM.ForgedDirectoryListingRequest(deviceID, completionID, self, path) self.forgedRequests[completionID] = request diff --git a/pyrdp/mitm/FastPathMITM.py b/pyrdp/mitm/FastPathMITM.py index bdc30157f..ed8caad61 100644 --- a/pyrdp/mitm/FastPathMITM.py +++ b/pyrdp/mitm/FastPathMITM.py @@ -5,6 +5,7 @@ # from pyrdp.layer import FastPathLayer +from pyrdp.logging.StatCounter import StatCounter, STAT from pyrdp.mitm.state import RDPMITMState from pyrdp.pdu import FastPathPDU, FastPathScanCodeEvent from pyrdp.player import keyboard @@ -16,23 +17,25 @@ class FastPathMITM(BasePathMITM): MITM component for the fast-path layer. """ - def __init__(self, client: FastPathLayer, server: FastPathLayer, state: RDPMITMState): + def __init__(self, client: FastPathLayer, server: FastPathLayer, state: RDPMITMState, statCounter: StatCounter): """ :param client: fast-path layer for the client side :param server: fast-path layer for the server side :param state: the MITM state. """ - super().__init__(state, client, server) + + super().__init__(state, client, server, statCounter) self.client.createObserver( - onPDUReceived = self.onClientPDUReceived, + onPDUReceived=self.onClientPDUReceived, ) self.server.createObserver( - onPDUReceived = self.onServerPDUReceived, + onPDUReceived=self.onServerPDUReceived, ) def onClientPDUReceived(self, pdu: FastPathPDU): + self.statCounter.increment(STAT.IO_INPUT_FASTPATH) if self.state.forwardInput: self.server.sendPDU(pdu) @@ -42,5 +45,6 @@ def onClientPDUReceived(self, pdu: FastPathPDU): self.onScanCode(event.scanCode, event.isReleased, event.rawHeaderByte & keyboard.KBDFLAGS_EXTENDED != 0) def onServerPDUReceived(self, pdu: FastPathPDU): + self.statCounter.increment(STAT.IO_OUTPUT_FASTPATH) if self.state.forwardOutput: - self.client.sendPDU(pdu) \ No newline at end of file + self.client.sendPDU(pdu) diff --git a/pyrdp/mitm/MCSMITM.py b/pyrdp/mitm/MCSMITM.py index 021ba4af9..d044849d6 100644 --- a/pyrdp/mitm/MCSMITM.py +++ b/pyrdp/mitm/MCSMITM.py @@ -10,6 +10,7 @@ from pyrdp.enum import ClientCapabilityFlag, EncryptionLevel, EncryptionMethod, HighColorDepth, MCSChannelName, \ PlayerPDUType, SupportedColorDepth from pyrdp.layer import MCSLayer +from pyrdp.logging.StatCounter import StatCounter, STAT from pyrdp.mcs import MCSClientChannel, MCSServerChannel from pyrdp.mitm.state import RDPMITMState from pyrdp.parser import ClientConnectionParser, GCCParser, ServerConnectionParser @@ -29,7 +30,8 @@ class MCSMITM: """ def __init__(self, client: MCSLayer, server: MCSLayer, state: RDPMITMState, recorder: Recorder, - buildChannelCallback: Callable[[MCSServerChannel, MCSClientChannel], None], log: LoggerAdapter): + buildChannelCallback: Callable[[MCSServerChannel, MCSClientChannel], None], + log: LoggerAdapter, statCounter: StatCounter): """ :param client: MCS layer for the client side :param server: MCS layer for the server side @@ -40,6 +42,7 @@ def __init__(self, client: MCSLayer, server: MCSLayer, state: RDPMITMState, reco """ self.log = log + self.statCounter = statCounter self.client = client self.server = server self.state = state @@ -151,7 +154,9 @@ def onConnectResponse(self, pdu: MCSConnectResponsePDU): for index in range(len(serverData.networkData.channels)): channelID = serverData.networkData.channels[index] - self.state.channelMap[channelID] = self.state.channelDefinitions[index].name + name = self.state.channelDefinitions[index].name + self.log.info("%(channelName)s <---> Channel #%(channelId)d", {"channelName": name, "channelId": channelID}) + self.state.channelMap[channelID] = name # Replace the server's public key with our own key so we can decrypt the incoming client random cert = serverData.securityData.serverCertificate @@ -183,7 +188,6 @@ def onConnectResponse(self, pdu: MCSConnectResponsePDU): self.client.sendPDU(modifiedMCSPDU) - def onErectDomainRequest(self, pdu: MCSErectDomainRequestPDU): """ Forward an erect domain request to the server. @@ -233,7 +237,10 @@ def onSendDataRequest(self, pdu: MCSSendDataRequestPDU): :param pdu: the send data request """ + self.statCounter.increment(STAT.MCS, STAT.MCS_INPUT) + if pdu.channelID in self.serverChannels: + self.statCounter.increment(STAT.MCS_INPUT_ + str(pdu.channelID)) self.clientChannels[pdu.channelID].recv(pdu.payload) def onSendDataIndication(self, pdu: MCSSendDataIndicationPDU): @@ -242,7 +249,10 @@ def onSendDataIndication(self, pdu: MCSSendDataIndicationPDU): :param pdu: the send data indication """ + self.statCounter.increment(STAT.MCS, STAT.MCS_OUTPUT) + if pdu.channelID in self.clientChannels: + self.statCounter.increment(STAT.MCS_OUTPUT_ + str(pdu.channelID)) self.serverChannels[pdu.channelID].recv(pdu.payload) def onClientDisconnectProviderUltimatum(self, pdu: MCSDisconnectProviderUltimatumPDU): diff --git a/pyrdp/mitm/mitm.py b/pyrdp/mitm/RDPMITM.py similarity index 96% rename from pyrdp/mitm/mitm.py rename to pyrdp/mitm/RDPMITM.py index 3cb9cc755..2c23f1b52 100644 --- a/pyrdp/mitm/mitm.py +++ b/pyrdp/mitm/RDPMITM.py @@ -17,6 +17,7 @@ from pyrdp.logging import RC4LoggingObserver from pyrdp.logging.adapters import SessionLogger from pyrdp.logging.observers import FastPathLogger, LayerLogger, MCSLogger, SecurityLogger, SlowPathLogger, X224Logger +from pyrdp.logging.StatCounter import StatCounter from pyrdp.mcs import MCSClientChannel, MCSServerChannel from pyrdp.mitm.AttackerMITM import AttackerMITM from pyrdp.mitm.ClipboardMITM import ActiveClipboardStealer @@ -65,6 +66,9 @@ def __init__(self, log: SessionLogger, config: MITMConfig): self.config = config """The MITM configuration""" + self.statCounter = StatCounter() + """Class to keep track of connection-related statistics such as # of mouse events, # of output events, etc.""" + self.state = RDPMITMState() """The MITM state""" @@ -84,19 +88,19 @@ def __init__(self, log: SessionLogger, config: MITMConfig): """MITM components for virtual channels""" serverConnector = self.connectToServer() - self.tcp = TCPMITM(self.client.tcp, self.server.tcp, self.player.tcp, self.getLog("tcp"), self.state, self.recorder, serverConnector) + self.tcp = TCPMITM(self.client.tcp, self.server.tcp, self.player.tcp, self.getLog("tcp"), self.state, self.recorder, serverConnector, self.statCounter) """TCP MITM component""" self.x224 = X224MITM(self.client.x224, self.server.x224, self.getLog("x224"), self.state, serverConnector, self.startTLS) """X224 MITM component""" - self.mcs = MCSMITM(self.client.mcs, self.server.mcs, self.state, self.recorder, self.buildChannel, self.getLog("mcs")) + self.mcs = MCSMITM(self.client.mcs, self.server.mcs, self.state, self.recorder, self.buildChannel, self.getLog("mcs"), self.statCounter) """MCS MITM component""" self.security: SecurityMITM = None """Security MITM component""" - self.slowPath = SlowPathMITM(self.client.slowPath, self.server.slowPath, self.state) + self.slowPath = SlowPathMITM(self.client.slowPath, self.server.slowPath, self.state, self.statCounter) """Slow-path MITM component""" self.fastPath: FastPathMITM = None @@ -154,8 +158,6 @@ def getServerLog(self, name: str) -> SessionLogger: """ return self.serverLog.createChild(name) - - async def connectToServer(self): """ Coroutine that connects to the target RDP server and the attacker. @@ -177,8 +179,6 @@ async def connectToServer(self): except asyncio.TimeoutError: self.log.error("Failed to connect to recording host: timeout expired") - - def startTLS(self): """ Execute a startTLS on both the client and server side. @@ -189,8 +189,6 @@ def startTLS(self): self.client.tcp.startTLS(contextForClient) self.server.tcp.startTLS(contextForServer) - - def buildChannel(self, client: MCSServerChannel, server: MCSClientChannel): """ Build a MITM component for an MCS channel. The client side has an MCSServerChannel because from the point of view @@ -234,7 +232,7 @@ def buildIOChannel(self, client: MCSServerChannel, server: MCSClientChannel): self.server.fastPath.addObserver(RecordingFastPathObserver(self.recorder, PlayerPDUType.FAST_PATH_OUTPUT)) self.security = SecurityMITM(self.client.security, self.server.security, self.getLog("security"), self.config, self.state, self.recorder) - self.fastPath = FastPathMITM(self.client.fastPath, self.server.fastPath, self.state) + self.fastPath = FastPathMITM(self.client.fastPath, self.server.fastPath, self.state, self.statCounter) if self.player.tcp.transport: self.attacker = AttackerMITM(self.client.fastPath, self.server.fastPath, self.player.player, self.log, self.state, self.recorder) @@ -274,7 +272,8 @@ def buildClipboardChannel(self, client: MCSServerChannel, server: MCSClientChann LayerChainItem.chain(client, clientSecurity, clientVirtualChannel, clientLayer) LayerChainItem.chain(server, serverSecurity, serverVirtualChannel, serverLayer) - mitm = ActiveClipboardStealer(clientLayer, serverLayer, self.getLog(MCSChannelName.CLIPBOARD), self.recorder) + mitm = ActiveClipboardStealer(clientLayer, serverLayer, self.getLog(MCSChannelName.CLIPBOARD), self.recorder, + self.statCounter) self.channelMITMs[client.channelID] = mitm def buildDeviceChannel(self, client: MCSServerChannel, server: MCSClientChannel): @@ -297,7 +296,7 @@ def buildDeviceChannel(self, client: MCSServerChannel, server: MCSClientChannel) LayerChainItem.chain(client, clientSecurity, clientVirtualChannel, clientLayer) LayerChainItem.chain(server, serverSecurity, serverVirtualChannel, serverLayer) - deviceRedirection = DeviceRedirectionMITM(clientLayer, serverLayer, self.getLog(MCSChannelName.DEVICE_REDIRECTION), self.config, self.state) + deviceRedirection = DeviceRedirectionMITM(clientLayer, serverLayer, self.getLog(MCSChannelName.DEVICE_REDIRECTION), self.config, self.statCounter, self.state) self.channelMITMs[client.channelID] = deviceRedirection if self.attacker: @@ -318,7 +317,7 @@ def buildVirtualChannel(self, client: MCSServerChannel, server: MCSClientChannel LayerChainItem.chain(client, clientSecurity, clientLayer) LayerChainItem.chain(server, serverSecurity, serverLayer) - mitm = VirtualChannelMITM(clientLayer, serverLayer) + mitm = VirtualChannelMITM(clientLayer, serverLayer, self.statCounter) self.channelMITMs[client.channelID] = mitm def sendPayload(self): diff --git a/pyrdp/mitm/SlowPathMITM.py b/pyrdp/mitm/SlowPathMITM.py index 8dc732ba2..3b6da52d1 100644 --- a/pyrdp/mitm/SlowPathMITM.py +++ b/pyrdp/mitm/SlowPathMITM.py @@ -6,6 +6,7 @@ from pyrdp.enum import CapabilityType, KeyboardFlag, OrderFlag, VirtualChannelCompressionFlag from pyrdp.layer import SlowPathLayer, SlowPathObserver +from pyrdp.logging.StatCounter import StatCounter, STAT from pyrdp.mitm.state import RDPMITMState from pyrdp.pdu import Capability, ConfirmActivePDU, DemandActivePDU, InputPDU, KeyboardEvent, SlowPathPDU from pyrdp.mitm.BasePathMITM import BasePathMITM @@ -15,12 +16,12 @@ class SlowPathMITM(BasePathMITM): MITM component for the slow-path layer. """ - def __init__(self, client: SlowPathLayer, server: SlowPathLayer, state: RDPMITMState): + def __init__(self, client: SlowPathLayer, server: SlowPathLayer, state: RDPMITMState, statCounter: StatCounter): """ :param client: slow-path layer for the client side :param server: slow-path layer for the server side """ - super().__init__(state, client, server) + super().__init__(state, client, server, statCounter) self.clientObserver = self.client.createObserver( onPDUReceived = self.onClientPDUReceived, @@ -33,6 +34,7 @@ def __init__(self, client: SlowPathLayer, server: SlowPathLayer, state: RDPMITMS ) def onClientPDUReceived(self, pdu: SlowPathPDU): + self.statCounter.increment(STAT.IO_INPUT_SLOWPATH) SlowPathObserver.onPDUReceived(self.clientObserver, pdu) if self.state.forwardInput: @@ -45,6 +47,7 @@ def onClientPDUReceived(self, pdu: SlowPathPDU): self.onScanCode(event.keyCode, event.flags & KeyboardFlag.KBDFLAGS_DOWN == 0, event.flags & KeyboardFlag.KBDFLAGS_EXTENDED != 0) def onServerPDUReceived(self, pdu: SlowPathPDU): + self.statCounter.increment(STAT.IO_OUTPUT_SLOWPATH) SlowPathObserver.onPDUReceived(self.serverObserver, pdu) if self.state.forwardOutput: diff --git a/pyrdp/mitm/TCPMITM.py b/pyrdp/mitm/TCPMITM.py index 6667eae4d..e9e2ff3e7 100644 --- a/pyrdp/mitm/TCPMITM.py +++ b/pyrdp/mitm/TCPMITM.py @@ -8,6 +8,7 @@ from typing import Coroutine from pyrdp.layer import TwistedTCPLayer +from pyrdp.logging.StatCounter import StatCounter from pyrdp.mitm.state import RDPMITMState from pyrdp.pdu.player import PlayerConnectionClosePDU from pyrdp.recording import Recorder @@ -18,7 +19,8 @@ class TCPMITM: MITM component for the TCP layer. """ - def __init__(self, client: TwistedTCPLayer, server: TwistedTCPLayer, attacker: TwistedTCPLayer, log: LoggerAdapter, state: RDPMITMState, recorder: Recorder, serverConnector: Coroutine): + def __init__(self, client: TwistedTCPLayer, server: TwistedTCPLayer, attacker: TwistedTCPLayer, log: LoggerAdapter, + state: RDPMITMState, recorder: Recorder, serverConnector: Coroutine, statCounter: StatCounter): """ :param client: TCP layer for the client side :param server: TCP layer for the server side @@ -28,8 +30,8 @@ def __init__(self, client: TwistedTCPLayer, server: TwistedTCPLayer, attacker: T :param serverConnector: coroutine that connects to the server side, closed when the client disconnects """ - self.connectionTime = 0 - # To keep track of the duration of the TCP connection. + self.statCounter = statCounter + # To keep track of useful statistics for the connection. self.client = client self.server = server @@ -66,7 +68,11 @@ def onClientConnection(self): """ Log the fact that a new client has connected. """ + + # Statistics + self.statCounter.start() self.connectionTime = time.time() + ip = self.client.transport.client[0] self.log.info("New client connected from %(clientIp)s", {"clientIp": ip}) @@ -76,11 +82,10 @@ def onClientDisconnection(self, reason): :param reason: reason for disconnection """ - self.connectionTime = time.time() - self.connectionTime - + self.statCounter.stop() self.recordConnectionClose() self.log.info("Client connection closed. %(reason)s", {"reason": reason.value}) - self.log.info("Client connection time: %(connectionTime)s secs", {"connectionTime": self.connectionTime}) + self.statCounter.logReport(self.log) self.serverConnector.close() self.server.disconnect(True) diff --git a/pyrdp/mitm/VirtualChannelMITM.py b/pyrdp/mitm/VirtualChannelMITM.py index 56a33f11a..7451e5b02 100644 --- a/pyrdp/mitm/VirtualChannelMITM.py +++ b/pyrdp/mitm/VirtualChannelMITM.py @@ -5,6 +5,8 @@ # from pyrdp.layer import RawLayer +from pyrdp.logging import StatCounter +from pyrdp.logging.StatCounter import STAT from pyrdp.pdu import PDU @@ -13,7 +15,7 @@ class VirtualChannelMITM: Generic MITM component for any virtual channel. """ - def __init__(self, client: RawLayer, server: RawLayer): + def __init__(self, client: RawLayer, server: RawLayer, statCounter: StatCounter): """ :param client: layer for the client side :param server: layer for the server side @@ -21,6 +23,7 @@ def __init__(self, client: RawLayer, server: RawLayer): self.client = client self.server = server + self.statCounter = statCounter self.client.createObserver( onPDUReceived = self.onClientPDUReceived @@ -36,6 +39,7 @@ def onClientPDUReceived(self, pdu: PDU): :param pdu: the PDU that was received """ + self.statCounter.increment(STAT.VIRTUAL_CHANNEL_INPUT) self.server.sendPDU(pdu) def onServerPDUReceived(self, pdu: PDU): @@ -44,4 +48,5 @@ def onServerPDUReceived(self, pdu: PDU): :param pdu: the PDU that was received """ - self.client.sendPDU(pdu) \ No newline at end of file + self.statCounter.increment(STAT.VIRTUAL_CHANNEL, STAT.VIRTUAL_CHANNEL_OUTPUT) + self.client.sendPDU(pdu) diff --git a/pyrdp/mitm/__init__.py b/pyrdp/mitm/__init__.py index d26ebd44e..50812d9c6 100644 --- a/pyrdp/mitm/__init__.py +++ b/pyrdp/mitm/__init__.py @@ -5,4 +5,4 @@ # from pyrdp.mitm.config import MITMConfig -from pyrdp.mitm.mitm import RDPMITM +from pyrdp.mitm.RDPMITM import RDPMITM