In [2]:
from collections import defaultdict
from functools import lru_cache, _make_key
from threading import Lock

def threadsafe_lru(func):
    func = lru_cache()(func)
    lock_dict = defaultdict(Lock)

    def _thread_lru(*args, **kwargs):
        key = _make_key(args, kwargs, typed=False)
        with lock_dict[key]:
            return func(*args, **kwargs)

    return _thread_lru

In [3]:
from functools import lru_cache, _make_key
from threading import Lock
from weakref import WeakValueDictionary

def threadsafe_lru_locked(func):
    func = lru_cache()(func)
    access_lock = Lock()
    locks = WeakValueDictionary()

    def _thread_lru(*args, **kwargs):
        key = _make_key(args, kwargs, typed=False)
        with access_lock:
            lock = locks.setdefault(key, Lock())
        with lock:
            return func(*args, **kwargs)

    return _thread_lru

In [17]:
import glow  # noqa, patch print

import random
from collections import Counter
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from time import perf_counter, sleep

random.seed(123)
sleep_length = 0.5

arg_range = 3
num_tasks = 5


"""
comment out this decorator (or replace it with @lru_cache) 
and take a look at the results
"""
# @threadsafe_lru
@threadsafe_lru_locked
def long_running_function(x, y):
    print(f'({x}, {y}): executing...')
    execs[(x, y)] += 1
    sleep(sleep_length)


# @threadsafe_lru
# def long_running_function(x, y):
#     print(f'({x}, {y}): executing...')
#     execs[(x, y)] += 1
#     long_running_function(x, y)
#     sleep(sleep_length)


execs = Counter()
func_list = [
    partial(long_running_function, random.randint(0, arg_range), random.randint(0, arg_range))
    for _ in range(num_tasks)
]


def _ak_repr(*args, **kwargs):
    return ', '.join(filter(None, [
        ', '.join(f'{a}' for a in args),
        ', '.join(f'{k}={v}' for k, v in kwargs.items()),
    ]))


def create_timing_callback(fn_id, *args, **kwargs):
    start_time = perf_counter()

    def timer(future):
        exec_time = perf_counter() - start_time
        print(f'({_ak_repr(*args, **kwargs)}): {fn_id} done in {exec_time:.1f}s')

    return timer


with ThreadPoolExecutor() as executor:
    fn_id = 0
    for batch in [0, 1]:
        print(f'Batch {batch}')
        for fn in func_list:
            print(f'({_ak_repr(*fn.args, **fn.keywords)}): {fn_id} submitted')
            fut = executor.submit(fn)
            fut.add_done_callback(create_timing_callback(fn_id, *fn.args, **fn.keywords))
            fn_id += 1
        sleep(.1)
    print('Waiting for results..')
    executor.shutdown(wait=True)
dict(execs)

Batch 0
(0, 2): 0 submitted
(0, 2): executing...
(0, 3): 1 submitted
(0, 3): executing...
(2, 0): 2 submitted
(2, 0): executing...
(0, 3): 3 submitted
(2, 2): 4 submitted
(2, 2): executing...
Batch 1
(0, 2): 5 submitted
(0, 3): 6 submitted
(2, 0): 7 submitted
(0, 3): 8 submitted
(2, 2): 9 submitted
Waiting for results..
