<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#Slice-specified-nodes-in-dimspec" data-toc-modified-id="Slice-specified-nodes-in-dimspec-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Slice specified nodes in dimspec</a></span></li><li><span><a href="#Test-parallelism" data-toc-modified-id="Test-parallelism-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>Test parallelism</a></span><ul class="toc-item"><li><ul class="toc-item"><li><span><a href="#Example-task" data-toc-modified-id="Example-task-2.0.1"><span class="toc-item-num">2.0.1&nbsp;&nbsp;</span>Example task</a></span></li></ul></li><li><span><a href="#Serial-invocation" data-toc-modified-id="Serial-invocation-2.1"><span class="toc-item-num">2.1&nbsp;&nbsp;</span>Serial invocation</a></span><ul class="toc-item"><li><span><a href="#Many-var-parallelisation" data-toc-modified-id="Many-var-parallelisation-2.1.1"><span class="toc-item-num">2.1.1&nbsp;&nbsp;</span>Many var parallelisation</a></span></li><li><span><a href="#Concurrent-assignment" data-toc-modified-id="Concurrent-assignment-2.1.2"><span class="toc-item-num">2.1.2&nbsp;&nbsp;</span>Concurrent assignment</a></span></li></ul></li><li><span><a href="#Use-unix-tools" data-toc-modified-id="Use-unix-tools-2.2"><span class="toc-item-num">2.2&nbsp;&nbsp;</span>Use unix tools</a></span><ul class="toc-item"><li><span><a href="#Threading" data-toc-modified-id="Threading-2.2.1"><span class="toc-item-num">2.2.1&nbsp;&nbsp;</span>Threading</a></span></li><li><span><a href="#Multiprocessing" data-toc-modified-id="Multiprocessing-2.2.2"><span class="toc-item-num">2.2.2&nbsp;&nbsp;</span>Multiprocessing</a></span></li></ul></li></ul></li></ul></div>

In [3]:
import ray
import pyrofiler as pyrof
import numpy as np
import sys
np.random.seed(42)

# Slice specified nodes in dimspec

In [4]:
def _none_slice():
    return slice(None)

def _get_idx(x, idxs, slice_idx, shapes=None):
    if shapes is None:
        shapes = [2]*len(idxs)
    point = np.unravel_index(slice_idx, shapes)
    get_point = {i:p for i,p in zip(idxs, point)}
    if x in idxs:
        p = get_point[x]
        return slice(p,p+1)
    else:
        return _none_slice()

def _slices_for_idxs(idxs, *args, shapes=None, slice_idx=0):
    """Return array of slices along idxs"""
    slices = []
    for indexes in args:
        _slice = [_get_idx(x, idxs, slice_idx, shapes) for x in indexes ]
        slices.append(tuple(_slice))
    return slices
        

# Test parallelism
### Example task

In [8]:
def get_example_task():
    A = 8
    #A = 13
    B, C = 10, 7
    shape1 = [2]*(A+B)
    shape2 = [2]*(A+C)
    T1 = np.random.randn(*shape1)
    T2 = np.random.randn(*shape2)
    common = list(range(A))
    idxs1 = common + list(range(A, A+B))
    idxs2 = common + list(range(A+B, A+B+C))
    return (T1, idxs1), (T2, idxs2)

x, y = get_example_task()
x[1], y[1]

([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17],
 [0, 1, 2, 3, 4, 5, 6, 7, 18, 19, 20, 21, 22, 23, 24])

## Serial invocation

In [9]:

#@ray.remote
def contract(A, B):
    a, idxa = A
    b, idxb = B
    contract_idx = set(idxa) & set(idxb)
    result_idx = set(idxa + idxb)
    C = np.einsum(a,idxa, b,idxb, result_idx)
    return C

with pyrof.timing('contract'):
    C = contract(x, y)


contract : 0.323714017868042


### Many var parallelisation

In [11]:
contract_idx = set(x[1]) & set(y[1])
result_idx = set(x[1] + y[1])

with pyrof.timing(f'contract simple'):
    C = contract(x,y)
    
par_vars = [1, 4, 17, 5]
threads = 2**len(par_vars)
target_shape = C.shape

with pyrof.timing('Sequential patches'):
    C_patches = [
        sliced_contract(x, y, par_vars, i)
        for i in range(threads)
    ]

with pyrof.timing('allocate result'):
    C_par = np.empty(target_shape)

patch_slces = [
    target_slice(result_idx, par_vars, i)
    for i in range(threads)
]

with pyrof.timing('assignment'):
    for s, patch in zip(patch_slces, C_patches):
        C_par[s[0]] = patch

assert np.array_equal(C, C_par)


contract simple : 1.0820136070251465
	contract sliced 0 : 0.033449411392211914
	contract sliced 1 : 0.03836464881896973
	contract sliced 2 : 0.06398296356201172
	contract sliced 3 : 0.0529017448425293
	contract sliced 4 : 0.04923677444458008
	contract sliced 5 : 0.04331612586975098
	contract sliced 6 : 0.050642967224121094
	contract sliced 7 : 0.04017186164855957
	contract sliced 8 : 0.042578935623168945
	contract sliced 9 : 0.047518253326416016
	contract sliced 10 : 0.3524761199951172
	contract sliced 11 : 0.24108147621154785
	contract sliced 12 : 0.07741117477416992
	contract sliced 13 : 0.10958695411682129
	contract sliced 14 : 0.06605362892150879
	contract sliced 15 : 0.06382369995117188
Sequential patches : 1.3844408988952637
allocate result : 0.03616809844970703
assignment : 0.40332555770874023


## Use unix tools
### Threading

In [13]:
from multiprocessing import Pool, Array
from multiprocessing.dummy import Pool as ThreadPool
import os

def tonumpyarray(mp_arr):
    return np.frombuffer(mp_arr.get_obj())

In [14]:
contract_idx = set(x[1]) & set(y[1])
result_idx = set(x[1] + y[1])

In [15]:
with pyrof.timing(f'contract simple'):
    C = contract(x,y)

contract simple : 0.3542511463165283


In [16]:
C_size = sys.getsizeof(C)
print(f'result size: {C_size:e}')
print(f'operands size: {sys.getsizeof(x[0]):e}, {sys.getsizeof(y[0]):e}')
target_shape = C.shape

result size: 2.684359e+08
operands size: 2.097520e+06, 2.624640e+05


In [18]:
with pyrof.timing('Total thread contraction time:'):
    par_vars = [1,17, 5]
    threads = 2**len(par_vars)

    os.global_C = np.empty(target_shape)

    def work(i):
        patch = sliced_contract(x, y, par_vars, i)
        sl = target_slice(result_idx, par_vars, i)
        os.global_C[sl[0]] = patch

    pool = ThreadPool(processes=threads)
    print('inited thread pool')
    with pyrof.timing('Thread work'):
        _ = pool.map(work, range(threads))

    C_size = sys.getsizeof(os.global_C)
    print(f'  result size: {C_size:e}')

inited thread pool
	contract sliced 0 : 0.04151320457458496
	contract sliced 5 : 0.049321889877319336
	contract sliced 1 : 0.09570074081420898
	contract sliced 6 : 0.0636897087097168
	contract sliced 2	contract sliced 4 	contract sliced 3 : 0.12053894996643066
 : 0.10593318939208984
: 0.1232302188873291
	contract sliced 7 : 0.19748163223266602
Thread work : 0.38308095932006836
  result size: 2.684359e+08
Total thread contraction time: : 0.38835692405700684


In [None]:
assert np.array_equal(C, os.global_C)

###  Multiprocessing

In [None]:

flat_size = len(C.flatten())
with pyrof.timing('init array'):
    os.global_C = np.empty(target_shape)
    #os.global_C = tonumpyarray(Array('d', flat_size))
#us.global_C = os.global_C.reshape(target_shape)

pool = Pool(processes=threads)
print('inited pool')
with pyrof.timing('parallel work'):
    print('started work')
    _ = pool.map(work, range(threads))

C_size = sys.getsizeof(os.global_C)
print(f'result size: {C_size:e}')
assert np.array_equal(C, os.global_C)


In [None]:
del os.global_C