# Compressed Sensing Simulation on snRNAseq

### Getting Started

In [34]:
import h5py
import numpy as np
import pandas as pd
import scanpy as sc
import spams
from scipy.stats import entropy
from scipy.spatial import distance
from sklearn.model_selection import train_test_split
import scipy.sparse as sp
from scipy.io import mmread
from scipy.stats import spearmanr, pearsonr, entropy
from scipy.spatial import distance
import os

THREADS = 16

In [4]:
# Define file paths
matrix_file = "dataset/GSE165371_cb_adult_mouse/cb_adult_mouse.mtx.gz"
barcodes_file = "dataset/GSE165371_cb_adult_mouse/cb_adult_mouse_barcodes.txt"
genes_file = "dataset/GSE165371_cb_adult_mouse/cb_adult_mouse_genes.txt"

# Load the data
matrix = mmread(matrix_file).tocsc()  # Read the matrix market file
barcodes = pd.read_csv(barcodes_file, header=None).iloc[:, 0].tolist()  # Read barcodes
genes = pd.read_csv(genes_file, header=None).iloc[:, 0].tolist()  # Read genes

In [5]:
print(matrix.shape)
print(len(barcodes))  # Number of barcodes
print(len(genes))     # Number of genes

(24409, 611034)
611034
24409


In [6]:
# AnnData creation
adata = sc.AnnData(
    X=matrix.T,  # Transpose the matrix to match AnnData's format
    obs=pd.DataFrame(index=barcodes),  # Barcodes -> obs
    var=pd.DataFrame(index=genes)      # Genes -> var
)

print(adata)

# Save the AnnData object to an H5AD file (optional)
adata.write("dataset/cb_adult_mouse.h5ad")

AnnData object with n_obs × n_vars = 611034 × 24409


In [7]:
# verify
print(adata.var)
print(adata.obs)

Empty DataFrame
Columns: []
Index: [Xkr4, Gm1992, Gm37381, Rp1, Sox17, Mrpl15, Lypla1, Gm37988, Tcea1, Rgs20, Atp6v1h, Rb1cc1, 4732440D04Rik, Fam150a, St18, Pcmtd1, Gm26901, Sntg1, Rrs1, Adhfe1, Mybl1, Vcpip1, 1700034P13Rik, Sgk3, Mcmdc2, Snhg6, Tcf24, Ppp1r42, Gm15818, Cops5, Cspp1, Arfgef1, Cpa6, Prex2, A830018L16Rik, Sulf1, Slco5a1, Gm29283, Prdm14, Ncoa2, Gm29570, Tram1, Lactb2, Eya1, Trpa1, Kcnb2, Terf1, Sbspon, 4930444P10Rik, Rpl7, Rdh10, Gm28095, Stau2, Gm7568, Ube2w, Tceb1, D030040B21Rik, Tmem70, Ly96, Gm28376, Jph1, Gdap1, Pi15, Gm28154, Gm16070, Crispld1, Gm28153, Defb41, Tfap2d, Tfap2b, Pkhd1, Il17f, Mcm3, 6720483E21Rik, Paqr8, Efhc1, Tram2, Tmem14a, Gsta3, Gm28836, Kcnq5, Rims1, Gm29107, Ogfrl1, Gm28822, B3gat2, Smap1, Sdhaf4, Fam135a, Col9a1, Col19a1, Lmbrd1, Adgrb3, Phf3, Ptp4a1, Gm29669, Lgsn, Khdrbs2, Prim2, Rab23, ...]

[24409 rows x 0 columns]
Empty DataFrame
Columns: []
Index: [IXa_M003_TTTACTGGTACAGTAA, IXa_M003_AACCTTTAGCGACTTT, IXa_M003_TGTGCGGAGGTCGACA, IXa_M003_

In [8]:
# downsample adata

# Set a random seed for reproducibility 
np.random.seed(23)

# Randomly select 1000 genes
selected_genes = np.random.choice(adata.var_names, size=1000, replace=False)

# Randomly select 2000 cells
selected_cells = np.random.choice(adata.obs_names, size=10000, replace=False)

# Subset the AnnData object
adata_downsampled = adata[selected_cells, selected_genes]

# Print the new AnnData object
print(adata_downsampled)

# Save the downsampled AnnData object
adata_downsampled.write("dataset/cb_adult_mouse_downsampled.h5ad")

View of AnnData object with n_obs × n_vars = 10000 × 1000


In [3]:
def split_and_save_data(adata=adata_downsampled, output_dir="./dataset/"):
    """
    Splits `adata.X` into train, validation, and test subsets and saves them in the specified directory.
    
    Parameters:
    - adata: AnnData object, containing the dataset
    - output_dir: str, directory to save the subsets
    """
    os.makedirs(output_dir, exist_ok=True)
    
    # Ensure `adata.X` is in a compatible format
    X = adata.X
    
    # Convert sparse matrices to compressed row format
    if sp.issparse(X):  
        X = X.tocsr()
    else:
        X = np.asarray(X, dtype=np.float64)
    
    # Get total number of samples (cells)
    num_cells = X.shape[0]
    
    # Define split sizes (50% Train, 25% Validate, 25% Test)
    train_idx, temp_idx = train_test_split(np.arange(num_cells), test_size=0.50, random_state=23)
    validate_idx, test_idx = train_test_split(temp_idx, test_size=0.50, random_state=23)
    
    # Extract data using indices
    train_data = X[train_idx]
    validate_data = X[validate_idx]
    test_data = X[test_idx]
    
    # Convert all subsets to dense arrays before saving
    train_data = np.asarray(train_data.todense() if sp.issparse(train_data) else train_data, dtype=np.float64)
    validate_data = np.asarray(validate_data.todense() if sp.issparse(validate_data) else validate_data, dtype=np.float64)
    test_data = np.asarray(test_data.todense() if sp.issparse(test_data) else test_data, dtype=np.float64)

    # Convert train_data to Fortran-contiguous format for SPAMS
    train_data = np.asfortranarray(train_data)
    
    # Save subsets
    np.save(os.path.join(output_dir, "train_data.npy"), train_data)
    np.save(os.path.join(output_dir, "validate_data.npy"), validate_data)
    np.save(os.path.join(output_dir, "test_data.npy"), test_data)
    
    # Print dataset shapes
    print(f"Train data shape: {train_data.shape}")
    print(f"Validation data shape: {validate_data.shape}")
    print(f"Test data shape: {test_data.shape}")
    print(f"Datasets saved in {output_dir}")
    

split_and_save_data()

Train data shape: (5000, 1000)
Validation data shape: (2500, 1000)
Test data shape: (2500, 1000)
Datasets saved in ./dataset/


In [35]:
def load_saved_data(input_dir="./dataset/"):
    """
    Loads the train, validation, and test datasets from the specified directory.
    
    Parameters:
    - input_dir: str, directory containing the saved subsets
    
    Returns:
    - train_data: numpy array
    - validate_data: numpy array
    - test_data: numpy array
    """
    train_data_path = os.path.join(input_dir, "train_data.npy")
    validate_data_path = os.path.join(input_dir, "validate_data.npy")
    test_data_path = os.path.join(input_dir, "test_data.npy")
    
    # Load data
    train_data = np.load(train_data_path, allow_pickle=False)
    validate_data = np.load(validate_data_path, allow_pickle=False)
    test_data = np.load(test_data_path, allow_pickle=False)
    
    # Ensure all loaded data is in dense format
    if isinstance(validate_data, np.ndarray) and validate_data.ndim == 0:
        validate_data = validate_data.item()  # Fix potential scalar issue
    elif sp.issparse(validate_data):
        validate_data = validate_data.toarray()
    
    if isinstance(test_data, np.ndarray) and test_data.ndim == 0:
        test_data = test_data.item()  # Fix potential scalar issue
    elif sp.issparse(test_data):
        test_data = test_data.toarray()

    # Print dataset shapes
    print(f"Train data shape: {train_data.shape}")
    print(f"Validation data shape: {validate_data.shape}")
    print(f"Test data shape: {test_data.shape}")
    
    return train_data, validate_data, test_data


train_data, validate_data, test_data = load_saved_data()

Train data shape: (5000, 1000)
Validation data shape: (2500, 1000)
Test data shape: (2500, 1000)


## Utilities

In [36]:
# Utility and measurement functions
def random_phi_subsets_g(m, g, n, d_thresh=0.4):
    Phi = np.zeros((m, g))  # Initialize measurement matrix (pools × genes)
    Phi[np.random.choice(m, np.random.randint(n[0], n[1]+1), replace=False), 0] = 1  # Assign first gene
    
    for i in range(1, g):  # Assign pools for remaining genes
        dmax = 1
        while dmax > d_thresh:
            p = np.zeros(m)
            p[np.random.choice(m, np.random.randint(n[0], n[1]+1), replace=False)] = 1  # Randomly assign pools
            dmax = 1 - distance.cdist(Phi[:, :i].T, [p], 'correlation').min()  # Check minimum correlation
        
        Phi[:, i] = p  # Assign gene to pools
    
    Phi = (Phi.T / Phi.sum(1)).T  # Normalize by row sum
    return Phi


def get_observations(X0, Phi, snr=5, return_noise=False):
    noise = np.array([np.random.randn(X0.shape[1]) for _ in range(X0.shape[0])])
    noise *= np.linalg.norm(X0) / np.linalg.norm(noise) / snr  # Adjust noise level
    
    if return_noise:
        return Phi.dot(X0 + noise), noise
    else:
        return Phi.dot(X0 + noise)
    

def compare_distances(A, B, random_samples=[], s=200, pvalues=False):
    if len(random_samples) == 0:
        random_samples = np.zeros(A.shape[1], dtype=bool)
        random_samples[:min(s, A.shape[1])] = True  # Select `s` random samples
        np.random.shuffle(random_samples)

    dist_x = distance.pdist(A[:, random_samples].T, 'euclidean')
    dist_y = distance.pdist(B[:, random_samples].T, 'euclidean')

    pear = pearsonr(dist_x, dist_y)  # Pearson correlation
    spear = spearmanr(dist_x, dist_y)  # Spearman correlation
    
    if pvalues:
        return pear, spear  # Return correlations with p-values
    else:
        return pear[0], spear[0]  # Return only correlation coefficients


def correlations(A, B):
    p = (1 - distance.correlation(A.flatten(), B.flatten()))  # Overall correlation
    spear = spearmanr(A.flatten(), B.flatten())  # Spearman correlation
    
    dist_genes = np.zeros(A.shape[0])  # Per-gene correlation
    for i in range(A.shape[0]):
        dist_genes[i] = 1 - distance.correlation(A[i], B[i])
    
    pg = np.average(dist_genes[np.isfinite(dist_genes)])  # Mean per-gene correlation
    
    dist_sample = np.zeros(A.shape[1])  # Per-sample correlation
    for i in range(A.shape[1]):
        dist_sample[i] = 1 - distance.correlation(A[:, i], B[:, i])
    
    ps = np.average(dist_sample[np.isfinite(dist_sample)])  # Mean per-sample correlation
    
    return p, spear[0], pg, ps


def compare_results(A, B):
    results = list(correlations(A, B))[:-1]  # Get correlation metrics (excluding `ps`)
    results += list(compare_distances(A, B))  # Compare Euclidean distances
    results += list(compare_distances(A.T, B.T))  # Compare sample-wise distances
    return results

### Sparse Decoding

In [37]:
def sparse_decode(Y, D, k, worstFit=1., mink=0, method='omp', nonneg=False):
    """
    Sparse decoding method - obtain module activations from composite measurements.
    
    Parameters:
    - Y : Observed measurement data (pools × samples)
    - D : Dictionary (module patterns) used for reconstruction
    - k : Sparsity constraint (number of nonzero coefficients)
    - worstFit : Minimum required reconstruction accuracy (default = 1)
    - mink : Minimum sparsity allowed (default = 0)
    - method : 'omp' (Orthogonal Matching Pursuit) or 'lasso' (L1 regularization)
    - nonneg : If True, forces non-negative solutions (for LASSO only)
    
    Returns:
    - W : Estimated module activations (modules × samples)
    """
    
    if method == 'omp':
        while k > mink:
            W = spams.omp(np.asfortranarray(Y), np.asfortranarray(D), L=k, numThreads=THREADS)
            W = np.asarray(W.todense())

            fit = 1 - np.linalg.norm(Y - D.dot(W))**2 / np.linalg.norm(Y)**2  # Compute fit accuracy

            if fit < worstFit:
                break  # Stop if fit is too low
            else:
                k -= 1  # Decrease sparsity constraint
                
            
    elif method == 'lasso':
        Ynorm = np.linalg.norm(Y)**2 / Y.shape[1]  # Normalize Y
        W = spams.lasso(np.asfortranarray(Y), np.asfortranarray(D),
                        lambda1=k * Ynorm, mode=1, numThreads=THREADS, pos=nonneg)
        W = np.asarray(W.todense())
    
    return W

### Learning Modules Dictionary 

In [38]:
def smaf(X, d, lda1, lda2, maxItr=10, UW=None, posW=False, posU=True, use_chol=False, 
         module_lower=1, activity_lower=1, donorm=False, mode=1, mink=5, U0=[], 
         U0_delta=0.1, doprint=False):
    """
    Sparse Module Activity Factorization (SMAF) to learn gene modules.
    
    Parameters:
    - X: np.array, shape (genes, cells), gene expression matrix. Note: dims opposite of anndata 
    - d: int, number of gene modules (dictionary size)
    - lda1: float, regularization term for activity sparsity
    - lda2: float, regularization term for module sparsity
    - maxItr: int, number of iterations for optimization
    - UW: tuple, (U, W) initial matrices, optional
    - posW: bool, enforce non-negativity on W
    - posU: bool, enforce non-negativity on U
    - use_chol: bool, use Cholesky decomposition for faster computation
    - module_lower: float, lower bound for module entropy
    - activity_lower: float, lower bound for activity entropy
    - donorm: bool, normalize U matrix
    - mode: int, optimization mode
    - mink: int, minimum sparsity constraint
    - U0: list, optional initialization for U
    - U0_delta: float, delta for projected gradient descent
    - doprint: bool, print progress
    
    Returns:
    - U, W: np.arrays, learned dictionary and activity matrices
    """
    # Initialize U and W matrices if not provided
    if UW is None:
        U, W = spams.nmf(np.asfortranarray(X), return_lasso=True, K=d, numThreads=THREADS)
        W = np.asarray(W.todense())  # Convert sparse W to dense format
    else:
        U, W = UW  # Use provided matrices
    
    # Compute initial reconstruction of X
    Xhat = U.dot(W)
    # Compute normalization factor for regularization
    Xnorm = np.linalg.norm(X) ** 2 / X.shape[1]
    
    # Iterate to optimize U and W
    for itr in range(maxItr):
        if mode == 1:
            # Solve for U using Lasso regression with sparsity regularization
            U = spams.lasso(np.asfortranarray(X.T), D=np.asfortranarray(W.T),
                            lambda1=lda2 * Xnorm, mode=1, numThreads=THREADS,
                            cholesky=use_chol, pos=posU)
            U = np.asarray(U.todense()).T  # Convert to dense and transpose
        elif mode == 2:
            # Optionally use projected gradient descent if U0 is provided
            if len(U0) > 0:
                U = projected_grad_desc(W.T, X.T, U.T, U0.T, lda2, U0_delta, maxItr=400)
                U = U.T  # Transpose back
            else:
                # Solve for U using Lasso regression with different lambda settings
                U = spams.lasso(np.asfortranarray(X.T), D=np.asfortranarray(W.T),
                                lambda1=lda2, lambda2=0.0, mode=2, numThreads=THREADS,
                                cholesky=use_chol, pos=posU)
                U = np.asarray(U.todense()).T  # Convert to dense and transpose
        
        # Normalize U if required
        if donorm:
            U = U / np.linalg.norm(U, axis=0)
            U[np.isnan(U)] = 0  # Replace NaN values with zero
        
        # Solve for W
        if mode == 1:
            wf = (1 - lda2)  # Worst-fit tolerance for sparsity
            W = sparse_decode(X, U, lda1, worstFit=wf, mink=mink)
        elif mode == 2:
            if len(U0) > 0:
                W = projected_grad_desc(U, X, W, [], lda1, 0., nonneg=posW, maxItr=400)
            else:
                W = spams.lasso(np.asfortranarray(X), D=np.asfortranarray(U),
                                lambda1=lda1, lambda2=1.0, mode=2, numThreads=THREADS,
                                cholesky=use_chol, pos=posW)
                W = np.asarray(W.todense())  # Convert to dense
        
        # Compute updated reconstruction of X
        Xhat = U.dot(W)
        
        # Compute module and activity sizes based on entropy
        module_size = np.average([np.exp(entropy(abs(u))) for u in U.T if u.sum() > 0])
        activity_size = np.average([np.exp(entropy(abs(w))) for w in W.T])
        
        # Print progress if required
        if doprint:
            print(distance.correlation(X.flatten(), Xhat.flatten()), module_size, activity_size, lda1, lda2)
        
        # Adjust sparsity parameters dynamically
        if module_size < module_lower:
            lda2 /= 2.  # Decrease sparsity regularization for U
        if activity_size < activity_lower:
            lda2 /= 2.  # Decrease sparsity regularization for W
    
    return U, W  # Return learned matrices

In [22]:
# Load the train_data and transpose it to match expected input (genes, cells)
train_data = np.load("./dataset/train_data.npy")

def run_smaf_and_save(train_data=train_data.T, output_file="results/gene_module_dictionary.csv"):
    """
    Runs SMAF decomposition on the training data and saves the learned dictionary as a CSV file.
    
    Parameters:
    - train_data: np.array, training dataset
    - output_file: str, file path to save the dictionary matrix
    - d, lda1, lda2, maxItr, use_chol, donorm, mode, mink, doprint: SMAF parameters
    """
    # Perform SMAF decomposition
    U, W = smaf(train_data, d=40, lda1=8, lda2=0.2, maxItr=20, use_chol=False, donorm=True, mode=1, mink=0., doprint=True)
    
    # Remove modules that have zero contribution
    nz = (U.sum(axis=0) > 0)
    U = U[:, nz]
    
    # Save the dictionary matrix to CSV
    pd.DataFrame(U).to_csv(output_file, index=False)
    print(f"Gene module dictionary saved to {output_file}")
    
    print(U.shape)
    print(W.shape)


run_smaf_and_save()

0.15502220165222347 51.08450062419147 1.0 8 0.2
0.07616726224363368 122.85337308340179 1.0 8 0.2
0.07348208686974511 121.70957825186119 1.0 8 0.2
0.07216777791673556 119.24445698817844 1.0 8 0.2
0.0715573237881234 117.88506495012612 1.0 8 0.2
0.0711028791864795 116.33589362791858 1.0 8 0.2
0.07082055445425584 115.88190981014434 1.0 8 0.2
0.07060187298814125 115.8994030137363 1.0 8 0.2
0.07038884555162794 115.83776969833636 1.0 8 0.2
0.0702874692485953 115.73724827044016 1.0 8 0.2
0.0701799333030102 115.71211324366072 1.0 8 0.2
0.07005602541531508 115.49863463655402 1.0 8 0.2
0.06998573298173383 115.22242108259366 1.0 8 0.2
0.06994670113748769 115.0518029587914 1.0 8 0.2
0.06992206841026738 114.8717706226925 1.0 8 0.2
0.06988907641031816 114.81097996573457 1.0 8 0.2
0.06981681562561204 114.55603245988581 1.0 8 0.2
0.06977074818239337 114.31519840356438 1.0 8 0.2
0.06972571717414933 114.05332493366424 1.0 8 0.2
0.06968392151289193 113.89299404740609 1.0 8 0.2
Gene module dictionary saved

### Random Measurement Matrix

In [23]:
# X
adata_downsampled = sc.read_h5ad("dataset/cb_adult_mouse_downsampled.h5ad")
X = adata_downsampled.X

# U
df = pd.read_csv("./results/gene_module_dictionary.csv")
U = df.to_numpy()

# transpose X to match expected
X = X.T

In [24]:
# print(f"phi shape: {phi.shape}")  # (40, g)
print(f"U shape: {U.shape}")      # Should be (g, something)
print(f"X shape: {X.shape}")

U shape: (1000, 40)
X shape: (1000, 10000)


In [39]:
def random_double_balanced(m, g, max_pools_per_gene, min_pools_per_gene):
    """
    Generates a random measurement matrix (`phi`) with "double balanced" characteristics.
    
    This function ensures:
    - Each gene is assigned to `max_pools_per_gene` pools initially.
    - Genes with fewer than `min_pools_per_gene` assignments are reassigned.
    - The resulting matrix is normalized so that each row sums to 1.

    Parameters:
    - m : int
        Number of pools (rows in the matrix).
    - g : int
        Number of genes (columns in the matrix).
    - max_pools_per_gene : int
        Maximum number of pools a gene can be assigned to.
    - min_pools_per_gene : int
        Minimum number of pools a gene must be assigned to.

    Returns:
    - phi : np.ndarray
        A (m × g) measurement matrix where each gene is assigned to a set of pools.
    """

    # Initialize an empty measurement matrix (pools × genes)
    phi = np.zeros((m, g))

    # Randomly assign each gene to pools up to `max_pools_per_gene` times
    for i in range(max_pools_per_gene):
        idx = np.random.choice(g, g, replace=False)  # Shuffle gene indices
        idx = idx % m  # Ensure indices fit within the pool size (modulo operation)
        phi[idx, np.arange(g)] = 1  # Assign genes to random pools

    # Ensure each gene is assigned to at least `min_pools_per_gene` pools
    for i in np.where(phi.sum(0) < min_pools_per_gene)[0]:  # Identify under-assigned genes
        p = phi.sum(1).max() - phi.sum(1)  # Compute imbalance for each pool
        p[np.where(phi[:, i])[0]] = 0  # Exclude pools that already contain the gene

        # If there are pools available for reassignment
        if p.sum() > 0:
            p = p / p.sum()  # Normalize probability distribution
            num_to_assign = min((p > 0).sum(), int(min_pools_per_gene - phi[:, i].sum()))
            idx = np.random.choice(m, num_to_assign, replace=False, p=p)  # Select new pools
            phi[idx, i] = 1  # Assign gene to additional pools

    # Normalize the matrix so that each row sums to 1 (avoid bias in pooling)
    phi = (phi.T / phi.sum(1)).T

    return phi

In [28]:
# Initialize an array to store the "best" coherence scores
best = np.ones(500)  # Stores the 500 best coherence values (initialized to 1, the worst possible score)
Phi_coh = [None for _ in best]  # Stores the corresponding best measurement matrices

# Generate 50,000 random measurement matrices (`phi`)
for x in range(50000):  
    if np.mod(x, 5000) == 0:  # Print progress every 5000 iterations
        print(x)

    # Generate a new random measurement matrix (`phi`)
    phi = random_double_balanced(40, X.shape[0], 4, 4)  

    # Compute the 90th percentile of the cosine distance between projected feature vectors
    coh_90 = np.percentile(1 - distance.pdist(phi.dot(U).T, 'cosine'), 90)

    # If the new `phi` has a better coherence score, replace the worst-performing matrix
    if coh_90 < best.max():  # Only keep `phi` if it has a lower (better) coherence score
        i = np.argmax(best)  # Find the index of the worst current coherence score
        best[i] = coh_90  # Replace the worst score with the new, better coherence score
        Phi_coh[i] = phi  # Store the corresponding measurement matrix

0
5000
10000
15000
20000
25000
30000
35000
40000
45000


In [29]:
# Load the validate_adata and transpose it to match expected input (genes, cells)
validate_data = np.load("./dataset/validate_data.npy")
validate_data = validate_data.T

In [30]:
# Of the "best" measurements, test ability to recover original gene expression patterns, and again pick the best.

# Initialize parameters for sparse decoding
sparsity = 0.02  # Defines the sparsity constraint (controls how many nonzero coefficients are allowed)
best = np.zeros(50)  # Stores the best reconstruction scores (initialized to 0, worst possible)
Phi = [None for _ in best]  # Stores the corresponding best measurement matrices

# Iterate through the "best" measurement matrices stored in `Phi_coh`
for phi in Phi_coh:
    
    # Generate simulated observations by applying `phi` to the validation data
    y = get_observations(validate_data, phi, snr=5)  # Simulate pooled noisy measurements
    
    # Use sparse decoding (LASSO) to recover gene module activations
    w = sparse_decode(y, phi.dot(U), sparsity, method='lasso')  
    
    # Reconstruct gene expression using the estimated module activations
    x2 = U.dot(w)  # Approximate gene expression matrix from recovered module weights
    
    # Compare reconstructed expression (`x2`) with the original validation data (`validate_data`)
    r = compare_results(validate_data, x2)  
    
    # If the new measurement matrix `phi` produces a better reconstruction, update `best` and `Phi`
    if r[2] > best.min():  # If the new reconstruction score is better than the worst stored score
        i = np.argmin(best)  # Find the index of the worst-performing matrix
        best[i] = r[2]  # Replace the worst reconstruction score with the new, better score
        Phi[i] = phi  # Store the corresponding measurement matrix

  dist = 1.0 - uv / math.sqrt(uu * vv)


In [31]:
# This computes matrix coherence of the "best" measurements

# Sort the best reconstruction scores in descending order
xs = np.argsort(best)  # Get indices that would sort `best` in ascending order
best = best[xs[::-1]]  # Reorder `best` so the highest scores come first

# Reorder the corresponding measurement 909matrices (`Phi`) based on sorting indices
Phi = [Phi[i] for i in xs]  # Reorder `Phi` so that the best-performing matrices come first

# Compute coherence metrics for the best measurement matrices
d_gene = np.array([
    np.percentile(1 - distance.pdist(phi.dot(U).T, 'cosine'), 90) for phi in Phi
])  # Compute the 90th percentile of cosine distance

d_gene99 = np.array([
    np.percentile(1 - distance.pdist(phi.dot(U).T, 'cosine'), 99) for phi in Phi
])  # Compute the 99th percentile of cosine distance

In [32]:
print(f"Best recovery score: {best.max()}")

Best recovery score: 0.3261876017780862


## Simulation

In [40]:
def run_smaf_and_get_U(train_data, d=100, lda1=8, lda2=0.2, maxItr=100, use_chol=False, donorm=True, mode=1, mink=0., doprint=True):
    """
    Runs SMAF decomposition on the training data and returns the learned dictionary matrix U.
    """
    U, W = smaf(train_data, d=d, lda1=lda1, lda2=lda2, maxItr=maxItr, 
                use_chol=use_chol, donorm=donorm, mode=mode, mink=mink, doprint=doprint)
    
    nz = (U.sum(axis=0) > 0)  # Remove zero-contribution modules
    U = U[:, nz]
    
    return U, W

def generate_best_measurement_matrices(U, X, num_trials=50000, num_best=500):
    """
    Generates and selects the best measurement matrices based on coherence.
    """
    best = np.ones(num_best)  # Initialize with worst coherence scores (1 is max distance)
    Phi_coh = [None] * num_best  # Store corresponding measurement matrices
    
    for x in range(num_trials):
        if x % (num_trials // 10) == 0:
            print(f"Iteration: {x}")
        
        phi = random_double_balanced(U.shape[1], X.shape[0], 4, 4)  # Ensure m matches U.shape[1]
        coh_90 = np.percentile(1 - distance.pdist(phi.dot(U).T, 'cosine'), 90)
        
        if coh_90 < best.max():
            i = np.argmax(best)
            best[i] = coh_90
            Phi_coh[i] = phi
    
    return best, Phi_coh

def test_measurement_recovery(validate_data, U, Phi_coh, sparsity=0.02, num_best=50):
    """
    Tests recovery ability of measurement matrices and selects the best ones.
    """
    best = np.zeros(num_best)
    Phi = [None] * num_best
    
    for phi in Phi_coh:
        y = get_observations(validate_data, phi, snr=5)
        w = sparse_decode(y, phi.dot(U), sparsity, method='lasso')  
        x2 = U.dot(w)
        r = compare_results(validate_data, x2)
        
        if r[2] > best.min():
            i = np.argmin(best)
            best[i] = r[2]
            Phi[i] = phi
    
    return best, Phi

def simulate_smaf(train_data_path, validate_data_path, d=100, sparsity=0.02, num_trials=50000, num_best=500):
    """
    Simulates SMAF with different d values, selects best measurement matrices, and evaluates recovery.
    """
    # Load and process data
    train_data = np.load(train_data_path).T  # Ensure correct shape (genes, cells)
    validate_data = np.load(validate_data_path).T
    adata_downsampled = sc.read_h5ad("dataset/cb_adult_mouse_downsampled.h5ad")
    X = adata_downsampled.X.T  # Ensure correct shape
    
    # Run SMAF and get dictionary matrix U
    U, W = run_smaf_and_get_U(train_data, d=d)
    print(f"U shape after SMAF: {U.shape}")
    
    # Generate and select best measurement matrices
    best_coh, Phi_coh = generate_best_measurement_matrices(U, X, num_trials=num_trials, num_best=num_best)
    
    # Test recovery ability
    best_rec, Phi = test_measurement_recovery(validate_data, U, Phi_coh, sparsity=sparsity, num_best=50)
    
    # Compute coherence metrics
    d_gene = np.array([np.percentile(1 - distance.pdist(phi.dot(U).T, 'cosine'), 90) for phi in Phi])
    d_gene99 = np.array([np.percentile(1 - distance.pdist(phi.dot(U).T, 'cosine'), 99) for phi in Phi])
    
    print(f"Best recovery score: {best_rec.max()}")
    
    return {
        "best_coherence_scores": best_coh,
        "best_recovery_scores": best_rec,
        "coherence_90th": d_gene,
        "coherence_99th": d_gene99,
        "Phi": Phi
    }

In [21]:
simulate_smaf("./dataset/train_data.npy", "./dataset/validate_data.npy",
              d=100, sparsity=0.02, num_trials=50000, num_best=500)

0.13716847965263046 39.42491961009741 1.0 8 0.2


  U = U / np.linalg.norm(U, axis=0)


0.06830910362598908 105.24376326263244 1.0 8 0.2
0.06576318398781 106.76051877565425 1.0 8 0.2
0.06490138129237111 105.94299215131639 1.0 8 0.2
0.06458111275522416 105.96783580914311 1.0 8 0.2
0.06443295283628536 105.51636789172528 1.0 8 0.2
0.06431985773985716 105.50199980353507 1.0 8 0.2
0.06424343796516829 105.49988815147992 1.0 8 0.2
0.06420032679624943 105.4537685314417 1.0 8 0.2
0.0641622715946244 105.55876931002179 1.0 8 0.2
0.06412387769311756 105.520060188813 1.0 8 0.2
0.06409181649538842 105.3929193021982 1.0 8 0.2
0.06405985540583448 105.29433988003073 1.0 8 0.2
0.0640378154323138 105.2487973773282 1.0 8 0.2
0.06402214520908078 105.25670678169118 1.0 8 0.2
0.06401035730172777 105.23411699623065 1.0 8 0.2
0.06400153234948036 105.21526810267471 1.0 8 0.2
0.06399631394054739 105.21519357176044 1.0 8 0.2
0.06399193050084395 105.24585403997308 1.0 8 0.2
0.06398584880291225 105.20768694298792 1.0 8 0.2
0.06397972815932995 105.15471895224346 1.0 8 0.2
0.06397593850202066 105.145097

  dist = 1.0 - uv / math.sqrt(uu * vv)


Best recovery score: 0.3353381939788718


{'best_coherence_scores': array([0.92749531, 0.93197694, 0.93127699, 0.93206352, 0.93235504,
        0.93260943, 0.93172664, 0.93274824, 0.93274708, 0.93256702,
        0.93164821, 0.93235884, 0.93277997, 0.93191573, 0.93270294,
        0.9327777 , 0.93254515, 0.93152935, 0.93221939, 0.93245903,
        0.93107717, 0.93273631, 0.93277415, 0.93269241, 0.93267564,
        0.93246771, 0.9322289 , 0.93250975, 0.93200648, 0.93043652,
        0.93278122, 0.93172987, 0.93279743, 0.93066326, 0.93272534,
        0.93200247, 0.93211739, 0.92988055, 0.93228189, 0.93247731,
        0.93203171, 0.93279769, 0.93126391, 0.93086945, 0.93134266,
        0.93196701, 0.93186564, 0.93144869, 0.93231065, 0.9310375 ,
        0.93189368, 0.93262259, 0.93220959, 0.9311321 , 0.93227375,
        0.93224063, 0.93057161, 0.93261035, 0.93224598, 0.93266901,
        0.93240455, 0.93197976, 0.93127239, 0.93258588, 0.9314723 ,
        0.93230595, 0.93253357, 0.93129779, 0.93249745, 0.93093552,
        0.93161304, 0.9

In [41]:
simulate_smaf("./dataset/train_data.npy", "./dataset/validate_data.npy",
              d=40, sparsity=0.08, num_trials=50000, num_best=500)

0.15684627817494612 50.030156886794074 1.0 8 0.2
0.07639166190128588 123.2906494173812 1.0 8 0.2
0.07387501007794117 122.24512877471568 1.0 8 0.2
0.07236307823036048 119.37838316269531 1.0 8 0.2
0.07165109678078874 117.74047363907985 1.0 8 0.2
0.07124570161546584 116.53214174185239 1.0 8 0.2
0.07097159800581476 116.11588815780676 1.0 8 0.2
0.0708241277540359 115.88073531631089 1.0 8 0.2
0.07070752609922548 115.81274890603027 1.0 8 0.2
0.07056919851032029 115.70331296967856 1.0 8 0.2
0.0704652825118588 115.66826292802925 1.0 8 0.2
0.07038098607269971 115.65746647772357 1.0 8 0.2
0.07031584204975616 115.54355574664555 1.0 8 0.2
0.07026101647294358 115.39270537685903 1.0 8 0.2
0.070207383240528 115.20554438198437 1.0 8 0.2
0.07016773328841819 115.07961093604702 1.0 8 0.2
0.07013312327569454 114.95190384909138 1.0 8 0.2
0.0701076306818249 114.8860072661161 1.0 8 0.2
0.07008829927669136 114.79432243466961 1.0 8 0.2
0.07006910926234022 114.7328768127683 1.0 8 0.2
0.07005892151778048 114.6582

  dist = 1.0 - uv / math.sqrt(uu * vv)


Best recovery score: 0.3710135688963039


{'best_coherence_scores': array([0.97526686, 0.97560796, 0.97495882, 0.97500935, 0.97515243,
        0.97543535, 0.97499919, 0.97521057, 0.9753365 , 0.97567739,
        0.97567543, 0.97557822, 0.97553381, 0.97465552, 0.97569113,
        0.97575972, 0.97492282, 0.97563431, 0.97521187, 0.97571645,
        0.97527465, 0.97384352, 0.97524532, 0.97568256, 0.97580163,
        0.97579792, 0.9754263 , 0.97573904, 0.97493711, 0.97520395,
        0.9757764 , 0.97544704, 0.9753187 , 0.97566231, 0.9755673 ,
        0.97537401, 0.9755889 , 0.97568637, 0.97556838, 0.97574261,
        0.97556246, 0.9755217 , 0.97547584, 0.97532597, 0.97523313,
        0.97558304, 0.97549109, 0.97551192, 0.97507469, 0.97571418,
        0.97532816, 0.97579139, 0.97560131, 0.97495033, 0.97548518,
        0.97574653, 0.97545559, 0.97560569, 0.97573039, 0.97577275,
        0.97538959, 0.97545522, 0.97535731, 0.97523266, 0.97551706,
        0.97574278, 0.97573478, 0.97540259, 0.9755583 , 0.97543291,
        0.97576004, 0.9