In [10]:
import cupy as cp
import numpy as np
from typing import List, Optional, Iterator
from time import time, sleep
from itertools import cycle

class NormGenerator:
    def __init__(self, rows: int, cols: int, seed: int):
        assert rows>0 and cols>0 and seed>0
        self._rows = rows
        self._cols = cols
        self._stream = cp.cuda.Stream()
        self._cp_random_gen = None
        self._generated_array = None
        self._time_spent_synchronizing = 0
        self._create_new_matrix(seed=seed)
        
    def __iter__(self)->Iterator[cp.ndarray]:
        while True:
            yield get_matrix()
        
    def _create_new_matrix(self, seed: Optional[int]=None)->None:
        assert self._generated_array is None, 'Error: create_new_matrix called when get_matrix needs to be called first'
        assert seed is None or seed>=0
        with self._stream:
            if seed is not None:
                self._cp_random_gen=cp.random.default_rng(seed)
            self._generated_array=self._cp_random_gen.standard_normal(size=[self._rows,self._cols])
            
    def get_time_spent_synchronizing(self)->float:
        return self._time_spent_synchronizing
        
    def get_matrix(self, seed: Optional[int]=None)->cp.ndarray:
        assert self._generated_array is not None, 'Error: get_matrix called when create_new_matrix needs to be called first'
        
        # ensure work is complete -- track how long we've spent waiting for it
        start=time()
        self._stream.synchronize() # ensure's work is complete
        end=time()
        self._time_spent_synchronizing+=end-start
        
        ret_mat=self._generated_array
        self._generated_array=None
        self._create_new_matrix(seed=seed)
        return ret_mat

class ConcurrentNormGenerator:
    def __init__(self, rows: int, cols: int, seed: int, buffer_size: int, skip_steps: int=0):
        assert buffer_size>0
        self._np_rng = np.random.default_rng(seed)
        for _ in range(skip_steps):
            self._get_seed()
        
        # helper function to create an instance of NormGenerator
        def createNormGenerator():
            seed=self._get_seed()
            return NormGenerator(rows=rows,cols=cols,seed=seed)
        
        self._cache: List[NormGenerator] = [
            createNormGenerator()
            for _ in range(buffer_size)
        ]
        
        self._cache_iterator=cycle(self._cache)
        
    def __iter__(self)->Iterator[cp.ndarray]:
        return self._generator()
    
    def _generator(self)->Iterator[cp.ndarray]:
        while True:
            yield self.get_matrix()
    
    def _get_seed(self)->int:
        return self._np_rng.integers(low=0, high=1e9)
    
    def get_time_spent_synchronizing(self)->float:
        return sum(ng.get_time_spent_synchronizing() for ng in self._cache)
    
    def get_matrix(self)->cp.ndarray:
        seed=self._get_seed()
        return next(self._cache_iterator).get_matrix(seed=seed)

if __name__ == "__main__":
    my_norm_gen=ConcurrentNormGenerator(rows=1024,cols=10000,seed=42,buffer_size=100)

random_matrix=next(iter(my_norm_gen))
print(random_matrix)

def compare_two_gens(gen1,gen2,k:int):
    gen1=iter(gen1)
    gen2=iter(gen2)
    for _ in range(k):
        assert cp.array_equal(next(gen1),next(gen2))

compare_two_gens(
    ConcurrentNormGenerator(rows=1024,cols=10000,seed=42,buffer_size=10),
    ConcurrentNormGenerator(rows=1024,cols=10000,seed=42,buffer_size=1),
    k=100
    )