In [None]:
import numpy as np
import pandas as pd
import anndata as ad

In [None]:
perturbations = ['BCL11B', 'XPO7', 'ANK3', 'TBR1', 'SATB2', 'CUL1', 'RB1CC1', 'HERC1']  # or any from your dataset
pred = gears_model.predict([[p] for p in perturbations])

In [None]:
X_pred = np.stack([pred[p] for p in perturbations])  # shape: (n_perturbations, n_genes)

In [None]:
obs_df = pd.DataFrame({
    'condition': perturbations,
    'cell_type': ['Neuron'] * len(perturbations),  # or whatever your cell type is
    'batch': ['GEARS'] * len(perturbations)        # optional but useful
}, index=perturbations)

In [None]:
var_df = gears_model.adata.var.copy()
if 'gene_name' not in var_df.columns:
    var_df['gene_name'] = var_df.index

In [None]:
adata_pred = ad.AnnData(
    X=X_pred,
    obs=obs_df,
    var=var_df
)

In [None]:
gears_model.adata.obs['condition'].value_counts()


In [None]:
import numpy as np
import pandas as pd
import anndata as ad

def simulate_deterministic_anndata(gears_model, pert_counts):
    all_X = []
    all_obs = []

    gene_names = gears_model.adata.var['gene_name'].values
    var_df = gears_model.adata.var.copy()

    for cond, n_cells in pert_counts.items():
        if cond == "ctrl":
            # Use control expression vector
            expr = gears_model.ctrl_expression.cpu().numpy()
        else:
            pert = cond.split('+')[0]  # Use only gene name
            pred = gears_model.predict([[pert]])[pert]
            expr = pred  # shape: (n_genes,)

        # Repeat predicted expression n_cells times
        mat = np.tile(expr, (n_cells, 1))
        all_X.append(mat)

        obs_df = pd.DataFrame({
            'condition': [cond] * n_cells,
            'perturbation': [cond.split('+')[0] if cond != "ctrl" else "ctrl"] * n_cells,
        })
        all_obs.append(obs_df)

    # Combine
    X = np.vstack(all_X)
    obs = pd.concat(all_obs, ignore_index=True)

    var_df = gears_model.adata.var.copy()
    if 'gene_name' not in var_df.columns:
        var_df['gene_name'] = var_df.index
    
    adata_sim = ad.AnnData(
        X=X,
        obs=obs,
        var=var_df
    )

    return adata_sim

In [None]:
adata_sim = simulate_deterministic_anndata(gears_model, gears_model.adata.obs['condition'].value_counts())

In [None]:
adata_sim.head(2)

In [None]:
gears_model.adata.to_df().head(2)

In [None]:
adata_sim.write("simulated_adata.h5ad")