In [1]:
import numpy as np

def generate_mixing_matrix(d_sources: int, d_data=None, lin_type='uniform', cond_threshold=25, n_iter_4_cond=None,
                           dtype=np.float32):
    """
    Generate square linear mixing matrix
    @param d_sources: dimension of the latent sources
    @param d_data: dimension of the mixed data
    @param lin_type: specifies the type of matrix entries; either `uniform` or `orthogonal`.
    @param cond_threshold: higher bound on the condition number of the matrix to ensure well-conditioned problem
    @param n_iter_4_cond: or instead, number of iteration to compute condition threshold of the mixing matrix.
        cond_threshold is ignored in this case/
    @param dtype: data type for data
    @return:
        A: mixing matrix
    @rtype: np.ndarray
    """
    if d_data is None:
        d_data = d_sources

    if lin_type == 'orthogonal':
        A = (np.linalg.qr(np.random.uniform(-1, 1, (d_sources, d_data)))[0]).astype(dtype)

    elif lin_type == 'uniform':
        if n_iter_4_cond is None:
            cond_thresh = cond_threshold
        else:
            cond_list = []
            for _ in range(int(n_iter_4_cond)):
                A = np.random.uniform(-1, 1, (d_sources, d_data)).astype(dtype)
                for i in range(d_data):
                    A[:, i] /= np.sqrt((A[:, i] ** 2).sum())
                cond_list.append(np.linalg.cond(A))

            cond_thresh = np.percentile(cond_list, 25)  # only accept those below 25% percentile

        A = (np.random.uniform(0, 2, (d_sources, d_data)) - 1).astype(dtype)
        for i in range(d_data):
            A[:, i] /= np.sqrt((A[:, i] ** 2).sum())

        while np.linalg.cond(A) > cond_thresh:
            # generate a new A matrix!
            A = (np.random.uniform(0, 2, (d_sources, d_data)) - 1).astype(dtype)
            for i in range(d_data):
                A[:, i] /= np.sqrt((A[:, i] ** 2).sum())
    else:
        raise ValueError('incorrect method')
    return A


In [2]:
generate_mixing_matrix(2, 8)

array([[ 0.683444  , -0.6756179 ,  0.60752434, -0.40626603, -0.785336  ,
         0.6819457 ,  0.5893126 ,  0.99239546],
       [-0.73000294, -0.73725194,  0.79430103,  0.91375476,  0.6190698 ,
        -0.73140275,  0.8079051 ,  0.12309024]], dtype=float32)