# Adipocyte Perturbation: Quickstart
End-to-end setup from repo clone to training and submission generation. Run cells top-to-bottom on a GPU-enabled machine with sufficient disk (≥100 GB recommended).

## 0. Prerequisites
- Python 3.10+, CUDA-capable GPU, and disk headroom (≥100 GB).
- Raw challenge files placed under `data/raw`:
  - obesity_challenge_1.h5ad
  - signature_genes.csv
  - program_proportion.csv
  - program_proportion_local_gtruth.csv
  - predict_perturbations.txt
  - gene_to_predict.txt
- Git access to the repository.

In [None]:
# 1) Clone repo (skip if already inside)
!git clone https://github.com/Koussaisalem/adipocyte-perturbation-prediction.git
%cd adipocyte-perturbation-prediction

!python -m pip install --upgrade pip
!python -m pip install -e ".[dev,notebooks]"

SyntaxError: invalid syntax (3199916407.py, line 2)

In [None]:
# 2) Install core requirements (recommended for clean envs)
!python -m pip install -r requirements.txt

In [None]:
# 4) Verify GPU availability
!nvidia-smi || echo 'nvidia-smi not available'
python - <<'PY'
import torch
print('CUDA available:', torch.cuda.is_available())
print('CUDA devices:', torch.cuda.device_count())
if torch.cuda.is_available():
    print('Device 0:', torch.cuda.get_device_name(0))
PY

In [None]:
# 5) Verify raw data presence
!ls -lh data/raw/Challenge

# ln -s /path/to/challenge/data data/raw/Challenge

In [None]:
# 5b) Unzip the small h5ad if needed
!unzip -o data/raw/Challenge/obesity_challenge_1.h5ad.small.zip -d data/raw/Challenge
!ls -lh data/raw/Challenge | head -n 5

In [None]:
# 6) Run setup helper (checks files, creates all_genes.txt, directories)
!bash setup_codespace.sh

In [None]:
# 7) Build Knowledge Graph (CollecTRI/DoRothEA + STRING)
!python scripts/build_kg.py \
  --gene-list data/processed/all_genes.txt \
  --output data/kg/knowledge_graph.gpickle \
  --dorothea-levels A B \
  --string-threshold 700

## 7b. Fix Geneformer Assets (Git-LFS Workaround)
On CamberCloud and similar environments without `apt-get` access, Git-LFS pointers in the Geneformer package need to be replaced with real `.pkl` files downloaded directly from HuggingFace.

In [None]:
# Fix Geneformer .pkl files (Git-LFS pointers → real data)
# This is needed on CamberCloud where apt-get install git-lfs is not available

import subprocess
import sys
from pathlib import Path

# Find geneformer package location
gf_result = subprocess.run(
    [sys.executable, "-c", "import geneformer; print(geneformer.__path__[0])"],
    capture_output=True, text=True
)
gf_path = Path(gf_result.stdout.strip())
print(f"Geneformer package location: {gf_path}")

# List of .pkl files that need to be real (not Git-LFS pointers)
pkl_files = [
    "gene_dictionaries_gc104M/gene_name_id_dict_gc104M.pkl",
    "gene_dictionaries_gc104M/gene_median_dict_gc104M.pkl",
    "gene_dictionaries_gc104M/ensembl_mapping_dict_gc104M.pkl",
    "gene_dictionaries_gc104M/token_dict_gc104M.pkl",
]

# Check if files are Git-LFS pointers
def is_lfs_pointer(filepath):
    if not filepath.exists():
        return True
    with open(filepath, "rb") as f:
        first_bytes = f.read(20)
    return first_bytes.startswith(b"version https://git-lfs")

needs_fix = any(is_lfs_pointer(gf_path / p) for p in pkl_files)

if needs_fix:
    print("Detected Git-LFS pointers. Downloading real .pkl files from HuggingFace...")
    from huggingface_hub import hf_hub_download
    
    for rel_path in pkl_files:
        local_path = gf_path / rel_path
        local_path.parent.mkdir(parents=True, exist_ok=True)
        
        print(f"  Downloading {rel_path}...")
        downloaded = hf_hub_download(
            repo_id="ctheodoris/Geneformer",
            filename=rel_path,
            local_dir="/tmp/geneformer_assets",
            force_download=True
        )
        # Copy to geneformer package
        import shutil
        shutil.copy2(downloaded, local_path)
        
        # Verify
        with open(local_path, "rb") as f:
            first_bytes = f.read(10)
        print(f"    Verified: first bytes = {first_bytes[:10]}")
    
    print("✓ All .pkl files fixed!")
else:
    print("✓ Geneformer .pkl files are already valid (not LFS pointers)")

## 7c. Download Full Geneformer Model
Downloads the complete Geneformer model including `pytorch_model.bin` (not just .pkl files).

In [None]:
# Download full Geneformer model (including pytorch_model.bin)
from huggingface_hub import snapshot_download
from pathlib import Path

model_dir = Path("models/geneformer_full")
model_dir.mkdir(parents=True, exist_ok=True)

# Check if model already downloaded
model_subdir = model_dir / "gf-12L-104M-i4096"
if (model_subdir / "pytorch_model.bin").exists():
    print(f"✓ Model already exists at {model_subdir}")
else:
    print("Downloading full Geneformer model from HuggingFace...")
    print("This may take a few minutes (~500MB)...")
    
    snapshot_download(
        repo_id="ctheodoris/Geneformer",
        local_dir=str(model_dir),
        # No allow_patterns filter - download everything including pytorch_model.bin
    )
    
    # Verify
    expected_model = model_dir / "gf-12L-104M-i4096" / "pytorch_model.bin"
    if expected_model.exists():
        print(f"✓ Model downloaded successfully: {expected_model}")
        print(f"  Size: {expected_model.stat().st_size / 1e6:.1f} MB")
    else:
        print(f"⚠ Expected model file not found. Checking directory...")
        !ls -la {model_dir}
        !ls -la {model_dir}/gf-12L-104M-i4096/ 2>/dev/null || echo "Model subdirectory not found"

## 7d. Prepare h5ad with Required Columns
Geneformer requires `ensembl_id` in `var` and `n_counts` in `obs`. This cell adds these columns to the dataset.

In [None]:
# Prepare h5ad with ensembl_id and n_counts columns for Geneformer
import scanpy as sc
import pickle
import numpy as np
from pathlib import Path
import subprocess
import sys

# Input and output paths
input_h5ad = Path("data/raw/Challenge/obesity_challenge_1.h5ad")
output_h5ad = Path("data/raw/Challenge/obesity_challenge_1.prepared.h5ad")

if output_h5ad.exists():
    print(f"✓ Prepared file already exists: {output_h5ad}")
else:
    print(f"Loading {input_h5ad}...")
    adata = sc.read_h5ad(input_h5ad)
    print(f"  Shape: {adata.shape[0]} cells x {adata.shape[1]} genes")
    
    # 1. Add ensembl_id column using Geneformer's gene name mapping
    if "ensembl_id" not in adata.var.columns:
        print("Adding ensembl_id column...")
        
        # Find geneformer package and load its gene name dict
        gf_result = subprocess.run(
            [sys.executable, "-c", "import geneformer; print(geneformer.__path__[0])"],
            capture_output=True, text=True
        )
        gf_path = Path(gf_result.stdout.strip())
        gene_name_dict_path = gf_path / "gene_dictionaries_gc104M" / "gene_name_id_dict_gc104M.pkl"
        
        with open(gene_name_dict_path, "rb") as f:
            gene_name_to_ensembl = pickle.load(f)
        
        # Map gene symbols to Ensembl IDs
        ensembl_ids = []
        missing_count = 0
        for gene in adata.var_names:
            if gene in gene_name_to_ensembl:
                ensembl_ids.append(gene_name_to_ensembl[gene])
            else:
                ensembl_ids.append(f"UNKNOWN_{gene}")
                missing_count += 1
        
        adata.var["ensembl_id"] = ensembl_ids
        mapped_count = len(adata.var) - missing_count
        print(f"  Mapped {mapped_count}/{len(adata.var)} genes ({100*mapped_count/len(adata.var):.1f}%)")
    else:
        print("✓ ensembl_id column already exists")
    
    # 2. Add n_counts column (total counts per cell)
    if "n_counts" not in adata.obs.columns:
        print("Adding n_counts column...")
        
        # Handle both sparse and dense matrices
        X = adata.X
        if hasattr(X, "toarray"):
            n_counts = np.array(X.sum(axis=1)).flatten()
        else:
            n_counts = np.array(X.sum(axis=1)).flatten()
        
        adata.obs["n_counts"] = n_counts
        print(f"  n_counts range: {n_counts.min():.0f} - {n_counts.max():.0f}")
    else:
        print("✓ n_counts column already exists")
    
    # Save prepared file
    print(f"Saving prepared file to {output_h5ad}...")
    adata.write_h5ad(output_h5ad)
    print(f"✓ Prepared file saved: {output_h5ad}")
    print(f"  Size: {output_h5ad.stat().st_size / 1e6:.1f} MB")

## 7e. Create Fixed Embedding Extraction Script
The original script has issues with tempfile directories. This creates a fixed version that uses stable paths.

In [None]:
%%writefile scripts/extract_embeddings_fixed.py
#!/usr/bin/env python3
"""
Fixed Geneformer embedding extraction script.
Key fixes:
  - Uses fixed work directory instead of tempfile (avoids tokenizer path issues)
  - Uses input_identifier=chunk_path.stem (no .h5ad extension)
  - Conservative memory settings for GPU environments
"""

import argparse
import logging
import shutil
import sys
from pathlib import Path

import pandas as pd
import torch

sys.path.insert(0, str(Path(__file__).parent.parent))

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)


def load_geneformer():
    try:
        from geneformer import TranscriptomeTokenizer, EmbExtractor
    except ImportError:
        logger.error("Geneformer not installed. Install with: pip install git+https://huggingface.co/ctheodoris/Geneformer.git")
        sys.exit(1)
    return TranscriptomeTokenizer, EmbExtractor


def clean_dir(path: Path):
    if path.exists():
        for item in path.iterdir():
            if item.is_dir():
                shutil.rmtree(item)
            else:
                item.unlink()


def main():
    parser = argparse.ArgumentParser(description="Extract Geneformer gene embeddings with chunking (fixed version)")
    parser.add_argument("--h5ad-file", required=True, help="Path to h5ad file (must have ensembl_id and n_counts)")
    parser.add_argument("--model-dir", required=True, help="Path to Geneformer model directory (e.g., models/geneformer_full/gf-12L-104M-i4096)")
    parser.add_argument("--output", default="data/processed/gene_embeddings.pt", help="Path to save embeddings")
    parser.add_argument("--work-dir", default="work/geneformer", help="Working directory for intermediate files")
    parser.add_argument("--gene-list", default=None, help="Optional file with genes to keep in output")
    parser.add_argument("--batch-size", type=int, default=1, help="Forward batch size (use 1 for low memory)")
    parser.add_argument("--max-cells", type=int, default=1000, help="Total cells to process")
    parser.add_argument("--chunk-cells", type=int, default=100, help="Cells per chunk for tokenization")
    parser.add_argument("--emb-layer", type=int, choices=[-1, 0], default=-1, help="Embedding layer: -1 (2nd last) or 0 (last)")
    parser.add_argument("--random", action="store_true", help="Use random embeddings (testing only)")
    args = parser.parse_args()

    output_path = Path(args.output)
    output_path.parent.mkdir(parents=True, exist_ok=True)

    # Random mode for quick testing
    if args.random:
        import scanpy as sc
        adata = sc.read_h5ad(args.h5ad_file, backed="r")
        genes = list(adata.var_names)
        embedding_dim = 512
        embeddings = {g: torch.randn(embedding_dim) for g in genes}
        torch.save(embeddings, output_path)
        logger.info(f"Saved {len(embeddings)} random embeddings to {output_path}")
        return

    TranscriptomeTokenizer, EmbExtractor = load_geneformer()
    model_dir = args.model_dir

    # Verify model exists
    model_path = Path(model_dir)
    if not (model_path / "pytorch_model.bin").exists() and not (model_path / "model.safetensors").exists():
        logger.error(f"Model not found at {model_dir}. Expected pytorch_model.bin or model.safetensors")
        sys.exit(1)

    import scanpy as sc

    # Set up fixed work directory
    work_path = Path(args.work_dir)
    data_dir = work_path / "data"
    token_dir = work_path / "tokenized"
    emb_dir = work_path / "embeddings"
    
    for d in [data_dir, token_dir, emb_dir]:
        d.mkdir(parents=True, exist_ok=True)
        clean_dir(d)

    # Open backed to avoid loading whole file
    adata = sc.read_h5ad(args.h5ad_file, backed="r")
    n_cells = adata.n_obs
    logger.info(f"Backed dataset: {n_cells} cells x {adata.n_vars} genes")

    # Determine total cells to process
    total_cells = min(args.max_cells, n_cells)
    chunk_size = min(args.chunk_cells, total_cells)
    if chunk_size <= 0:
        logger.error("Chunk size must be positive")
        sys.exit(1)

    agg_sum: dict[str, torch.Tensor] = {}
    counts: dict[str, int] = {}
    processed = 0
    chunk_idx = 0

    tokenizer = TranscriptomeTokenizer(nproc=1, model_input_size=4096, special_token=True, model_version="V2")

    while processed < total_cells:
        start = processed
        end = min(processed + chunk_size, total_cells)
        logger.info(f"Processing chunk {chunk_idx}: cells {start} to {end} (total {total_cells})")

        # Load slice into memory
        slice_indices = list(range(start, end))
        chunk_adata = adata[slice_indices, :].to_memory()
        
        # Use a simple name without .h5ad extension for input_identifier
        chunk_name = f"chunk_{chunk_idx}"
        chunk_path = data_dir / f"{chunk_name}.h5ad"
        chunk_adata.write_h5ad(chunk_path)
        del chunk_adata

        # Tokenize chunk - KEY FIX: use stem (no .h5ad extension)
        clean_dir(token_dir)
        logger.info(f"Tokenizing with input_identifier={chunk_name}")
        tokenizer.tokenize_data(
            data_directory=str(data_dir),
            output_directory=str(token_dir),
            output_prefix=f"tokenized_{chunk_idx}",
            file_format="h5ad",
            input_identifier=chunk_name,  # No .h5ad extension!
        )

        tokenized_files = list(token_dir.glob(f"tokenized_{chunk_idx}*.dataset"))
        if not tokenized_files:
            logger.error(f"No tokenized dataset produced for chunk {chunk_idx}")
            logger.info(f"Token dir contents: {list(token_dir.iterdir())}")
            sys.exit(1)
        tokenized_path = tokenized_files[0]
        logger.info(f"Tokenized dataset: {tokenized_path}")

        # Extract embeddings
        clean_dir(emb_dir)
        emb_extractor = EmbExtractor(
            model_type="Pretrained",
            emb_mode="gene",
            gene_emb_style="mean_pool",
            max_ncells=None,
            emb_layer=args.emb_layer,
            forward_batch_size=args.batch_size,
            nproc=1,
            model_version="V2",
        )

        logger.info(f"Extracting embeddings using model: {model_dir}")
        emb_extractor.extract_embs(
            model_directory=model_dir,
            input_data_file=str(tokenized_path),
            output_directory=str(emb_dir),
            output_prefix=f"gene_embs_{chunk_idx}",
            output_torch_embs=True,
        )

        # Load embeddings from torch if available, else CSV
        embeddings_chunk: dict[str, torch.Tensor] = {}
        torch_files = list(emb_dir.glob("*.pt"))
        if torch_files:
            data_pt = torch.load(torch_files[0], weights_only=False)
            if isinstance(data_pt, dict):
                embeddings_chunk = {k: (v if isinstance(v, torch.Tensor) else torch.tensor(v)) for k, v in data_pt.items()}
        if not embeddings_chunk:
            csv_files = list(emb_dir.glob("*.csv"))
            if csv_files:
                emb_df = pd.read_csv(csv_files[0], index_col=0)
                embeddings_chunk = {g: torch.tensor(emb_df.loc[g].values, dtype=torch.float32) for g in emb_df.index}
        if not embeddings_chunk:
            logger.error("Failed to load embeddings for chunk")
            logger.info(f"Emb dir contents: {list(emb_dir.iterdir())}")
            sys.exit(1)

        logger.info(f"Loaded {len(embeddings_chunk)} gene embeddings from chunk {chunk_idx}")

        # Aggregate mean across chunks
        for gene, vec in embeddings_chunk.items():
            if gene not in agg_sum:
                agg_sum[gene] = vec.clone().float()
                counts[gene] = 1
            else:
                agg_sum[gene] += vec.float()
                counts[gene] += 1

        processed = end
        chunk_idx += 1
        
        # Clean up chunk file to save space
        if chunk_path.exists():
            chunk_path.unlink()

        # Clear CUDA cache to prevent OOM
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    adata.file.close()

    # Finalize embeddings (mean across chunks)
    final_embeddings = {g: agg_sum[g] / counts[g] for g in agg_sum}

    # Filter to gene list if provided
    if args.gene_list:
        with open(args.gene_list) as f:
            target = set(line.strip() for line in f if line.strip())
        final_embeddings = {g: v for g, v in final_embeddings.items() if g in target}
        logger.info(f"Filtered to {len(final_embeddings)} genes from gene_list")

    torch.save(final_embeddings, output_path)
    emb_dim = next(iter(final_embeddings.values())).shape[0] if final_embeddings else 0
    logger.info(f"Saved embeddings to {output_path}")
    logger.info(f"  {len(final_embeddings)} genes x {emb_dim} dimensions")


if __name__ == "__main__":
    main()

### Embedding Extraction Notes
- Requires GPU and disk headroom; chunked to reduce memory.
- Tune `--max-cells`, `--chunk-cells`, and `--batch-size` to your hardware.

In [None]:
# 8) Extract Geneformer embeddings using the FIXED script
# Uses the prepared h5ad with ensembl_id and n_counts columns
# Memory-safe settings: batch-size=1, chunk-cells=100-200

# Adjust these based on your GPU memory:
# - L4 (22GB): chunk-cells=100-200, max-cells=2000-5000
# - A100 (40GB): chunk-cells=500, max-cells=10000+

!python scripts/extract_embeddings_fixed.py \
  --h5ad-file data/raw/Challenge/obesity_challenge_1.prepared.h5ad \
  --model-dir models/geneformer_full/gf-12L-104M-i4096 \
  --max-cells 2000 \
  --chunk-cells 200 \
  --batch-size 1 \
  --output data/processed/gene_embeddings.pt

### Training Notes
- Defaults: AdamW lr=1e-4, batch_size=64, epochs=100, precision=16-mixed, early stopping on val/mmd.
- Adjust batch size or `accumulate_grad_batches` if you hit GPU OOM.

In [None]:
# 9) Baseline training run
!python scripts/train.py \
  --config configs/default.yaml \
  --seed 42 \
  2>&1 | tee experiments/logs/baseline_run.log

### Experiment Variants (optional)
- Increase MMD weight: create `configs/high_mmd.yaml` with `losses.mmd_weight: 0.2`, `losses.pearson_weight: 0.1`.
- Deeper GAT: `gat_layers: 4`, `gat_heads: 16`, `gat_hidden_dim: 256` (lower batch if needed).
- Higher PCA: `flow_matching.pca_components: 750` if memory allows.

In [None]:
# 10) Generate submission from best checkpoint
!python scripts/generate_submission.py \
  --checkpoint checkpoints/best.ckpt \
  --output-dir submissions \
  --n-cells 100 \
  --batch-size 10 \
  2>&1 | tee experiments/logs/inference.log

### Quick Validation
- Expected expression rows: 286,301 (including header).
- NaN check on expression matrix.

In [None]:
# 11) Validate submission files
!wc -l submissions/expression_matrix.csv
!head submissions/program_proportions.csv
import pandas as pd
df = pd.read_csv('submissions/expression_matrix.csv')
print('NaNs:', df.isna().sum().sum())