diff --git a/pyrdp/logging/log.py b/pyrdp/logging/log.py index 308017a0c..76b3d1240 100644 --- a/pyrdp/logging/log.py +++ b/pyrdp/logging/log.py @@ -17,6 +17,7 @@ class LOGGER_NAMES: PYRDP = "pyrdp" MITM = f"{PYRDP}.mitm" MITM_CONNECTIONS = f"{MITM}.connections" + MITM_FAKE_SERVER = f"{MITM}.fake_server" PLAYER = f"{PYRDP}.player" PLAYER_UI = f"{PLAYER}.ui" NTLMSSP = f"ntlmssp" diff --git a/pyrdp/mitm/FakeServer.py b/pyrdp/mitm/FakeServer.py index e824e079a..472dd8b64 100644 --- a/pyrdp/mitm/FakeServer.py +++ b/pyrdp/mitm/FakeServer.py @@ -3,13 +3,13 @@ # Copyright (C) 2022 # Licensed under the GPLv3 or later. # -import multiprocessing, os, random, shutil, socket, subprocess, threading, time +import logging, multiprocessing, os, random, shutil, socket, subprocess, threading, time from tkinter import * from PIL import Image, ImageTk from pyvirtualdisplay import Display -from logging import LoggerAdapter +from pyrdp.logging import SessionLogger, LOGGER_NAMES BACKGROUND_COLOR = "#044a91" IMAGES_DIR = os.path.dirname(__file__) + "/images" @@ -159,11 +159,14 @@ def show_loading_animation(self, index): class FakeServer(threading.Thread): - def __init__(self, targetHost: str, targetPort: int, log: LoggerAdapter): + def __init__(self, targetHost: str, targetPort: int = 3389, sessionID: str = None): super().__init__() self.targetHost = targetHost self.targetPort = targetPort - self.log = log + self.log = SessionLogger( + logging.getLogger(LOGGER_NAMES.MITM_FAKE_SERVER), sessionID + ) + self.log.info("test") self._launch_display() diff --git a/pyrdp/mitm/RDPMITM.py b/pyrdp/mitm/RDPMITM.py index 5452f2d3e..237b936df 100644 --- a/pyrdp/mitm/RDPMITM.py +++ b/pyrdp/mitm/RDPMITM.py @@ -78,7 +78,7 @@ def __init__(self, mainLogger: SessionLogger, crawlerLogger: SessionLogger, conf self.statCounter = StatCounter() """Class to keep track of connection-related statistics such as # of mouse events, # of output events, etc.""" - self.state = state if state is not None else RDPMITMState(self.config, self.log.sessionID, self.getLog) + self.state = state if state is not None else RDPMITMState(self.config, self.log.sessionID) """The MITM state""" self.client = RDPLayerSet() diff --git a/pyrdp/mitm/state.py b/pyrdp/mitm/state.py index a68fff032..50085c223 100644 --- a/pyrdp/mitm/state.py +++ b/pyrdp/mitm/state.py @@ -4,13 +4,12 @@ # Licensed under the GPLv3 or later. # -from typing import Callable, Dict, List, Optional +from typing import Dict, List, Optional from Crypto.PublicKey import RSA from pyrdp.enum import NegotiationProtocols, ParserMode from pyrdp.layer import FastPathLayer, SecurityLayer, TLSSecurityLayer -from pyrdp.logging import SessionLogger from pyrdp.parser import createFastPathParser from pyrdp.pdu import ClientChannelDefinition from pyrdp.security import RC4CrypterProxy, SecuritySettings @@ -22,7 +21,7 @@ class RDPMITMState: State object for the RDP MITM. This is for data that needs to be shared across components. """ - def __init__(self, config: MITMConfig, sessionID: str, getLog: Callable[[str], SessionLogger]): + def __init__(self, config: MITMConfig, sessionID: str): self.requestedProtocols: Optional[NegotiationProtocols] = None """The original request protocols""" @@ -94,9 +93,6 @@ def __init__(self, config: MITMConfig, sessionID: str, getLog: Callable[[str], S self.fakeServer = None """The current fake server""" - self.getLog = getLog - """Function to create additional loggers""" - self.securitySettings.addObserver(self.crypters[ParserMode.CLIENT]) self.securitySettings.addObserver(self.crypters[ParserMode.SERVER]) @@ -139,9 +135,12 @@ def useRedirectionHost(self): def useFakeServer(self): from pyrdp.mitm.FakeServer import FakeServer + self.fakeServer = FakeServer( - self.config.targetHost, self.config.targetPort, self.getLog("") + self.config.targetHost, + targetPort=self.config.targetPort, + sessionID=self.sessionID, ) self.effectiveTargetHost = "127.0.0.1" self.effectiveTargetPort = self.fakeServer.port - self.fakeServer.start() \ No newline at end of file + self.fakeServer.start() diff --git a/test/test_prerecorded.py b/test/test_prerecorded.py index 61abc9283..5cb801195 100644 --- a/test/test_prerecorded.py +++ b/test/test_prerecorded.py @@ -122,7 +122,7 @@ def sendBytesStub(_: bytes): config.outDir = output_directory # replay_transport = FileLayer(output_path) - state = RDPMITMState(config, log.sessionID, lambda name : log.createChild(name)) + state = RDPMITMState(config, log.sessionID) super().__init__(log, log, config, state, CustomMITMRecorder([], state)) self.client.tcp.sendBytes = sendBytesStub