diff --git a/dlls/ntdll/thread.c b/dlls/ntdll/thread.c index f73e141911..b1ec6762a6 100644 --- a/dlls/ntdll/thread.c +++ b/dlls/ntdll/thread.c @@ -149,6 +149,295 @@ static ULONG_PTR get_image_addr(void) } #endif +/* + * Patch for MH:W after Release of 12/3/2020 + * + * The idea of this patch is not to perform I/O + * operations with the wineserver when getting and + * setting the context of a given thread, but to + * keep a local copy cahed (max MHW_MAX_CONTEXTS) + */ + +#define MHW_MAX_TH_MAP (256) +#define MHW_MAX_CONTEXTS MHW_MAX_TH_MAP +#define MHW_STEAM_ID "582010" +#define MHW_DEBUG_LOG "MHW_DEBUG_LOG" +#define MHW_CHECK_TODO (-2) +#define MHW_CHECK_YES (1) +#define MHW_CHECK_NO (0) +#define MHW_UNLIKELY(x) __builtin_expect(!!(x),0) + +typedef struct { + HANDLE handles[MHW_MAX_TH_MAP]; + pthread_t th[MHW_MAX_TH_MAP]; + int cur_el; + pthread_mutex_t mtx; +} MHW_TH_MAP; + +typedef struct { + pthread_t th[MHW_MAX_CONTEXTS]; + context_t ctx[MHW_MAX_CONTEXTS]; + int cur_el; + pthread_mutex_t mtx; +} MHW_CTX_MAP; + +static MHW_TH_MAP mhw_th_map = { .cur_el=0, .mtx=PTHREAD_MUTEX_INITIALIZER}; +static MHW_CTX_MAP mhw_ctx_map = { .cur_el=0, .mtx=PTHREAD_MUTEX_INITIALIZER}; +static pthread_mutex_t mhw_ctx_mtx_debug = PTHREAD_MUTEX_INITIALIZER; +static FILE *mhw_fp_debug = 0; + +static void mhw_th_map_add(HANDLE h, pthread_t th) { + pthread_mutex_lock(&mhw_th_map.mtx); + if(mhw_th_map.cur_el < MHW_MAX_TH_MAP) { + const int ce = mhw_th_map.cur_el; + mhw_th_map.handles[ce] = h; + mhw_th_map.th[ce] = th; + ++mhw_th_map.cur_el; + } else { + FIXME( "Reached a limit for threads in MH:W patch" ); + } + pthread_mutex_unlock(&mhw_th_map.mtx); +} + +static pthread_t mhw_th_map_remove_by_h(HANDLE h) { + pthread_t rv = 0; + pthread_mutex_lock(&mhw_th_map.mtx); + for(int i = 0; i < mhw_th_map.cur_el; ++i) { + if(mhw_th_map.handles[i] == h) { + rv = mhw_th_map.th[i]; + /* if we have more than 1 element + * move the last in its place */ + if(mhw_th_map.cur_el > 1) { + const int last_el = mhw_th_map.cur_el - 1; + mhw_th_map.th[i] = mhw_th_map.th[last_el]; + mhw_th_map.handles[i] = mhw_th_map.handles[last_el]; + } + --mhw_th_map.cur_el; + break; + } + } + pthread_mutex_unlock(&mhw_th_map.mtx); + return rv; +} + +static HANDLE mhw_th_map_remove_by_th(pthread_t th) { + HANDLE rv = 0; + pthread_mutex_lock(&mhw_th_map.mtx); + for(int i = 0; i < mhw_th_map.cur_el; ++i) { + if(mhw_th_map.th[i] == th) { + rv = mhw_th_map.handles[i]; + /* if we have more than 1 element + * move the last in its place */ + if(mhw_th_map.cur_el > 1) { + const int last_el = mhw_th_map.cur_el - 1; + mhw_th_map.th[i] = mhw_th_map.th[last_el]; + mhw_th_map.handles[i] = mhw_th_map.handles[last_el]; + } + --mhw_th_map.cur_el; + break; + } + } + pthread_mutex_unlock(&mhw_th_map.mtx); + return rv; +} + +static pthread_t mhw_th_map_find_by_h(HANDLE h) { + pthread_t rv = 0; + pthread_mutex_lock(&mhw_th_map.mtx); + for(int i = 0; i < mhw_th_map.cur_el; ++i) { + if(mhw_th_map.handles[i] == h) { + rv = mhw_th_map.th[i]; + break; + } + } + pthread_mutex_unlock(&mhw_th_map.mtx); + return rv; +} + +static void mhw_set_context(pthread_t th, const context_t* ctx) { + int idx = 0; + pthread_mutex_lock(&mhw_ctx_map.mtx); + /* If you can find it, replace the content + */ + for(; idx < mhw_ctx_map.cur_el; ++idx) { + if(mhw_ctx_map.th[idx] == th) { + memcpy(&mhw_ctx_map.ctx[idx], ctx, sizeof(context_t)); + break; + } + } + /* Otherwise add it + */ + if(idx == mhw_ctx_map.cur_el) { + if(idx < MHW_MAX_CONTEXTS) { + mhw_ctx_map.th[idx] = th; + memcpy(&mhw_ctx_map.ctx[idx], ctx, sizeof(context_t)); + ++mhw_ctx_map.cur_el; + } else { + FIXME( "Reached a limit for contexts in MH:W patch" ); + } + } + pthread_mutex_unlock(&mhw_ctx_map.mtx); +} + +static void mhw_get_context_flags(const context_t* from, context_t* to, unsigned int flags) { + /* + * This flags may have specific logic, leave it + * like that for now + const static unsigned int CPU_FLAGS = SERVER_CTX_DEBUG_REGISTERS; + */ + to->flags |= flags; + if (flags & SERVER_CTX_CONTROL) to->ctl = from->ctl; + if (flags & SERVER_CTX_INTEGER) to->integer = from->integer; + if (flags & SERVER_CTX_SEGMENTS) to->seg = from->seg; + if (flags & SERVER_CTX_FLOATING_POINT) to->fp = from->fp; + if (flags & SERVER_CTX_DEBUG_REGISTERS) to->debug = from->debug; + if (flags & SERVER_CTX_EXTENDED_REGISTERS) to->ext = from->ext; +} + +static int mhw_get_context(pthread_t th, context_t* ctx, unsigned int flags) { + int rv = 0; + pthread_mutex_lock(&mhw_ctx_map.mtx); + for(int i = 0; i < mhw_ctx_map.cur_el; ++i) { + if(mhw_ctx_map.th[i] == th) { + mhw_get_context_flags(&mhw_ctx_map.ctx[i], ctx, flags); + rv = 1; + break; + } + } + pthread_mutex_unlock(&mhw_ctx_map.mtx); + return rv; +} + +static void mhw_remove_context(pthread_t th) { + pthread_mutex_lock(&mhw_ctx_map.mtx); + for(int i = 0; i < mhw_ctx_map.cur_el; ++i) { + if(mhw_ctx_map.th[i] == th) { + /* if we have more than 1 element + * move the last in its place */ + if(mhw_ctx_map.cur_el > 1) { + const int last_el = mhw_th_map.cur_el - 1; + mhw_ctx_map.th[i] = mhw_ctx_map.th[last_el]; + if(last_el != i) memcpy(&mhw_ctx_map.ctx[i], &mhw_ctx_map.ctx[last_el], sizeof(context_t)); + } + --mhw_th_map.cur_el; + break; + } + } + pthread_mutex_unlock(&mhw_ctx_map.mtx); +} + +static void mhw_debug_init(void) { + FILE *fp_lcl = 0; + if(mhw_fp_debug) return; + fp_lcl = fopen("/home/ema/mhw_debug.log", "wa"); + if(!fp_lcl) return; + if(!__sync_bool_compare_and_swap(&mhw_fp_debug, 0, fp_lcl)) + fclose(fp_lcl); +} + +static void mhw_debug_printf_buffer(FILE *fp, const void* buf, int sz) { + const unsigned char *uc = (const unsigned char*)buf; + fprintf(fp, "0x"); + for(int i = 0; i < sz; ++i) + fprintf(fp, "%02X", uc[i]); +} + +static void mhw_debug_printf_x86_regs(FILE *fp, const context_t *ctx) { + fprintf(fp, "ctl: [%u %u %u %u %u %u]\n", ctx->ctl.i386_regs.eip, ctx->ctl.i386_regs.ebp, ctx->ctl.i386_regs.esp, ctx->ctl.i386_regs.eflags, ctx->ctl.i386_regs.cs, ctx->ctl.i386_regs.ss); + fprintf(fp, "integer: [%u %u %u %u %u %u]\n", ctx->integer.i386_regs.eax, ctx->integer.i386_regs.ebx, ctx->integer.i386_regs.ecx, ctx->integer.i386_regs.edx, ctx->integer.i386_regs.esi, ctx->integer.i386_regs.edi); + fprintf(fp, "seg: [%u %u %u %u]\n", ctx->seg.i386_regs.ds, ctx->seg.i386_regs.es, ctx->seg.i386_regs.fs, ctx->seg.i386_regs.gs); + fprintf(fp, "fp: [%u %u %u %u %u %u %u %u ", ctx->fp.i386_regs.ctrl, ctx->fp.i386_regs.status, ctx->fp.i386_regs.tag, ctx->fp.i386_regs.err_off, ctx->fp.i386_regs.err_sel, ctx->fp.i386_regs.data_off, ctx->fp.i386_regs.data_sel, ctx->fp.i386_regs.cr0npx); + mhw_debug_printf_buffer(fp, &ctx->fp.i386_regs.regs[0], 80); + fprintf(fp, "]\n"); + fprintf(fp, "debug: [%u %u %u %u %u %u]\n", ctx->debug.i386_regs.dr0, ctx->debug.i386_regs.dr1, ctx->debug.i386_regs.dr2, ctx->debug.i386_regs.dr3, ctx->debug.i386_regs.dr6, ctx->debug.i386_regs.dr7); + fprintf(fp, "ext: ["); + mhw_debug_printf_buffer(fp, &ctx->ext.i386_regs[0], 512); + fprintf(fp, "]\n"); +} + +static void mhw_debug_printf_x86_64_regs(FILE *fp, const context_t *ctx) { + fprintf(fp, "ctl: [%lu %lu %lu %u %u %u]\n", ctx->ctl.x86_64_regs.rip, ctx->ctl.x86_64_regs.rbp, ctx->ctl.x86_64_regs.rsp, ctx->ctl.x86_64_regs.cs, ctx->ctl.x86_64_regs.ss, ctx->ctl.x86_64_regs.flags); + fprintf(fp, "integer: [%lu %lu %lu %lu %lu %lu %lu %lu %lu %lu %lu %lu %lu %lu]\n", ctx->integer.x86_64_regs.rax, ctx->integer.x86_64_regs.rbx, ctx->integer.x86_64_regs.rcx, ctx->integer.x86_64_regs.rdx, ctx->integer.x86_64_regs.rsi, ctx->integer.x86_64_regs.rdi, ctx->integer.x86_64_regs.r8, ctx->integer.x86_64_regs.r9, ctx->integer.x86_64_regs.r10, ctx->integer.x86_64_regs.r11, ctx->integer.x86_64_regs.r12, ctx->integer.x86_64_regs.r13, ctx->integer.x86_64_regs.r14, ctx->integer.x86_64_regs.r15); + fprintf(fp, "seg: [%u %u %u %u]\n", ctx->seg.x86_64_regs.ds, ctx->seg.x86_64_regs.es, ctx->seg.x86_64_regs.fs, ctx->seg.x86_64_regs.gs); + fprintf(fp, "fp: ["); + mhw_debug_printf_buffer(fp, &ctx->fp.x86_64_regs.fpregs[0], sizeof(ctx->fp.x86_64_regs.fpregs)); + fprintf(fp, "]\n"); + fprintf(fp, "debug: [%lu %lu %lu %lu %lu %lu]\n", ctx->debug.x86_64_regs.dr0, ctx->debug.x86_64_regs.dr1, ctx->debug.x86_64_regs.dr2, ctx->debug.x86_64_regs.dr3, ctx->debug.x86_64_regs.dr6, ctx->debug.x86_64_regs.dr7); + fprintf(fp, "ext: ["); + mhw_debug_printf_buffer(fp, &ctx->ext.i386_regs[0], 512); + fprintf(fp, "]\n"); +} + +static int mhw_is_debug_log(void) { + static int log_flag = MHW_CHECK_TODO; + /* don't care if we execute the below + * spuriously */ + if(MHW_UNLIKELY(log_flag == MHW_CHECK_TODO)) { + const char* p_debug_log = getenv(MHW_DEBUG_LOG); + const int is_log = (!p_debug_log) ? MHW_CHECK_NO : MHW_CHECK_YES; + /* atomic CAS */ + __sync_bool_compare_and_swap(&log_flag, MHW_CHECK_TODO, is_log); + } + + return log_flag == MHW_CHECK_YES; +} + +static void mhw_debug_print_context_qry(const char* fn, HANDLE this_th, HANDLE tgt_th, const context_t* ctx, const BOOL b_self, const int i_flags) { + if(!mhw_is_debug_log()) + return; + mhw_debug_init(); + if(!mhw_fp_debug) return; + pthread_mutex_lock(&mhw_ctx_mtx_debug); + fprintf(mhw_fp_debug, "%s\t%p\t%p\t[%lu]\n", fn, (void*)this_th, (void*)tgt_th, pthread_self()); + fprintf(mhw_fp_debug, "%i\t%08X\n", b_self, i_flags); + fprintf(mhw_fp_debug, "ctx: [cpu %i flags %08X]\n", ctx->cpu, ctx->flags); + switch(ctx->cpu) { + case CPU_x86: + mhw_debug_printf_x86_regs(mhw_fp_debug, ctx); + break; + case CPU_x86_64: + mhw_debug_printf_x86_64_regs(mhw_fp_debug, ctx); + break; + default: + break; + } + fprintf(mhw_fp_debug, "\n"); + fflush(mhw_fp_debug); + pthread_mutex_unlock(&mhw_ctx_mtx_debug); +} + +static void mhw_debug_print_th_map(const char *fn) { + if(!mhw_is_debug_log()) + return; + mhw_debug_init(); + if(!mhw_fp_debug) return; + pthread_mutex_lock(&mhw_ctx_mtx_debug); + pthread_mutex_lock(&mhw_th_map.mtx); + fprintf(mhw_fp_debug, "th_map %s [", fn); + for(int i = 0; i < mhw_th_map.cur_el; ++i) { + fprintf(mhw_fp_debug, "{ %lu, %p } ", mhw_th_map.th[i], mhw_th_map.handles[i]); + } + fprintf(mhw_fp_debug, "]\n"); + fprintf(mhw_fp_debug, "\n"); + fflush(mhw_fp_debug); + pthread_mutex_unlock(&mhw_th_map.mtx); + pthread_mutex_unlock(&mhw_ctx_mtx_debug); +} + +static int mhw_is_running(void) { + static int mhw_running_flag = MHW_CHECK_TODO; + /* don't care if we execute the below + * spuriously */ + if(MHW_UNLIKELY(mhw_running_flag == MHW_CHECK_TODO)) { + const char* p_gameid = getenv("SteamGameId"); + const int is_running = (!p_gameid || strcmp(p_gameid, MHW_STEAM_ID)) ? MHW_CHECK_NO : MHW_CHECK_YES; + /* atomic CAS */ + __sync_bool_compare_and_swap(&mhw_running_flag, MHW_CHECK_TODO, is_running); + } + + return mhw_running_flag == MHW_CHECK_YES; +} + /*********************************************************************** * thread_init * @@ -552,6 +841,11 @@ NTSTATUS WINAPI RtlCreateUserThread( HANDLE process, SECURITY_DESCRIPTOR *descr, if (handle_ptr) *handle_ptr = handle; else NtClose( handle ); + if(mhw_is_running()) { + mhw_th_map_add(handle, pthread_id); + mhw_debug_print_th_map("RtlCreateUserThread"); + } + return STATUS_SUCCESS; error: @@ -658,7 +952,6 @@ NTSTATUS WINAPI NtAlertThread( HANDLE handle ) return STATUS_NOT_IMPLEMENTED; } - /****************************************************************************** * NtTerminateThread (NTDLL.@) * ZwTerminateThread (NTDLL.@) @@ -678,6 +971,17 @@ NTSTATUS WINAPI NtTerminateThread( HANDLE handle, LONG exit_code ) SERVER_END_REQ; if (self) abort_thread( exit_code ); + if (mhw_is_running()) { + if(handle == GetCurrentThread()) { + mhw_remove_context(pthread_self()); + mhw_th_map_remove_by_th(pthread_self()); + } else { + pthread_t th = mhw_th_map_find_by_h(handle); + if(th) mhw_remove_context(th); + mhw_th_map_remove_by_h(handle); + } + mhw_debug_print_th_map("NtTerminateThread"); + } return ret; } @@ -744,6 +1048,24 @@ NTSTATUS set_thread_context( HANDLE handle, const context_t *context, BOOL *self NTSTATUS ret; DWORD dummy, i; + /* + * Short circuit in case of MH:W running + */ + if(mhw_is_running()) { + *self = (handle == GetCurrentThread()); + if(*self) { + mhw_set_context(pthread_self(), context); + } else { + pthread_t th = mhw_th_map_find_by_h(handle); + if(th) { + mhw_set_context(th, context); + } else { + FIXME( "Can't resolve MH:W handle map for %p", handle); + } + } + return STATUS_SUCCESS; + } + SERVER_START_REQ( set_thread_context ) { req->handle = wine_server_obj_handle( handle ); @@ -778,6 +1100,7 @@ NTSTATUS set_thread_context( HANDLE handle, const context_t *context, BOOL *self if (ret == STATUS_PENDING) ret = STATUS_ACCESS_DENIED; } + mhw_debug_print_context_qry("set_thread_context", GetCurrentThread(), handle, context, *self, 0); return ret; } @@ -790,6 +1113,30 @@ NTSTATUS get_thread_context( HANDLE handle, context_t *context, unsigned int fla NTSTATUS ret; DWORD dummy, i; + mhw_debug_print_context_qry("get_thread_context (in)", GetCurrentThread(), handle, context, *self, flags); + /* + * Short circuit in case of MH:W running + */ + if(mhw_is_running()) { + int rv = 0; + *self = (handle == GetCurrentThread()); + if(*self) { + rv = mhw_get_context(pthread_self(), context, flags); + } else { + pthread_t th = mhw_th_map_find_by_h(handle); + if(th) { + rv = mhw_get_context(th, context, flags); + } else { + FIXME( "Can't resolve MH:W handle map for %p", handle); + } + } + /* return success only if we could resolve it + */ + if(rv) { + return STATUS_SUCCESS; + } + } + SERVER_START_REQ( get_thread_context ) { req->handle = wine_server_obj_handle( handle ); @@ -825,6 +1172,16 @@ NTSTATUS get_thread_context( HANDLE handle, context_t *context, unsigned int fla NtResumeThread( handle, &dummy ); if (ret == STATUS_PENDING) ret = STATUS_ACCESS_DENIED; } + + /* + * If we're at this stage and MH:W is running, cache the + * result of the server operation + */ + if(mhw_is_running()) { + mhw_set_context(pthread_self(), context); + } + + mhw_debug_print_context_qry("get_thread_context (out)", GetCurrentThread(), handle, context, *self, flags); return ret; }