In [10]:
# ============================================================
# 1.パス設定と追加カラムの指定（ここだけ変えればOK）
# ============================================================
INPUT_DIR = "/home/pj25000162/ku50001814/scRNA-seq_data/Adrenal_scRNA-seq_MC/output_arrow"
OUTPUT_DIR = "/home/pj25000162/ku50001814/scRNA-seq_data/Adrenal_scRNA-seq_MC/output_arrow_label"

NEW_COLUMNS = {
    "disease": "MC",
    "cell_types" : "nan",
    "organ_major" : "nan",
}

# ============================================================
# 2.実装部分（そのままでOK）
# ============================================================
from datasets import load_from_disk, Dataset, DatasetDict
from typing import Dict, Union, Callable, Any

def load_dataset_folder(path: str) -> DatasetDict:
    """フォルダを読み込んで DatasetDict 形式に統一"""
    ds = load_from_disk(path)
    if isinstance(ds, Dataset):
        return DatasetDict({"train": ds})
    return ds

def add_columns(dsdict: DatasetDict,
                new_columns: Dict[str, Union[Any, Callable[[Dict], Any]]],
                batched: bool = True,
                num_proc: int = None) -> DatasetDict:
    """新しいカラムを全 split に追加"""
    out = {}
    for split, ds in dsdict.items():
        def _mapper(ex):
            result = {}
            for col, val in new_columns.items():
                if callable(val):
                    result[col] = val(ex)
                else:
                    # 固定値なら、バッチサイズ分に展開
                    batch_size = len(next(iter(ex.values())))
                    result[col] = [val] * batch_size
            return result

        out[split] = ds.map(
            _mapper,
            batched=batched,
            num_proc=num_proc,
            desc=f"Adding columns to {split}",
        )
    return DatasetDict(out)

# ============================================================
# 3.実行
# ============================================================
print("入力データセットを読み込み中...")
dsdict = load_dataset_folder(INPUT_DIR)

print("追加前カラム:", {k: v.column_names for k, v in dsdict.items()})

# カラム追加
dsdict = add_columns(dsdict, NEW_COLUMNS)

print("追加後カラム:", {k: v.column_names for k, v in dsdict.items()})

# データの中身を少し確認（最初の1サンプルだけ）
print("\nデータの最初の1サンプルを表示:")
first_sample = dsdict["train"][0]
for key, value in first_sample.items():
    if isinstance(value, list) and len(value) > 10:
        # 長すぎる配列は省略して表示
        print(f"  {key}: {value[:10]} ... ({len(value)} elements)")
    else:
        print(f"  {key}: {value}")

# 保存
print("\n新しいデータセットを保存中...")
dsdict["train"].save_to_disk(OUTPUT_DIR)

print(f"完了！\n保存先: {OUTPUT_DIR}")


入力データセットを読み込み中...
追加前カラム: {'train': ['input_ids', 'length']}


Adding columns to train:   0%|          | 0/2547 [00:00<?, ? examples/s]

追加後カラム: {'train': ['input_ids', 'length', 'disease', 'cell_types', 'organ_major']}

データの最初の1サンプルを表示:
  input_ids: [32123, 36856, 17972, 17979, 17978, 17966, 17975, 17962, 17990, 6359] ... (2048 elements)
  length: 2048
  disease: MC
  cell_types: nan
  organ_major: nan

新しいデータセットを保存中...


Saving the dataset (0/1 shards):   0%|          | 0/2547 [00:00<?, ? examples/s]

完了！
保存先: /home/pj25000162/ku50001814/scRNA-seq_data/Adrenal_scRNA-seq_MC/output_arrow_label
