run every section seperatedly.

since there are multiple variables in different sections that have the same name.

## detail information

In [1]:
import os
import json
import pandas as pd
import numpy as np
import anndata as ad
# anndata/scanpy 用于 .h5ad

path_adata = "adata_Training.h5ad"
path_gene  = "gene_names.csv"
path_pert  = "pert_counts_Validation.csv"

def human_size(n):
    for unit in ["B","KB","MB","GB","TB"]:
        if n < 1024:
            return f"{n:.1f} {unit}"
        n /= 1024
    return f"{n:.1f} PB"

def file_info(p):
    if not os.path.exists(p):
        return {"path": p, "exists": False}
    st = os.stat(p)
    return {
        "path": p,
        "exists": True,
        "size_bytes": st.st_size,
        "size_human": human_size(st.st_size),
        "mtime": st.st_mtime,
    }

print("文件信息概览：")
for p in [path_adata, path_gene, path_pert]:
    print(file_info(p))

# 1) 读取 .h5ad
print("\n读取 anndata (.h5ad)...")
adata = ad.read_h5ad(path_adata)
print(f"adata: {adata}")
# 尝试预览基本结构
print(f"形状: n_obs={adata.n_obs}, n_vars={adata.n_vars}")
print("obs 列（前若干）:", list(adata.obs.columns)[:10])
print("var 列（前若干）:", list(adata.var.columns)[:10])
# 若 X 为稀疏矩阵，取少量子集转为密集显示
try:
    sub = adata[:3, :5].X
    if hasattr(sub, "toarray"):
        sub = sub.toarray()
    print("X 子集(3x5)示例：\n", np.asarray(sub))
except Exception as e:
    print("无法显示 X 子集：", e)

# 3) 读取 gene_names.csv
print("\n读取 gene_names.csv ...")
try:
    # 自动推断分隔符，如明确为逗号可用 sep=','
    genes_df = pd.read_csv(path_gene)
    print("gene_names.csv 预览：")
    print(genes_df.head())
    print("列信息：", genes_df.columns.tolist())
    print("行数/列数：", genes_df.shape)
except Exception as e:
    print("读取 gene_names.csv 出错：", e)

# 4) 读取 pert_counts_Validation.csv
print("\n读取 pert_counts_Validation.csv ...")
try:
    pert_df = pd.read_csv(path_pert)
    print("pert_counts_Validation.csv 预览：")
    print(pert_df.head())
    print("列信息：", pert_df.columns.tolist())
    print("行数/列数：", pert_df.shape)
except Exception as e:
    print("读取 pert_counts_Validation.csv 出错：", e)

# 可选：一致性检查（例如基因名是否和 adata.var 索引对齐）
try:
    if 'gene' in (genes_df.columns if 'genes_df' in locals() else []):
        gene_list = genes_df['gene'].astype(str).tolist()
        adata_genes = adata.var_names.astype(str).tolist()
        overlap = len(set(gene_list) & set(adata_genes))
        print(f"\n基因名重叠计数：{overlap} / {len(gene_list)}")
except Exception as e:
    print("一致性检查时出错：", e)

print("\n读取完成。")

文件信息概览：
{'path': 'adata_Training.h5ad', 'exists': True, 'size_bytes': 15482497461, 'size_human': '14.4 GB', 'mtime': 1762960875.5560176}
{'path': 'gene_names.csv', 'exists': True, 'size_bytes': 116023, 'size_human': '113.3 KB', 'mtime': 1762957829.7691019}
{'path': 'pert_counts_Validation.csv', 'exists': True, 'size_bytes': 978, 'size_human': '978.0 B', 'mtime': 1762957829.7691019}

读取 anndata (.h5ad)...
adata: AnnData object with n_obs × n_vars = 221273 × 18080
    obs: 'target_gene', 'guide_id', 'batch'
    var: 'gene_id'
形状: n_obs=221273, n_vars=18080
obs 列（前若干）: ['target_gene', 'guide_id', 'batch']
var 列（前若干）: ['gene_id']
X 子集(3x5)示例：
 [[0. 0. 1. 0. 0.]
 [0. 1. 0. 0. 0.]
 [0. 1. 0. 0. 0.]]

读取 gene_names.csv ...
gene_names.csv 预览：
    SAMD11
0    NOC2L
1   KLHL17
2  PLEKHN1
3    PERM1
4     HES4
列信息： ['SAMD11']
行数/列数： (18079, 1)

读取 pert_counts_Validation.csv ...
pert_counts_Validation.csv 预览：
  target_gene  n_cells  median_umi_per_cell
0      SH3BP4     2925              54551.0
1

In [2]:
print(f"adata: {adata}")
print("obs columns:", adata.obs.columns.tolist())
print("var columns:", adata.var.columns.tolist())
print("layers:", list(adata.layers.keys()) if adata.layers else None)
print("obsm:", list(adata.obsm.keys()) if adata.obsm else None)
print("varm:", list(adata.varm.keys()) if adata.varm else None)
print("uns:", list(adata.uns.keys()) if adata.uns else None)
print("raw:", "Yes" if adata.raw is not None else "No")
print(type(adata.X))
print(adata.X.dtype)
print("是否为稀疏矩阵:", hasattr(adata.X, "toarray"))
for col in adata.obs.columns:
    print(f"{col}: {adata.obs[col].nunique()} unique values")
    print(adata.obs[col].value_counts().head(), "\n")
print(adata.var.head())
if "X_pca" in adata.obsm:
    print("PCA shape:", adata.obsm["X_pca"].shape)

if "X_umap" in adata.obsm:
    print("UMAP shape:", adata.obsm["X_umap"].shape)



adata: AnnData object with n_obs × n_vars = 221273 × 18080
    obs: 'target_gene', 'guide_id', 'batch'
    var: 'gene_id'
obs columns: ['target_gene', 'guide_id', 'batch']
var columns: ['gene_id']
layers: None
obsm: None
varm: None
uns: None
raw: No
<class 'scipy.sparse._csr.csr_matrix'>
float32
是否为稀疏矩阵: True
target_gene: 151 unique values
target_gene
non-targeting    38176
TMSB4X            4760
PRCP              4331
TADA1             4035
HIRA              3407
Name: count, dtype: int64 

guide_id: 189 unique values
guide_id
TADA1_P1P2_A|TADA1_P1P2_B    4035
PRCP_P1_A|PRCP_P1_B          3415
IGF2R_P1P2_A|IGF2R_P1P2_B    3109
NCK2_P1P2_A|NCK2_P1P2_B      2929
HIRA_P1_A|HIRA_P1_B          2888
Name: count, dtype: int64 

batch: 48 unique values
batch
Flex_3_10    5082
Flex_3_06    5068
Flex_3_15    5038
Flex_3_11    4988
Flex_3_08    4988
Name: count, dtype: int64 

                 gene_id
SAMD11   ENSG00000187634
NOC2L    ENSG00000188976
KLHL17   ENSG00000187961
PLEKHN1  ENSG0000018

## split dataset (old version: used for vcc.ipynb)

In [None]:
# used for vcc.ipynb

import numpy as np
import anndata as ad

print("\n读取 anndata (.h5ad)...")
adata = ad.read_h5ad("adata_Training.h5ad")
print(f"adata: {adata}")
# 提取所有 target genes（去除 non-targeting）
all_targets = adata.obs['target_gene'].unique().tolist()
all_targets = [g for g in all_targets if g != "non-targeting"]

print("扰动基因数量：", len(all_targets))  # 应为 150

# 固定随机种子以保证可复现
np.random.seed(42)

# 打乱顺序
np.random.shuffle(all_targets)

# 划分比例，可根据需要调整
val_ratio = 0.3
test_ratio = 0.2

n = len(all_targets)
n_val = int(n * val_ratio)
n_test = int(n * test_ratio)

val_genes = all_targets[:n_val]
test_genes = all_targets[n_val:n_val + n_test]
train_genes = all_targets[n_val + n_test:]

print("Validation genes:", len(val_genes))
print("Test genes:", len(test_genes))
print("Train genes:",len(train_genes))
# -------------------------
# 生成 validation AnnData
# -------------------------
adata_val = adata[adata.obs['target_gene'].isin(val_genes)].copy()
adata_test = adata[adata.obs['target_gene'].isin(test_genes)].copy()
adata_train=adata[adata.obs['target_gene'].isin(train_genes)].copy()
# # 保存文件（英文文件名）
# date="1119"
# adata_val.write_h5ad(f"validation_set_{date}.h5ad")
# adata_test.write_h5ad(f"test_set_{date}.h5ad")
# adata_train.write_h5ad(f"training_set_{date}.h5ad")

# print(f"保存完成：validation_set_{date}.h5ad, test_set_{date}.h5ad, training_set_{date}.h5ad")
# 提取 target_gene 列中的细胞数量
target_gene_counts = adata.obs['target_gene'].value_counts()

# 筛选出 gene_list 中的基因及其对应的细胞数
test_targets = adata_test.obs['target_gene'].unique().tolist()
valid_targets = adata_val.obs['target_gene'].unique().tolist()
test_gene_counts = target_gene_counts[target_gene_counts.index.isin(test_targets)]
valid_gene_counts = target_gene_counts[target_gene_counts.index.isin(valid_targets)]

# 将结果转换为 DataFrame 并保存为 CSV
test_df = test_gene_counts.reset_index()
valid_df = valid_gene_counts.reset_index()
test_df.columns = ['target_gene', 'n_cells']
valid_df.columns = ['target_gene', 'n_cells']
test_df.to_csv(f'test_gene_ncell_{date}.csv', index=False)
valid_df.to_csv(f'valid_gene_ncell_{date}.csv', index=False)
print(test_df)

   target_gene  n_cells
0        MED13     2787
1        STAT1     2493
2        KAT2A     2120
3          KDR     2056
4        HMGN1     2004
5      TSC22D4     1970
6         RNF2     1844
7        ACAT2     1765
8        CREG1     1703
9        HMGB2     1487
10      DHCR24     1420
11      ACVR1B     1400
12       DHX36     1392
13       LZTR1     1381
14       CLDN7     1303
15       CENPO     1193
16        NRAS     1172
17       CHMP3     1168
18        EID2     1167
19      ZNF426     1132
20       PLCB3     1110
21        PMS1     1100
22     ZNF286A     1059
23       PHF14      899
24       C1QBP      711
25       SIN3B      586
26       EWSR1      538
27       MED24      455
28       KDM1A      352
29     SLC39A6      178


## Add control

In [2]:
import numpy as np
import anndata as ad

print("\n读取 anndata (.h5ad)...")
adata = ad.read_h5ad("adata_Training.h5ad")
print(f"adata: {adata}")
# 提取所有 target genes（去除 non-targeting）
all_targets = adata.obs['target_gene'].unique().tolist()
all_targets = [g for g in all_targets if g != "non-targeting"]
ntc_adata = adata[adata.obs["target_gene"] == "non-targeting"] # 添加 non-targeting 控制组

print("扰动基因数量：", len(all_targets))  # 应为 150

# 固定随机种子以保证可复现
np.random.seed(42)

# 打乱顺序
np.random.shuffle(all_targets)

# 划分比例，可根据需要调整
val_ratio = 0.3
test_ratio = 0.2

n = len(all_targets)
n_val = int(n * val_ratio)
n_test = int(n * test_ratio)

val_genes = all_targets[:n_val]
test_genes = all_targets[n_val:n_val + n_test]
train_genes = all_targets[n_val + n_test:]

print("Validation genes:", len(val_genes))
print("Test genes:", len(test_genes))
print("Train genes:",len(train_genes))
# -------------------------
# 生成 validation AnnData
# -------------------------
adata_val = adata[adata.obs['target_gene'].isin(val_genes)].copy()
adata_test = adata[adata.obs['target_gene'].isin(test_genes)].copy()
adata_train=adata[adata.obs['target_gene'].isin(train_genes)].copy()

# Append the non-targeting controls to the example anndata if they're missing
adata_val=ad.concat([adata_val, ntc_adata])
adata_test=ad.concat([adata_test, ntc_adata])
adata_train=ad.concat([adata_train, ntc_adata])

# 保存文件（英文文件名）
date="1119"
adata_val.write_h5ad(f"validation_set_{date}.h5ad")
adata_test.write_h5ad(f"test_set_{date}.h5ad")
adata_train.write_h5ad(f"training_set_{date}.h5ad")

print(f"保存完成：validation_set_{date}.h5ad, test_set_{date}.h5ad, training_set_{date}.h5ad")


读取 anndata (.h5ad)...
adata: AnnData object with n_obs × n_vars = 221273 × 18080
    obs: 'target_gene', 'guide_id', 'batch'
    var: 'gene_id'
扰动基因数量： 150
Validation genes: 45
Test genes: 30
Train genes: 75
保存完成：validation_set_1119.h5ad, test_set_1119.h5ad, training_set_1119.h5ad


## For cell-eval, we need to use log1p data as predicted.

In [4]:
import anndata as ad
import numpy as np
from scipy import sparse
import scanpy as sc

def safe_normalize_log1p(file_path):
    # "test_set_1119.h5ad"
    a = ad.read_h5ad(f"{file_path}.h5ad")
    print("Before:", type(a.X), a.X.dtype, a.X.max())
    #  normalize_total + log1p
    if sparse.issparse(a.X):
        a.X = a.X.astype(np.float32)
    else:
        a.X = np.asarray(a.X, dtype=np.float32)
    sc.pp.normalize_total(a, target_sum=1e4)
    sc.pp.log1p(a)
    print("After:", type(a.X), a.X.dtype, a.X.max())
    a.write_h5ad(f"eval/{file_path}_lognorm.h5ad")
    print("Saved to eval/..._lognorm.h5ad")

#safe_normalize_log1p("test_set_1119")
# safe_normalize_log1p("training_set_1119")
# safe_normalize_log1p("validation_set_1119")
safe_normalize_log1p("small_set")


Before: <class 'scipy.sparse._csr.csr_matrix'> float32 1989.0
After: <class 'scipy.sparse._csr.csr_matrix'> float32 5.1971045
Saved to eval/..._lognorm.h5ad
