In [3]:
# %% [markdown]
# # Jupyter版：複数遺伝子 同時削除オプション → UMAP（PDFのみ・列は不変）
# - 既存の Arrow（列: `input_ids`, `cell_types`, `organ_major`, `disease`, `length`）をそのまま使用
# - `disease` は Cop1_WT / Cop1_KO を前提
# - KO行に対し、指定した複数遺伝子の token を**同時に削除**（列は不変、`length`のみ再計算）
# - WT, KO, KO(各Delシナリオ...) を**同一UMAP空間**へ投影、PDF保存
# - EmbExtractor: MODEL_TYPE="Pretrained", num_classes=0（int）, `model_directory`はextract_embs()で渡す

# %%
# === Cell 1: 依存関係の確認 / 必要ならインストール ===
import sys, subprocess, importlib

def _ensure(pkg, pip_name=None):
    pip_name = pip_name or pkg
    try:
        importlib.import_module(pkg)
    except Exception:
        subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", pip_name])
        importlib.import_module(pkg)

# tqdmのIProgress問題を回避（ipywidgetsが無い環境向け）
try:
    import ipywidgets  # noqa
except Exception:
    from tqdm import auto as tqdmauto
    import tqdm.notebook as tqdmb
    tqdmb.tqdm = tqdmauto.tqdm
    tqdmb.trange = tqdmauto.trange
    tqdmb.tnrange = tqdmauto.trange

for pkg, pipn in [
    ("datasets", "datasets"),
    ("umap", "umap-learn"),
    ("matplotlib", "matplotlib"),
    ("numpy", "numpy"),
    ("pandas", "pandas"),
]:
    _ensure(pkg, pipn)

# Geneformer はリポジトリから利用想定（環境依存のためpipインストールは行わない）

# %%
# === Cell 2: インポート & ロギング設定 ===
import os, pickle, json, shutil, time, math, random
from pathlib import Path
from typing import List, Optional, Dict, Iterable, Tuple
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from datasets import load_dataset, Dataset
import umap
import logging

# Geneformer EmbExtractor
from geneformer.emb_extractor import EmbExtractor

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s | %(levelname)s | %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
    force=True,
)
log = logging.getLogger("UMAP-MultiDel")

def log_step(msg: str):
    log.info(msg)

# %%
# === Cell 3: パラメータ ===
# ----- 入力 Arrow -----
ARROW_PATH = "/work/eval_dataset/Cop1KO_isp_mouse_tokenize_dataset/data-00000-of-00001.arrow"  # ←あなたのArrowパス
# /work/scRNA-seq_data/Adrenal_scRNA-seq_MS/output_arrow/data-00000-of-00001.arrow
# /work/eval_dataset/Cop1KO_isp_mouse_tokenize_dataset/data-00000-of-00001.arrow
CONDITION_COL = "disease"


# ==== Cell 3: パラメータの追加/変更 ====
# どちらを削除ベースにするか: "WT" or "KO"
BASE_FOR_DELETION = "KO"  

VALUE_KO = "Cop1_KO" # 遺伝子削除を行う方のdisease名を指定
VALUE_WT = "Cop1_WT" # 遺伝子削除後に近づけたい目標状態を指定

# シナリオ（削除する遺伝子集合）—例
MULTI_DELETE_GENE_SETS = [
    ["Apoe"], 
    # ["Apoe", "Lrp1"],
    # ["Apoe", "Lrp1", "Ldlr"],
]


# ----- 辞書（提供済み） -----
PATH_GENE_SYMBOL2ENS = "/work/mouse-geneformer/dictionary_pickle/MLM-re_token_dictionary_v1_GeneSymbol_to_EnsemblID.pkl"
TOKEN_DICTIONARY_FILE = "/work/mouse-geneformer/dictionary_pickle/MLM-re_token_dictionary_v1.pkl"


# ----- 作業/出力 -----
WORK_DIR = Path("/work/kyushu_univ/Umap")
OUT_DIR  = WORK_DIR / "out"
EMB_CSV_DIR = OUT_DIR / "emb_csv"
OUT_PREFIX = "UMAP_Cpo1KO"  # 出力PDFファイル名の接頭辞

# ----- 一時 .dataset 出力（シナリオごとに動的作成） -----
TMP_WT = WORK_DIR / "tmp_WT.dataset"
TMP_KO = WORK_DIR / "tmp_KO.dataset"
TMP_SCENARIO_ROOT = WORK_DIR / "tmp_scenarios"  # 中に各シナリオの .dataset を作成

# ----- Geneformer（EmbExtractor）設定 -----
FINETUNED_MODEL_DIR = "/work/kyushu_univ/cop1_pretrain"  # ←あなたのモデルパス
# /work/results/251021_mouse-geneformer_CellClassifier_nan_L2048_B12_LR5e-05_LSlinear_WU500_E20_OadamW_F0_ISP-nan
# /work/kyushu_univ/cop1_pretrain
MODEL_TYPE = "Pretrained"   # {"Pretrained","GeneClassifier","CellClassifier"}
EMB_LAYER = 0
FORWARD_BATCH_SIZE = 10
NPROC = 8
MAX_NCELLS_PER_STATE = 5000  # UMAP負荷対策

# ----- UMAP 設定 -----
UMAP_N_NEIGHBORS = 15
UMAP_MIN_DIST    = 0.1
UMAP_METRIC      = "cosine"
UMAP_RANDOM_STATE= 42

# ----- 描画設定 -----
POINT_SIZE = 1.5
ALPHA = 0.85
SAVE_PDF = True

np.random.seed(UMAP_RANDOM_STATE)
random.seed(UMAP_RANDOM_STATE)

# %%
# === Cell 4: ユーティリティ ===

# キャッシュして高速化
_SYM2ENS_CACHE = None
_ENS2TOK_CACHE = None
def _load_dicts():
    global _SYM2ENS_CACHE, _ENS2TOK_CACHE
    if _SYM2ENS_CACHE is None:
        with open(PATH_GENE_SYMBOL2ENS, "rb") as f:
            _SYM2ENS_CACHE = pickle.load(f)
    if _ENS2TOK_CACHE is None:
        with open(TOKEN_DICTIONARY_FILE, "rb") as f:
            _ENS2TOK_CACHE = pickle.load(f)

def ensure_dirs(clean_tmp=True):
    OUT_DIR.mkdir(parents=True, exist_ok=True)
    EMB_CSV_DIR.mkdir(parents=True, exist_ok=True)
    WORK_DIR.mkdir(parents=True, exist_ok=True)
    if clean_tmp:
        for p in (TMP_WT, TMP_KO, TMP_SCENARIO_ROOT):
            if p.exists():
                shutil.rmtree(p)
    TMP_SCENARIO_ROOT.mkdir(parents=True, exist_ok=True)

def resolve_token_ids(symbols: Iterable[str]) -> Dict[str, int]:
    """GeneSymbol -> token_id の辞書を返す（全て解決できない場合はエラー）"""
    _load_dicts()
    sym2tok = {}
    missing_syms = []
    for sym in symbols:
        ens = _SYM2ENS_CACHE.get(sym)
        if ens is None:
            missing_syms.append(sym)
            continue
        tok = _ENS2TOK_CACHE.get(ens)
        if isinstance(tok, (int, np.integer)):
            sym2tok[sym] = int(tok)
        else:
            missing_syms.append(sym)
    if missing_syms:
        raise ValueError(f"辞書に見つからない/解決できない遺伝子がありました: {missing_syms}")
    # ログ
    for s, t in sym2tok.items():
        log_step(f"[辞書] {s} → token_id={t}")
    return sym2tok

def load_arrow_dataset(path: str) -> Dataset:
    log_step(f"[読込] Arrow を読み込み中: {path}")
    ds = load_dataset("arrow", data_files=path)["train"]
    log_step(f"[読込] 行数={len(ds)}, 列={ds.column_names}")
    assert CONDITION_COL in ds.column_names, f"{CONDITION_COL} 列が見つかりません"
    return ds

def split_dataset(ds: Dataset):
    log_step(f"[分割] disease == '{VALUE_WT}' / '{VALUE_KO}' でフィルタ")
    ds_wt = ds.filter(lambda x: x[CONDITION_COL] == VALUE_WT, desc="filter WT")
    ds_ko = ds.filter(lambda x: x[CONDITION_COL] == VALUE_KO, desc="filter KO")
    log_step(f"[分割] WT={len(ds_wt)} 行, KO={len(ds_ko)} 行")
    assert len(ds_wt) > 0 and len(ds_ko) > 0, "WT/KO が見つかりませんでした"
    return ds_wt, ds_ko

def remove_tokens_from_ids(ids: List[int], token_ids: set) -> List[int]:
    if not token_ids:
        return list(ids)
    return [t for t in ids if t not in token_ids]

def make_del_dataset_in_memory(ds_ko: Dataset, del_token_ids: set) -> Dataset:
    """
    KOの input_ids から del_token_ids を**同時に**削除し、length を再計算。
    列スキーマは不変（新カラムは作らない）。
    """
    gene_count = len(del_token_ids)
    log_step(f"[DelMap] KOに対し {gene_count} 遺伝子のtokenを同時削除 → 仮想KO_Del")
    def _map_fn(ex):
        new_ids = remove_tokens_from_ids(ex["input_ids"], del_token_ids)
        ex["input_ids"] = new_ids
        if "length" in ex:
            ex["length"] = len(new_ids)
        return ex
    return ds_ko.map(_map_fn, desc=f"remove {gene_count} tokens from KO", load_from_cache_file=False)

def save_to_disk(ds: Dataset, out_dir: Path):
    log_step(f"[保存] save_to_disk → {out_dir}")
    ds.save_to_disk(str(out_dir))

def scenario_label_from_genes(genes: List[str], base_label: str) -> str:
    """凡例用ラベル: e.g., Cop1_WT_Del[Apoe+Lrp1] or Cop1_KO_Del[...]"""
    genes_sorted = "+".join(sorted(genes))
    return f"{base_label}_Del[{genes_sorted}]"

def extract_embs_from_dir(dataset_dir: Path, state_tag: str) -> np.ndarray:
    """
    - dataset_dir: save_to_disk した .dataset のフォルダ
    - state_tag:   出力CSVの接頭辞
    戻り値: (n_cells, emb_dim)
    """
    log_step(f"[埋込] EmbExtractor 初期化")
    embex = EmbExtractor(
        model_type=MODEL_TYPE,
        num_classes=0,
        emb_mode="cell",
        cell_emb_style="mean_pool",
        filter_data=None,
        max_ncells=MAX_NCELLS_PER_STATE,
        emb_layer=EMB_LAYER,
        emb_label=None,
        labels_to_plot=None,
        forward_batch_size=FORWARD_BATCH_SIZE,
        nproc=NPROC,
        summary_stat=None,
        token_dictionary_file=TOKEN_DICTIONARY_FILE,
    )
    EMB_CSV_DIR.mkdir(parents=True, exist_ok=True)
    log_step(f"[埋込] extract_embs(model_directory=..., input_data_file=..., output_directory=..., output_prefix={state_tag})")
    embs_df = embex.extract_embs(
        model_directory=FINETUNED_MODEL_DIR,
        input_data_file=str(dataset_dir),
        output_directory=str(EMB_CSV_DIR),
        output_prefix=state_tag,
    )
    embs = embs_df.to_numpy(dtype=float, copy=False)
    log_step(f"[埋込] shape={embs.shape}, saved CSV under {EMB_CSV_DIR}")
    return embs

def run_umap(X: np.ndarray) -> np.ndarray:
    log_step("[UMAP] fit_transform 開始 …")
    reducer = umap.UMAP(
        n_neighbors=UMAP_N_NEIGHBORS,
        min_dist=UMAP_MIN_DIST,
        metric=UMAP_METRIC,
        random_state=UMAP_RANDOM_STATE,
        n_components=2,
    )
    coords = reducer.fit_transform(X)
    log_step("[UMAP] 完了")
    return coords

def _auto_colors(n: int) -> List[str]:
    """matplotlibのタブカラーパレットから必要数確保（n>10なら再利用）"""
    base = plt.rcParams['axes.prop_cycle'].by_key().get('color', [])
    if not base:
        base = [f"C{i}" for i in range(10)]
    if n <= len(base):
        return base[:n]
    out = []
    i = 0
    for k in range(n):
        out.append(base[i % len(base)])
        i += 1
    return out

def plot_pdf(coords: np.ndarray, labels: np.ndarray, legend_order: List[str], colors_map: Dict[str, str],
             title: str, out_pdf: Path):
    log_step(f"[描画] PDF 保存: {out_pdf}")
    plt.figure(figsize=(7.8, 6.8))
    for k in legend_order:
        pts = coords[labels == k]
        if len(pts) == 0:
            continue
        plt.scatter(pts[:, 0], pts[:, 1], s=POINT_SIZE, alpha=ALPHA, label=k, c=colors_map.get(k, None))
    plt.legend(title="state", markerscale=3, loc="best", frameon=True)
    plt.title(title)
    plt.xlabel("UMAP-1"); plt.ylabel("UMAP-2")
    plt.tight_layout()
    plt.savefig(out_pdf)
    plt.close()

# %%
# === Cell 5: 実行フロー ===
start_ts = time.time()
ensure_dirs(clean_tmp=True)

# 読み込み & 分割
ds = load_arrow_dataset(ARROW_PATH)
ds_wt, ds_ko = split_dataset(ds)

# ベース2状態を保存（WT/KO）
save_to_disk(ds_wt, TMP_WT)
save_to_disk(ds_ko, TMP_KO)

# ★ ここから追加：削除ベースの選択
if BASE_FOR_DELETION.upper() == "WT":
    ds_base = ds_wt
    base_label = VALUE_WT
else:
    ds_base = ds_ko
    base_label = VALUE_KO

# 削除シナリオの token_id 解析
all_symbols = sorted({s for subset in MULTI_DELETE_GENE_SETS for s in subset})
sym2tok = resolve_token_ids(all_symbols) if all_symbols else {}

# シナリオごとに .dataset 作成
scenario_dirs: List[Tuple[str, Path]] = []  # (label, dir)
scenario_dirs.append((VALUE_WT, TMP_WT))
scenario_dirs.append((VALUE_KO, TMP_KO))

for genes in MULTI_DELETE_GENE_SETS:
    label = scenario_label_from_genes(genes, base_label=base_label)
    out_dir = TMP_SCENARIO_ROOT / f"{label}.dataset"
    if out_dir.exists():
        shutil.rmtree(out_dir)
    token_ids = {sym2tok[g] for g in genes} if genes else set()
    ds_del = make_del_dataset_in_memory(ds_base, token_ids)  # ★ WT もしくは KO に対して削除
    save_to_disk(ds_del, out_dir)
    scenario_dirs.append((label, out_dir))

# 埋め込み抽出（順序＝凡例順に利用）
emb_arrays = []
labels_vec = []
legend_order = [lbl for (lbl, _) in scenario_dirs]

for label, ddir in scenario_dirs:
    arr = extract_embs_from_dir(ddir, state_tag=label)
    emb_arrays.append(arr)
    labels_vec.extend([label]*len(arr))

# UMAP
X_all = np.vstack(emb_arrays)
y = np.array(labels_vec)
coords = run_umap(X_all)

# 色
palette = _auto_colors(len(legend_order))
color_map = {lbl: palette[i] for i, lbl in enumerate(legend_order)}

# PDF 保存
OUT_DIR.mkdir(parents=True, exist_ok=True)
out_pdf = OUT_DIR / f"{OUT_PREFIX}.pdf"
plot_pdf(
    coords, y, legend_order, color_map,
    title=f"UMAP: {VALUE_WT} vs {VALUE_KO} vs Multi-Gene Deletions (KO base)",
    out_pdf=out_pdf
)

elapsed = time.time() - start_ts
log_step(f"[完了] 出力: {out_pdf}")
log_step(f"[完了] 所要時間: {elapsed/60:.2f} 分")


2025-10-30 08:39:34 | INFO | [読込] Arrow を読み込み中: /work/eval_dataset/Cop1KO_isp_mouse_tokenize_dataset/data-00000-of-00001.arrow


FileNotFoundError: Unable to find '/work/eval_dataset/Cop1KO_isp_mouse_tokenize_dataset/data-00000-of-00001.arrow'