In [None]:
from util import SharedNumpyArray
import numpy as np
from torch import nn
import torch as th
import time


def generate_random_data():
    episodes = 2 ** 12
    episode_data = 2 ** 16
    return np.random.rand(episodes, episode_data)


data = generate_random_data()
print(data.shape)

In [None]:
from torch.multiprocessing import Pool
import os


def process_data(data, idx_start, idx_end):
    try:
        data = data.read()
    except:
        pass

    print(f"processing in {os.getpid()}")
    mat = np.random.rand(data.shape[1], 12)
    res = data @ mat
    time.sleep(10)
    print(f"processing finished in {os.getpid()}")

    return res


start = time.time()
result = process_data(data, 0, len(data))
print(result.shape)
print(f"time taken {time.time() - start:.02}s")
result

In [None]:
def process_data_multiprocessing(data, num_processes=8):
    chunk_size = len(data) // num_processes
    chunks = [(data, i, i + chunk_size)
              for i in range(0, len(data), chunk_size)]
    with Pool(processes=num_processes) as pool:
        result = np.concatenate(pool.starmap(process_data, chunks))
    return result


start = time.time()
result = process_data_multiprocessing(data)
print(result.shape)
print(f"time taken {time.time() - start:.02}s")
result

In [None]:
from util import SharedNumpyArray


def process_data_multiprocessing(data, num_processes=8):
    chunk_size = len(data) // num_processes
    data_shared = SharedNumpyArray(data)
    chunks = [(data_shared, i, i + chunk_size)
              for i in range(0, len(data), chunk_size)]
    with Pool(processes=num_processes) as pool:
        result = np.concatenate(pool.starmap(process_data, chunks))

    data_shared.unlink()

    return result


start = time.time()
result = process_data_multiprocessing(data)
print(result.shape)
print(f"time taken {time.time() - start:.02}s")
result