In [None]:
"""
data_gen.py

Creates:
 - reference_expression_1000x100.csv
 - reference_labels_1000.csv
 - query_expression_1000x100.csv
 - query_coordinates_1000x2.csv
 - mock_sc_dataset.npz      (compact binary with arrays)

Optional: saves .h5ad files if scanpy is available.

Run:
    python data_gen.py
"""
import os
import numpy as np
import pandas as pd

# Optional: produce AnnData objects if scanpy is installed
try:
    import scanpy as sc  # type: ignore
    HAVE_SCANPY = True
except Exception:
    HAVE_SCANPY = False

def generate(
    n_cells=1000,
    n_genes=100,
    n_types=5,
    markers_per_type=10,
    seed=42,
    outdir="mock_sc"
):
    rng = np.random.default_rng(seed)

    os.makedirs(outdir, exist_ok=True)

    gene_names = [f"gene_{i}" for i in range(n_genes)]
    cell_ids_query = [f"Q_cell_{i}" for i in range(n_cells)]
    cell_ids_ref   = [f"R_cell_{i}" for i in range(n_cells)]

    # ---- baseline log-expression per gene ----
    base_log_mu = rng.normal(loc=1.0, scale=0.6, size=n_genes)  # baseline log-rate

    # ---- create sparse marker shifts for each cell type ----
    type_shifts = np.zeros((n_types, n_genes))
    for t in range(n_types):
        # choose marker genes (non-overlapping blocks to keep types distinct)
        start = (t * markers_per_type) % n_genes
        marker_idx = np.arange(start, start + markers_per_type) % n_genes
        type_shifts[t, marker_idx] = rng.normal(loc=1.5, scale=0.4, size=markers_per_type)

    # ---- Reference set B (annotated) ----
    # Balanced-ish assignment of cell types
    counts_per_type = n_cells // n_types
    ref_cell_types = np.concatenate([np.repeat(t, counts_per_type) for t in range(n_types)])
    remainder = n_cells - counts_per_type * n_types
    if remainder > 0:
        ref_cell_types = np.concatenate([ref_cell_types, rng.integers(0, n_types, size=remainder)])
    rng.shuffle(ref_cell_types)

    # per-cell log rates = baseline + type shift + small cell noise
    ref_log_rates = np.zeros((n_cells, n_genes))
    for i in range(n_cells):
        ct = int(ref_cell_types[i])
        cell_noise = rng.normal(loc=0.0, scale=0.3, size=n_genes)
        ref_log_rates[i] = base_log_mu + type_shifts[ct] + cell_noise

    ref_rates = np.exp(ref_log_rates)
    # Poisson sampling to get counts (clip to avoid extreme lambda)
    ref_counts = rng.poisson(lam=np.clip(ref_rates, 0, 1e6)).astype(int)

    ref_df = pd.DataFrame(ref_counts, index=cell_ids_ref, columns=gene_names)
    ref_labels = pd.Series(ref_cell_types, index=cell_ids_ref, name="cell_type").astype(int)

    # ---- Query set A (unlabeled) ----
    # sample from same latent types but discard labels (keep hidden for realism)
    query_cell_types_hidden = rng.choice(np.arange(n_types), size=n_cells, replace=True)

    # batch/gene-wise shift to simulate domain/batch differences
    batch_shift = rng.normal(loc=0.2, scale=0.2, size=n_genes)

    query_log_rates = np.zeros((n_cells, n_genes))
    for i in range(n_cells):
        ct = int(query_cell_types_hidden[i])
        cell_noise = rng.normal(loc=0.0, scale=0.35, size=n_genes)
        query_log_rates[i] = base_log_mu + type_shifts[ct] + batch_shift + cell_noise

    query_rates = np.exp(query_log_rates)
    query_counts = rng.poisson(lam=np.clip(query_rates, 0, 1e6)).astype(int)
    query_df = pd.DataFrame(query_counts, index=cell_ids_query, columns=gene_names)

    # ---- Spatial coordinates for query cells ----
    # We'll place one Gaussian cluster per latent type (labels are not saved with coords)
    coords = np.zeros((n_cells, 2))
    angles = np.linspace(0, 2*np.pi, n_types, endpoint=False)
    radius = 50.0
    centers = np.vstack([radius*np.cos(angles), radius*np.sin(angles)]).T + 100.0  # shift to positive coords
    for t in range(n_types):
        idx = np.where(query_cell_types_hidden == t)[0]
        if len(idx) > 0:
            coords[idx] = rng.normal(loc=centers[t], scale=5.0, size=(len(idx), 2))
    coords_df = pd.DataFrame(coords, index=cell_ids_query, columns=["x", "y"])

    # ---- Save outputs ----
    ref_expr_csv   = os.path.join(outdir, "reference_expression_1000x100.csv")
    ref_labels_csv = os.path.join(outdir, "reference_labels_1000.csv")
    query_expr_csv = os.path.join(outdir, "query_expression_1000x100.csv")
    query_coords_csv = os.path.join(outdir, "query_coordinates_1000x2.csv")
    npz_path = os.path.join(outdir, "mock_sc_dataset.npz")

    ref_df.to_csv(ref_expr_csv)
    ref_labels.to_csv(ref_labels_csv, header=True)
    query_df.to_csv(query_expr_csv)
    coords_df.to_csv(query_coords_csv)

    np.savez_compressed(
        npz_path,
        ref_expr=ref_df.values,
        ref_cells=np.array(ref_df.index),
        ref_genes=np.array(ref_df.columns),
        ref_labels=ref_labels.values,
        query_expr=query_df.values,
        query_cells=np.array(query_df.index),
        query_genes=np.array(query_df.columns),
        query_coords=coords_df.values
    )

    # ---- Optional: save as h5ad if scanpy is available ----
    if HAVE_SCANPY:
        import anndata as ad  # type: ignore
        ad_ref = ad.AnnData(X=ref_df.values, obs=pd.DataFrame({"cell_id": ref_df.index, "cell_type": ref_labels.values}, index=ref_df.index), var=pd.DataFrame(index=ref_df.columns))
        ad_query = ad.AnnData(X=query_df.values, obs=pd.DataFrame({"cell_id": query_df.index}, index=query_df.index), var=pd.DataFrame(index=query_df.columns))
        ad_query.obsm["spatial"] = coords_df.values
        ad_ref.write(os.path.join(outdir, "reference_1000x100.h5ad"))
        ad_query.write(os.path.join(outdir, "query_1000x100.h5ad"))

    # ---- return objects for interactive use ----
    return {
        "ref_df": ref_df,
        "ref_labels": ref_labels,
        "query_df": query_df,
        "query_coords": coords_df,
        "npz_path": npz_path
    }

if __name__ == "__main__":
    out = generate()
    print("Saved files in 'mock_sc' directory:")
    for k in ["reference_expression_1000x100.csv", "reference_labels_1000.csv", "query_expression_1000x100.csv", "query_coordinates_1000x2.csv", "mock_sc_dataset.npz"]:
        print(" -", os.path.join("mock_sc", k))
    if HAVE_SCANPY:
        print("Also saved .h5ad files (scanpy available).")


Saved files in 'mock_sc' directory:
 - mock_sc/reference_expression_1000x100.csv
 - mock_sc/reference_labels_1000.csv
 - mock_sc/query_expression_1000x100.csv
 - mock_sc/query_coordinates_1000x2.csv
 - mock_sc/mock_sc_dataset.npz
Also saved .h5ad files (scanpy available).
