In [18]:
#%% Import libraries
import scanpy as sc
import numpy as np
from geosketch import gs
from fbpca import pca
import time
from memory_profiler import memory_usage

#%% Load dataset (10x Genomics 1.3M neurons)
# adata = sc.datasets.dataset_zeisel_2015()
# For full 1.3M dataset use:
# adata = sc.read_10x_mtx('path/to/matrix/folder/', var_names='gene_ids')

adata = sc.read_10x_mtx(
    'pbmc3k_filtered_gene_bc_matrices.tar',  # Update with actual path
    var_names='gene_symbols',
    cache=True
)

#%% Preprocessing
def preprocess(adata):
    sc.pp.filter_cells(adata, min_genes=200)
    sc.pp.filter_genes(adata, min_cells=3)
    sc.pp.normalize_total(adata, target_sum=1e4)
    sc.pp.log1p(adata)
    sc.pp.highly_variable_genes(adata, n_top_genes=2000)
    adata = adata[:, adata.var.highly_variable]
    return adata

adata = preprocess(adata)

#%% Dimensionality reduction
def run_pca(adata):
    start = time.time()
    U, s, Vt = pca(adata.X, k=100)
    adata.obsm['X_pca'] = U[:, :100] * s[:100]
    print(f"PCA time: {time.time()-start:.2f}s")
    return adata

adata = run_pca(adata)

#%% Geometric Sketching
def geometric_sketch(adata, n_sketch=20000):
    start = time.time()
    sketch_idx = gs(adata.obsm['X_pca'], n_sketch, replace=False)
    adata_sketch = adata[sketch_idx].copy()
    print(f"Geometric Sketching time: {time.time()-start:.2f}s")
    return adata_sketch

#%% Uniform Sampling  
def uniform_sample(adata, n_sketch=20000):
    start = time.time()
    np.random.seed(42)
    uniform_idx = np.random.choice(adata.n_obs, size=n_sketch, replace=False)
    adata_uniform = adata[uniform_idx].copy()
    print(f"Uniform Sampling time: {time.time()-start:.2f}s")
    return adata_uniform

#%% Memory profiling
def profile_memory(func, *args):
    mem_usage = memory_usage((func, args))
    return max(mem_usage)

geo_mem = profile_memory(geometric_sketch, adata)
uni_mem = profile_memory(uniform_sample, adata)

print(f"Geometric Sketching peak memory: {geo_mem:.1f} MiB")
print(f"Uniform Sampling peak memory: {uni_mem:.1f} MiB")

#%% Run both methods
adata_geo = geometric_sketch(adata)
adata_uni = uniform_sample(adata)

#%% Downstream analysis
def analyze_sketch(adata_sketch):
    sc.pp.neighbors(adata_sketch)
    sc.tl.leiden(adata_sketch)
    sc.tl.umap(adata_sketch)
    return adata_sketch

adata_geo = analyze_sketch(adata_geo)
adata_uni = analyze_sketch(adata_uni)

#%% Visualization
sc.pl.umap(adata_geo, color='leiden', 
           title='Geometric Sketch Clustering')
sc.pl.umap(adata_uni, color='leiden',
           title='Uniform Sample Clustering')

#%% Rare cell detection comparison (example)
def rare_cell_analysis(adata, adata_sketch):
    # Compare cluster distributions
    full_clusters = adata.obs['leiden'].value_counts(normalize=True)
    sketch_clusters = adata_sketch.obs['leiden'].value_counts(normalize=True)
    
    # Calculate Jensen-Shannon divergence
    from scipy.spatial import distance
    jsd = distance.jensenshannon(full_clusters, sketch_clusters)
    return jsd

# Note: Requires running clustering on full dataset (may be computationally intensive)
# full_jsd = rare_cell_analysis(adata, adata_geo)
# uni_jsd = rare_cell_analysis(adata, adata_uni)


FileNotFoundError: Did not find file pbmc3k_filtered_gene_bc_matrices.tar/matrix.mtx.gz.