# Single-nuclei Pseudobulk Preprocessing (RNA-seq and ATAC-seq) Pipeline

## Overview

This pipeline preprocesses single-nuclei pseudobulk count data (snATAC-seq or snRNA-seq)
for downstream QTL analysis and region-specific studies.

**Goals:**
- Transform raw pseudobulk counts into analysis-ready formats
- Remove technical confounders while preserving biological covariates (sex, age)
- Generate QTL-ready phenotype files or region-specific datasets

## Pipeline Structure

```
Step 0: Sample ID Mapping        [sampleid_mapping]
         ↓
Step 1: Pseudobulk QC            [pseudobulk_qc]
         noBIOvar: regress out technical covariates only
         (msex and age_death deliberately preserved)
         ↓ (optional)
         Batch Correction (ComBat-seq or limma::removeBatchEffect)
         ↓ (optional)
         Quantile Normalization
         ↓
Step 2: Format Output
         ├── Phenotype Reformatting → BED    [phenotype_formatting]  (genome-wide QTL mapping, snATAC-seq only, locus-specific)
         └── Region Peak Filtering  → TSV    [region_filtering]  (gene filtering for snRNA-seq)
```

## Modality Support

| Feature | snATAC-seq | snRNA-seq |
|---------|-----------|-----------|
| Count file auto-detected | ✓ | ✓ |
| Default `tech_vars` | `log_n_nuclei`, `med_nucleosome_signal`, `med_tss_enrich`, `log_med_n_tot_fragment`, `log_total_unique_peaks` | custom via `--tech_vars` |
| Blacklist filtering | ✓ | — |
| `region_filtering` step | ✓ | — |
| `phenotype_formatting` step | ✓ | ✓ |

For snRNA-seq, override `tech_vars` to match your metadata columns, e.g.:
```bash
--tech_vars log_n_nuclei percent_mito log_n_genes
```

Any `tech_var` starting with `log_` is automatically derived via `log1p()` from the
raw column of the same name with `log_` stripped. No code changes needed across modalities.

## Input Files

All input files required to run this pipeline can be downloaded
[here](https://drive.google.com/drive/folders/1UzJuHN8SotMn-PJTBp9uGShD25YxapKr?usp=drive_link).

| File | Used in |
|------|---------|
| `pseudobulk_peaks_counts_{celltype}.csv.gz` *(snATAC-seq)* | Step 0, Step 1 |
| `pseudobulk_counts_{celltype}.csv.gz` *(snRNA-seq)* | Step 0, Step 1 |
| `metadata_{celltype}.csv` | Step 0, Step 1 |
| `rosmap_sample_mapping_data.csv` | Step 0 |
| `rosmap_cov.txt` | Step 1 |
| `hg38-blacklist.v2.bed.gz` | Step 1 (snATAC-seq only) |

Count files are **auto-detected** from `input_dir` — no prefix parameter needed.

## Parameters

### `sampleid_mapping`
| Parameter | Default | Description |
|-----------|---------|-------------|
| `map_file` | *required* | CSV with `individualID` → `sampleid` mapping |
| `input_dir` | *required* | Directory with raw metadata and count files |
| `output_dir` | *required* | Parent output directory; writes to `{output_dir}/1_files_with_sampleid/` |
| `celltype` | `['Ast','Ex','In','Microglia','Oligo','OPC']` | Cell types to process |
| `suffix` | `''` | Optional filename suffix (e.g. `_50nuc`) |

### `pseudobulk_qc`
| Parameter | Default | Description |
|-----------|---------|-------------|
| `input_dir` | *required* | Directory with remapped metadata and count files |
| `output_dir` | *required* | Parent output directory; writes to `{output_dir}/2_residuals/{ct}/` |
| `covariates_file` | *required* | Covariate file with `pmi` and `study` columns |
| `blacklist_file` | `''` | Genomic blacklist BED file (snATAC-seq only) |
| `sample_list` | `''` | Optional file with one sample ID per line to subset |
| `tech_vars` | `['log_n_nuclei','med_nucleosome_signal','med_tss_enrich','log_med_n_tot_fragment','log_total_unique_peaks']` | Technical covariates for the model |
| `batch_correction` | `FALSE` | Apply batch correction (`TRUE`/`FALSE`) |
| `batch_method` | `limma` | Batch correction method (`limma` or `combat`) |
| `quant_norm` | `FALSE` | Apply quantile normalization after residuals |
| `min_count` | `5` | Min reads in at least one sample |
| `min_total_count` | `15` | Min total reads across all samples |
| `min_prop` | `0.1` | Min proportion of samples with expression |
| `min_nuclei` | `20` | Min nuclei per sample |
| `celltype` | `['Ast','Ex','In','Microglia','Oligo','OPC']` | Cell types to process |
| `suffix` | `''` | Optional filename suffix |

### `phenotype_formatting`
| Parameter | Default | Description |
|-----------|---------|-------------|
| `input_dir` | *required* | Directory containing `{ct}/{ct}_residuals.txt` |
| `output_dir` | *required* | Parent output directory; writes to `{output_dir}/3_pheno_reformat/` |
| `modality` | `snatac` | Modality label used in output filename (`snatac` or `snrna`) |
| `celltype` | `['Ast','Ex','In','Mic','Oligo','OPC']` | Cell types to process |

### `region_filtering` *(snATAC-seq only)*
| Parameter | Default | Description |
|-----------|---------|-------------|
| `input_dir` | *required* | Directory containing `{ct}/{ct}_filtered_raw_counts.txt` |
| `output_dir` | *required* | Parent output directory; writes to `{output_dir}/3_region_filter/` |
| `regions` | `chr7:28000000-28300000,...` | Comma-separated genomic regions of interest |
| `celltype` | `['Ast','Ex','In','Mic','Oligo','OPC']` | Cell types to process |

## Minimal Working Example

## Step 0: Sample ID Mapping

Maps original sample identifiers (`individualID`) to standardized sample IDs (`sampleid`)
across metadata and count matrix files.

### Input

| File | Description |
|------|-------------|
| `rosmap_sample_mapping_data.csv` | Mapping reference: `individualID → sampleid` |
| `metadata_{celltype}.csv` × 6 | Per-cell-type sample metadata |
| `pseudobulk_peaks_counts_{celltype}.csv.gz` × 6 *(snATAC-seq)* | Per-cell-type peak count matrices |
| `pseudobulk_counts_{celltype}.csv.gz` × 6 *(snRNA-seq)* | Per-cell-type gene count matrices |

Cell types: `Ast`, `Ex`, `In`, `Microglia`, `Oligo`, `OPC`

Count files are **auto-detected** from `input_dir` — any `.csv.gz` file ending with
`{celltype}{suffix}` will be found regardless of prefix (`pseudobulk_peaks_counts_`,
`pseudobulk_counts_`, etc.).

### Process

**Part 1 — Metadata files**

For each `metadata_{celltype}.csv`:
1. Look up each `individualID` in the mapping reference
2. Assign `sampleid` — falls back to `individualID` if no mapping found
3. Reorder columns: `sampleid` first, then `individualID`, then the rest
4. Save updated file

**Part 2 — Count matrix files**

For each count file detected in `input_dir`:
1. Auto-detect filename by scanning for `.csv.gz` files matching `{celltype}{suffix}`
2. Extract the header row (column names only)
3. Keep the first column (peak or gene IDs) unchanged
4. Map remaining column names (`individualID` → `sampleid`) where mapping exists,
   otherwise keep original
5. Write new header and stream data rows unchanged
6. Recompress with gzip

### Output

Output directory: `{output_dir}/1_files_with_sampleid/`

| File | Description |
|------|-------------|
| `metadata_{celltype}.csv` × 6 | Metadata with `sampleid` column prepended |
| `{detected_count_filename}` × 6 | Count matrices with mapped column headers |

**Timing:** < 1 min

In [None]:
sos run pipeline/pseudobulk_preprocessing.ipynb sampleid_mapping \
    --map-file data/atac_seq/rosmap_sample_mapping_data.csv \
    --input-dir data/atac_seq/1_files_with_sampleid \
    --output-dir output/atac_seq \
    --celltype Ast Ex In Microglia Oligo OPC

## Step 1: Pseudobulk QC

Regresses out technical covariates while preserving biological variation (sex, age) for
downstream QTL analysis. Works for both snATAC-seq and snRNA-seq.

### Input

| File | Location |
|------|----------|
| `pseudobulk_*counts_{celltype}.csv.gz` *(auto-detected)* | `1_files_with_sampleid/` |
| `metadata_{celltype}.csv` | `1_files_with_sampleid/` |
| `rosmap_cov.txt` | `data/` |
| `hg38-blacklist.v2.bed.gz` *(snATAC-seq, optional)* | `data/` |

### Process

1. Load metadata per cell type; auto-detect and load count matrix from `input_dir`
2. Standardize metadata column names across datasets
3. Filter samples with fewer than `min_nuclei` nuclei (default: 20)
4. *(Optional)* Subset to samples listed in `sample_list` file
5. Align samples between metadata and count matrix
6. *(Optional)* Filter blacklisted genomic regions (`blacklist_file`)
7. Merge with demographic covariates (`pmi`, `study`) from `covariates_file`
8. Impute missing `pmi` values with median
9. Load `tech_vars` from parameter — any variable prefixed with `log_` is automatically
   derived via `log1p()` from the raw column of the same name:
   - e.g. `log_n_nuclei` ← `log1p(n_nuclei)`
   - e.g. `log_total_unique_peaks` ← `log1p(colSums(counts > 0))`
   - Works for both snATAC-seq and snRNA-seq without code changes
10. Build model variable list — `msex` and `age_death` are **deliberately excluded**
11. Drop samples with NA in any model variable
12. Apply expression filtering (`filterByExpr`):
    - `min_count = 5`: minimum reads in at least one sample
    - `min_total_count = 15`: minimum total reads across all samples
    - `min_prop = 0.1`: feature expressed in ≥10% of samples
13. TMM normalization
14. *(Optional)* Batch correction (`sequencingBatch` and/or `Library`):
    - `limma::removeBatchEffect` (default)
    - `ComBat-seq`
15. Add `sequencingBatch` and `Library` to model if multi-level
16. Fit linear model (`voom` + `lmFit` + `eBayes`)

**Model formula (default snATAC-seq):**
```
~ log_n_nuclei + med_nucleosome_signal + med_tss_enrich +
  log_med_n_tot_fragment + log_total_unique_peaks +
  [sequencingBatch] + [Library] + pmi + study
```

> `sequencingBatch` and `Library` are included only if present in metadata and have
> more than one level. If batch correction was applied, they are removed from the model.

17. Compute `offset + residuals` as final adjusted values:
    - `offset`: predicted value at median/reference covariate levels
    - `residuals`: unexplained variation after removing all covariate effects
18. *(Optional)* Quantile normalization of final values
19. Save outputs

### Output

Output directory: `{output_dir}/2_residuals/{celltype}/`

| File | Description |
|------|-------------|
| `{celltype}_residuals.txt` | Covariate-adjusted values (log2-CPM) |
| `{celltype}_results.rds` | Full results: DGEList, fit, offset, residuals, design, parameters |
| `{celltype}_filtered_raw_counts.txt` | Filtered raw counts before normalization |

**Variables deliberately NOT regressed out:**
- Sex (`msex`)
- Age at death (`age_death`)

**Timing:** < 5 min per cell type

### Pseudobulk QC


In [None]:
# snATAC-seq
sos run pipeline/pseudobulk_preprocessing.ipynb pseudobulk_qc \
    --input-dir output/atac_seq/1_files_with_sampleid \
    --output-dir output/atac_seq \
    --blacklist-file data/atac_seq/hg38-blacklist.v2.bed.gz \
    --covariates-file data/atac_seq/rosmap_cov.txt \
    --batch-correction FALSE \
    --min-count 5 \
    --celltype Ast Ex In Microglia Oligo OPC

# snRNA-seq
sos run pipeline/pseudobulk_preprocessing.ipynb pseudobulk_qc \
    --input-dir output/snrna_seq/1_files_with_sampleid \
    --output-dir output/snrna_seq \
    --covariates-file data/snrna_seq/covariates.txt \
    --min-count 5 \
    --batch-correction FALSE \
    --quant-norm TRUE \      # add this if you want quantile normalized output
    --celltype Ast Ex In Microglia Oligo OPC


## Batch Correction (Optional)

Runs between TMM normalization (step 15) and model fitting (step 18).
Use when batch effects are severe (e.g., visible batch clusters in PCA, multiple sequencing runs).

> When batch correction is applied, `sequencingBatch` and `Library` are **removed** from
> the model formula since their variance has already been removed from the counts.

**Method comparison:**

| | ComBat-seq | limma `removeBatchEffect` |
|---|---|---|
| **Operates on** | Raw integer counts | log-CPM values |
| **Mean-variance modelling** | Yes | No |
| **Best for** | Large, balanced batches | Small or fragmented batches |
| **Robustness** | May fail with many small batches | More robust to unbalanced designs |

**ComBat-seq:**
```r
dge$counts <- ComBat_seq(as.matrix(dge$counts), batch = batches)
```

**limma `removeBatchEffect`:**
```r
logCPM     <- cpm(dge, log = TRUE, prior.count = 1)
logCPM     <- removeBatchEffect(logCPM, batch = factor(batches))
dge$counts <- round(pmax(2^logCPM, 0))
```

**Additional filtering applied before correction:**
- Singleton batches (only 1 sample in a batch) are removed prior to correction

**Parameters:**

| Parameter | Default | Description |
|-----------|---------|-------------|
| `batch_correction` | `FALSE` | Enable batch correction |
| `batch_method` | `limma` | Method to use (`limma` or `combat`) |

**Command:**
```bash
sos run pipeline/pseudobulk_preprocessing.ipynb pseudobulk_qc \
    ... \
    --batch_correction TRUE \
    --batch_method limma
```

**Effect on RDS output:**

The `{celltype}_results.rds` file will include:
- `batch_correction = TRUE`
- `batch_method = "limma"` or `"combat"`

### Pseudobulk QC with batch correction


In [None]:
sos run pipeline/pseudobulk_preprocessing.ipynb pseudobulk_qc \
    --input-dir output/atac_seq/1_files_with_sampleid \
    --output-dir output/atac_seq \
    --blacklist-file data/atac_seq/hg38-blacklist.v2.bed.gz \
    --covariates-file data/atac_seq/rosmap_cov.txt \
    --batch-correction TRUE \
    --batch-method limma \
    --min-count 5 \
    --celltype Ast Ex In Microglia Oligo OPC

### Additional parameters


In [None]:
# All available pseudobulk_qc parameters with defaults
--min-count 5
--min-total-count 15
--min-prop 0.1
--min-nuclei 20
--sample-list '' # path to file with one sample ID per line
--tech-vars log_n_nuclei med_nucleosome_signal med_tss_enrich log_med_n_tot_fragment log_total_unique_peaks# snATAC-seq defaults; for snRNA-seq use e.g.: log_n_nuclei percent_mito log_n_genes

## Step 2: Format Output

### Phenotype Reformatting (exclusively for snATAC-seq)

Converts residuals into a QTL-ready BED format for genome-wide QTL mapping.
Works for both snATAC-seq and snRNA-seq.

**Input:**

| File | Location |
|------|----------|
| `{celltype}_residuals.txt` | `{output_dir}/2_residuals/{celltype}/` |

**Process:**

1. Read residuals file with proper handling of feature IDs and sample columns
2. Parse peak coordinates from peak IDs (`chr-start-end` format)
3. Convert to midpoint coordinates (standard for QTLtools):
```
start = floor((peak_start + peak_end) / 2)
end   = start + 1
```
4. Build BED format: `#chr`, `start`, `end`, `ID` followed by per-sample values
5. Sort by chromosome and position
6. Compress with `bgzip` and index with `tabix`

**Output:** `{output_dir}/3_pheno_reformat/`

| File | Description |
|------|-------------|
| `{celltype}_{modality}_phenotype.bed.gz` | bgzip-compressed BED with midpoint coordinates |
| `{celltype}_{modality}_phenotype.bed.gz.tbi` | tabix index for random-access queries |

**Use case:** Standard QTL mapping to identify genetic variants affecting chromatin
accessibility (caQTL) or gene expression (eQTL), with biological variation preserved.
Compatible with FastQTL, TensorQTL, and QTLtools.

**Timing:** < 1 min per cell type

**Note** For snRNA-seq, please follow this [pipeline](https://github.com/StatFunGen/xqtl-protocol/blob/main/code/data_preprocessing/phenotype/phenotype_formatting.ipynb).

In [None]:
sos run pipeline/pseudobulk_preprocessing.ipynb phenotype_formatting \
    --input-dir output/atac_seq/2_residuals \
    --output-dir output/atac_seq \
    --celltype Ast Ex In Mic Oligo OPC

### Region Filtering

Filters peak counts to specific genomic regions of interest for locus-specific analysis.

**Input:**

| File | Location |
|------|----------|
| `{celltype}_filtered_raw_counts.txt` | `{output_dir}/2_residuals/{celltype}/` |

**Process:**

1. Read filtered raw counts per cell type
2. Parse peak coordinates from peak IDs (`chr-start-end` format)
3. Calculate per-peak metrics:
   - `peakwidth`: `end - start`
   - `midpoint`: `(start + end) / 2`
4. Filter peaks overlapping any target region — includes peaks that start, end, or span region boundaries
5. Calculate summary statistics per peak:
   - `total_count`: sum of counts across all samples
   - `weighted_count`: `total_count / peakwidth` (normalizes for peak size)

**Output:** `{output_dir}/3_region_filter/`

| File | Description |
|------|-------------|
| `{celltype}_filtered_regions_of_interest.txt` | Full count matrix for peaks in target regions |
| `{celltype}_filtered_regions_of_interest_summary.txt` | Peak metadata with coordinates and count statistics |

**Use case:** Hypothesis-driven analysis of specific genomic loci (e.g., AD risk loci such as
the APOE or TREM2 regions) where biological variation is preserved for downstream interpretation.

**Timing:** < 1 min per cell type

In [None]:
#snATAC-seq 
sos run pipeline/pseudobulk_preprocessing.ipynb region_filtering \
    --input-dir output/atac_seq/2_residuals \
    --output-dir output/atac_seq \
    --celltype Ast Ex In Mic Oligo OPC \
    --regions "chr7:28000000-28300000,chr11:85050000-86200000"

#snRNA-seq
sos run pipeline/pseudobulk_preprocessing.ipynb region_filtering \
    --input-dir output/snrna_seq/2_residuals \
    --output-dir output/snrna_seq \
    --celltype MIC \
    --gene-list "ENSG00000000010"

## Command interface

In [None]:
sos run pipeline/pseudobulk_preprocessing.ipynb -h

## Setup and global parameters

In [None]:
[global]
parameter: cwd = path("output")
parameter: job_size = 1
parameter: walltime = "5h"
parameter: mem = "16G"
parameter: numThreads = 8
parameter: container = ""

import re
from sos.utils import expand_size

entrypoint = (
    'micromamba run -a "" -n' + ' ' +
    re.sub(r'(_apptainer:latest|_docker:latest|\.sif)$', '', container.split('/')[-1])
) if container else ""

cwd = path(f'{cwd:a}')

```
  usage: sos run pipeline/pseudobulk_preprocessing.ipynb
               [workflow_name | -t targets] [options] [workflow_options]
  workflow_name:        Single or combined workflows defined in this script
  targets:              One or more targets to generate
  options:              Single-hyphen sos parameters (see "sos run -h" for details)
  workflow_options:     Double-hyphen workflow-specific parameters
Workflows:
  sampleid_mapping
  pseudobulk_qc
  phenotype_formatting
  region_filtering
Global Workflow Options:
  --cwd output (as path)
  --job-size 1 (as int)
  --walltime 5h
  --mem 16G
  --numThreads 8 (as int)
  --container ''
Sections
  sampleid_mapping:
    Workflow Options:
      --map-file VAL (as str, required)
      --input-dir VAL (as str, required)
      --output-dir VAL (as str, required)
      --celltype Ast Ex In Microglia Oligo OPC (as list)
      --suffix ''
  pseudobulk_qc:
    Workflow Options:
      --celltype Ast Ex In Microglia Oligo OPC (as list)
      --input-dir VAL (as str, required)
      --output-dir VAL (as str, required)
      --covariates-file VAL (as str, required)
      --blacklist-file ''
      --sample-list ''
      --tech-vars log_n_nuclei med_nucleosome_signal med_tss_enrich log_med_n_tot_fragment log_total_unique_peaks (as list)
      --batch-correction FALSE
      --batch-method limma
      --quant-norm FALSE
      --min-count 5 (as int)
      --min-total-count 15 (as int)
      --min-prop 0.1 (as float)
      --min-nuclei 20 (as int)
      --suffix ''
  phenotype_formatting:
    Workflow Options:
      --celltype Ast Ex In Mic Oligo OPC (as list)
      --input-dir VAL (as str, required)
      --output-dir VAL (as str, required)
  region_filtering:
    Workflow Options:
      --celltype Ast Ex In Mic Oligo OPC (as list)
                        Parameters
      --input-dir VAL (as str, required)
      --output-dir VAL (as str, required)
      --regions ''
      --gene-list ''
```

## `sampleid_mapping`

In [None]:
[sampleid_mapping]
parameter: map_file   = str
parameter: input_dir  = str
parameter: output_dir = str
parameter: celltype   = ['Ast', 'Ex', 'In', 'Microglia', 'Oligo', 'OPC']
parameter: suffix     = ''

input:  [f'{input_dir}/metadata_{ct}{suffix}.csv' for ct in celltype]
output: [f'{output_dir}/1_files_with_sampleid/metadata_{ct}{suffix}.csv' for ct in celltype]

python: expand = "${ }"

import pandas as pd
import gzip
import os
import subprocess
import csv
import numpy as np
import tempfile

map_df = pd.read_csv("${map_file}")
id_map = dict(zip(map_df["individualID"], map_df["sampleid"]))

celltype   = ${celltype}
input_dir  = "${input_dir}"
output_dir = "${output_dir}/1_files_with_sampleid"
suffix     = "${suffix}"

os.makedirs(output_dir, exist_ok=True)

def map_id(ind_id):
    return id_map.get(ind_id, ind_id)

def format_value(val):
    if pd.isna(val):
        return ''
    if isinstance(val, (int, np.integer)):
        return str(val)
    if isinstance(val, (float, np.floating)):
        if val == int(val):
            return str(int(val))
        else:
            return str(val)
    return str(val)

def find_count_file(input_dir, ct, suffix):
    candidates = [
        f for f in os.listdir(input_dir)
        if f.endswith(f"{ct}{suffix}.csv.gz") or f.endswith(f"_{ct}{suffix}.csv.gz")
    ]
    if not candidates:
        return None, None
    preferred = [f for f in candidates if f.endswith(f"_{ct}{suffix}.csv.gz")]
    fname = preferred[0] if preferred else candidates[0]
    return os.path.join(input_dir, fname), fname

# ── Process metadata ───────────────────────────────────────────────────────
for ct in celltype:
    fname    = f"metadata_{ct}{suffix}.csv"
    in_path  = os.path.join(input_dir, fname)
    out_path = os.path.join(output_dir, fname)

    if not os.path.exists(in_path):
        print(f"Skipping metadata (not found): {fname}")
        continue

    meta = pd.read_csv(in_path)

    if "individualID" not in meta.columns:
        print(f"Warning: individualID column not found in {fname}")
        continue

    meta["sampleid"] = meta["individualID"].map(map_id)
    cols = meta.columns.tolist()
    cols.remove("sampleid")
    cols.remove("individualID")
    meta = meta[["sampleid", "individualID"] + cols]

    with open(out_path, 'w', newline='') as f:
        writer = csv.writer(f, quoting=csv.QUOTE_MINIMAL)
        writer.writerow(meta.columns)
        for _, row in meta.iterrows():
            writer.writerow([format_value(val) for val in row])

    print(f"Processed metadata: {fname}")

# ── Process count files ────────────────────────────────────────────────────
for ct in celltype:
    in_path, fname = find_count_file(input_dir, ct, suffix)

    if in_path is None:
        print(f"Skipping counts (not found) for celltype: {ct}")
        continue

    print(f"Detected count file: {fname}")
    out_path = os.path.join(output_dir, fname)

    with gzip.open(in_path, "rt") as fh:
        header_line = fh.readline().rstrip("\n")

    col_names       = header_line.split(",")
    peak_id_col     = col_names[0]
    new_sample_cols = [map_id(s) for s in col_names[1:]]
    new_header      = ",".join([peak_id_col] + new_sample_cols)

    tmp = tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt')
    tmp.write(new_header + "\n")
    tmp.close()

    cmd = f"zcat {in_path} | tail -n +2 | cat {tmp.name} - | gzip -6 > {out_path}"
    subprocess.run(cmd, shell=True, check=True)
    os.unlink(tmp.name)

    print(f"Processed counts: {fname}")

print("\nSample ID mapping completed!")

## `pseudobulk_qc`

In [None]:
[pseudobulk_qc]
parameter: celltype         = ['Ast','Ex','In','Microglia','Oligo','OPC']
parameter: input_dir        = str
parameter: output_dir       = str
parameter: covariates_file  = str
parameter: blacklist_file   = ''
parameter: sample_list      = ''
parameter: tech_vars        = ['log_n_nuclei','med_nucleosome_signal','med_tss_enrich','log_med_n_tot_fragment','log_total_unique_peaks']
parameter: batch_correction = "FALSE"
parameter: batch_method     = "limma"
parameter: quant_norm       = "FALSE"
parameter: min_count        = 5
parameter: min_total_count  = 15
parameter: min_prop         = 0.1
parameter: min_nuclei       = 20
parameter: suffix           = ''

input:  [f'{input_dir}/metadata_{ct}{suffix}.csv' for ct in celltype]
output: [f'{output_dir}/2_residuals/{ct}/{ct}_residuals.txt' for ct in celltype]

task: trunk_workers = 1, trunk_size = 1, walltime = '6:00:00', mem = '64G', cores = 4

cts_str = "c(" + ", ".join([f"'{x}'" for x in celltype]) + ")"
tvs_str = "c(" + ", ".join([f"'{x}'" for x in tech_vars]) + ")"

R: expand = "${ }", stdout = f'{_output[0]:n}.stdout', stderr = f'{_output[0]:n}.stderr'

    library(edgeR)
    library(limma)
    library(data.table)
    library(GenomicRanges)
    if (as.logical("${batch_correction}") && "${batch_method}" == "combat") library(sva)

    rename_if_found <- function(dt, target, candidates) {
        found <- intersect(candidates, colnames(dt))[1]
        if (!is.na(found) && found != target) setnames(dt, found, target)
    }

    standardize_meta <- function(meta) {
        rename_if_found(meta, "n_nuclei",              c("n.nuclei","nNuclei","nuclei_count"))
        rename_if_found(meta, "med_nucleosome_signal", c("med.nucleosome_signal.ct","NucleosomeRatio","med_nucleosome_signal.ct"))
        rename_if_found(meta, "med_tss_enrich",        c("med.tss.enrich.ct","TSSEnrichment","med_tss_enrich.ct"))
        rename_if_found(meta, "med_n_tot_fragment",    c("med.n_tot_fragment.ct","med_n_tot_fragment.ct"))
        return(meta)
    }

    find_count_file <- function(input_dir, ct, suffix) {
        all_files  <- list.files(input_dir, pattern="\\.csv\\.gz$", full.names=FALSE)
        pattern    <- paste0(ct, suffix, "\\.csv\\.gz$")
        candidates <- all_files[grepl(pattern, all_files)]
        if (length(candidates) == 0) return(NULL)
        preferred <- candidates[grepl(paste0("_", ct, suffix, "\\.csv\\.gz$"), candidates)]
        if (length(preferred) > 0) return(file.path(input_dir, preferred[1]))
        return(file.path(input_dir, candidates[1]))
    }

    filter_blacklist <- function(mat, bed, feat_label) {
        peaks <- data.table(id = rownames(mat))
        peaks[, c("chr","start","end") := tstrsplit(gsub("_","-",id), "-")]
        peaks[, `:=`(start = as.numeric(start), end = as.numeric(end))]
        bl <- fread(bed)[, 1:3]
        setnames(bl, c("chr","start","end"))
        bl[, `:=`(start = as.numeric(start), end = as.numeric(end))]
        gr1 <- GRanges(peaks$chr, IRanges(peaks$start, peaks$end))
        gr2 <- GRanges(bl$chr,    IRanges(bl$start,    bl$end))
        blacklisted <- unique(queryHits(findOverlaps(gr1, gr2)))
        if (length(blacklisted) > 0) {
            message("Blacklisted ", feat_label, " removed: ", length(blacklisted))
            return(mat[-blacklisted, , drop=FALSE])
        }
        return(mat)
    }

    predictOffset <- function(fit) {
        D  <- fit$design
        Dm <- D
        for (col in colnames(D)) {
            if (col == "(Intercept)") next
            if (is.numeric(D[, col]) && !all(D[, col] %in% c(0, 1)))
                Dm[, col] <- median(D[, col], na.rm=TRUE)
            else
                Dm[, col] <- 0
        }
        B <- fit$coefficients
        B[is.na(B)] <- 0
        B %*% t(Dm)
    }

    cts       <- ${cts_str}
    tech_vars <- ${tvs_str}

    for (ct in cts) {
        message("\n", paste(rep("=", 40), collapse=""))
        message("Processing: ", ct)
        message("Batch correction: ", ifelse(as.logical("${batch_correction}"), "${batch_method}", "none"))
        message("Quantile normalization: ", ifelse(as.logical("${quant_norm}"), "TRUE", "FALSE"))
        message(paste(rep("=", 40), collapse=""))

        outdir <- file.path("${output_dir}/2_residuals", ct)
        dir.create(outdir, recursive=TRUE, showWarnings=FALSE)

        # ── 1. Load data ───────────────────────────────────────────────────
        meta <- fread(sprintf("${input_dir}/metadata_%s${suffix}.csv", ct))

        counts_file <- find_count_file("${input_dir}", ct, "${suffix}")
        if (is.null(counts_file)) stop("No count file found for celltype: ", ct)
        message("Detected count file: ", basename(counts_file))

        counts_raw <- fread(counts_file)
        counts <- as.matrix(counts_raw[, -1, with=FALSE])
        rownames(counts) <- counts_raw[[1]]
        rm(counts_raw)

        # ── Auto-detect modality ───────────────────────────────────────────
        is_atac    <- grepl("^chr.*-[0-9]+-[0-9]+$", rownames(counts)[1])
        feat_label <- ifelse(is_atac, "peaks", "genes")
        message("Detected modality: ", ifelse(is_atac, "snATAC-seq", "snRNA-seq"))
        message("Loaded: ", nrow(counts), " ", feat_label, " x ", ncol(counts), " samples")

        # ── 2. Standardize metadata ────────────────────────────────────────
        meta <- standardize_meta(meta)

        # ── 3. Sample ID column ───────────────────────────────────────────
        idcol <- intersect(c("sampleid","sampleID","individualID","projid"), colnames(meta))[1]
        if (is.na(idcol)) stop("Cannot find sample ID column in metadata.")

        # ── 4. Nuclei filter ──────────────────────────────────────────────
        if ("n_nuclei" %in% colnames(meta)) {
            meta <- meta[meta$n_nuclei > ${min_nuclei}]
            message("Samples after nuclei (>${min_nuclei}) filter: ", nrow(meta))
        }

        # ── 5. Optional sample list filter ────────────────────────────────
        if ("${sample_list}" != "" && file.exists("${sample_list}")) {
            keep_ids <- fread("${sample_list}", header=FALSE)[[1]]
            meta     <- meta[meta[[idcol]] %in% keep_ids]
            message("Samples after sample_list filter: ", nrow(meta))
        }

        # ── 6. Align samples ──────────────────────────────────────────────
        common <- intersect(meta[[idcol]], colnames(counts))
        if (length(common) == 0) stop("Zero sample overlap between metadata and count matrix.")
        meta   <- meta[match(common, meta[[idcol]])]
        counts <- counts[, common, drop=FALSE]
        message("Samples after alignment: ", length(common))

        # ── 7. Blacklist filtering ─────────────────────────────────────────
        if ("${blacklist_file}" != "" && file.exists("${blacklist_file}")) {
            counts <- filter_blacklist(counts, "${blacklist_file}", feat_label)
            message(feat_label, " after blacklist filter: ", nrow(counts))
        } else {
            message("No blacklist file provided - skipping.")
        }

        # ── 8. Load and merge covariates ──────────────────────────────────
        covs      <- fread("${covariates_file}")
        id2       <- intersect(c("#id","id","projid","individualID"), colnames(covs))[1]
        keep_cols <- c(id2, intersect(c("pmi","study"), colnames(covs)))
        covs      <- covs[, ..keep_cols]
        meta      <- merge(meta, covs, by.x=idcol, by.y=id2, all.x=TRUE)
        meta      <- meta[match(common, meta[[idcol]])]

        # ── 9. Impute missing PMI ─────────────────────────────────────────
        if ("pmi" %in% colnames(meta) && any(is.na(meta$pmi))) {
            message("Imputing missing values for: pmi")
            meta$pmi[is.na(meta$pmi)] <- median(meta$pmi, na.rm=TRUE)
        }

        # ── 10. Tech vars ─────────────────────────────────────────────────
        message("Tech vars: ", paste(tech_vars, collapse=", "))

        # ── 11. Compute derived log metrics ───────────────────────────────
        for (tv in tech_vars[startsWith(tech_vars, "log_")]) {
            if (tv %in% colnames(meta)) next
            if (tv == "log_total_unique_peaks") {
                meta$log_total_unique_peaks <- log1p(colSums(counts > 0))
            } else {
                raw_col <- sub("^log_", "", tv)
                if (raw_col %in% colnames(meta)) {
                    meta[[tv]] <- log1p(meta[[raw_col]])
                } else {
                    message("Warning: cannot compute ", tv, " - '", raw_col, "' not in metadata")
                }
            }
        }

        # ── 12. Select model variables ────────────────────────────────────
        all_vars <- c(intersect(tech_vars, colnames(meta)), "pmi", "study")
        all_vars <- intersect(all_vars, colnames(meta))
        message("Model terms: ", paste(all_vars, collapse=", "))

        # ── 13. Drop samples with NA in model variables ───────────────────
        keep_rows <- complete.cases(meta[, ..all_vars])
        meta      <- meta[keep_rows]
        counts    <- counts[, meta[[idcol]], drop=FALSE]
        message("Valid samples for modelling: ", nrow(meta))

        # ── 14. Expression filtering ──────────────────────────────────────
        dge <- DGEList(counts=counts, samples=meta)
        dge$samples$group <- factor(rep("all", ncol(dge)))
        message(feat_label, " before expression filter: ", nrow(dge))

        keep <- filterByExpr(dge, group=dge$samples$group,
                             min.count=${min_count},
                             min.total.count=${min_total_count},
                             min.prop=${min_prop})
        dge <- dge[keep,, keep.lib.sizes=FALSE]
        message(feat_label, " after expression filter: ", nrow(dge))

        # ── Save filtered raw counts ──────────────────────────────────────
        write.table(dge$counts,
                    file.path(outdir, paste0(ct, "_filtered_raw_counts.txt")),
                    sep="\t", quote=FALSE, col.names=NA)

        # ── 15. TMM normalization ─────────────────────────────────────────
        dge <- calcNormFactors(dge, method="TMM")

        # ── 16. Optional batch correction ─────────────────────────────────
        if (as.logical("${batch_correction}") && "sequencingBatch" %in% colnames(dge$samples)) {
            batches       <- dge$samples$sequencingBatch
            batch_counts  <- table(batches)
            valid_batches <- names(batch_counts[batch_counts > 1])
            keep_bc       <- batches %in% valid_batches
            dge           <- dge[, keep_bc, keep.lib.sizes=FALSE]
            batches       <- batches[keep_bc]
            message("Samples after singleton batch removal: ", ncol(dge))

            if ("${batch_method}" == "combat") {
                dge$counts <- ComBat_seq(as.matrix(dge$counts), batch=batches)
                message("ComBat-seq batch correction applied.")
            } else {
                logCPM     <- cpm(dge, log=TRUE, prior.count=1)
                logCPM     <- removeBatchEffect(logCPM, batch=factor(batches))
                dge$counts <- round(pmax(2^logCPM, 0))
                message("limma removeBatchEffect applied.")
            }
        }

        # ── 17. Add batch vars to model if multi-level ────────────────────
        other_vars <- setdiff(all_vars, tech_vars)
        batch_vars <- c()
        if ("sequencingBatch" %in% colnames(dge$samples) &&
            length(unique(dge$samples$sequencingBatch)) > 1) {
            dge$samples$sequencingBatch_factor <- factor(dge$samples$sequencingBatch)
            batch_vars <- c(batch_vars, "sequencingBatch_factor")
        }
        if ("Library" %in% colnames(dge$samples) &&
            length(unique(dge$samples$Library)) > 1) {
            dge$samples$Library_factor <- factor(dge$samples$Library)
            batch_vars <- c(batch_vars, "Library_factor")
        }
        all_vars <- intersect(c(tech_vars, batch_vars, other_vars),
                              c(colnames(dge$samples), colnames(meta)))

        # ── 18. Build design matrix ───────────────────────────────────────
        form   <- as.formula(paste("~", paste(all_vars, collapse=" + ")))
        design <- model.matrix(form, data=dge$samples)
        message("Formula: ", deparse(form))

        if (!is.fullrank(design)) {
            message("Design not full rank - trimming.")
            qr_d   <- qr(design)
            design <- design[, qr_d$pivot[seq_len(qr_d$rank)], drop=FALSE]
        }
        message("Design matrix: ", nrow(design), " x ", ncol(design))

        # ── 19. Voom + lmFit + eBayes ────────────────────────────────────
        v   <- voom(dge, design, plot=FALSE)
        fit <- lmFit(v, design)
        fit <- eBayes(fit)

        # ── 20. Offset + residuals ────────────────────────────────────────
        off   <- predictOffset(fit)
        res   <- residuals(fit, v$E)
        final <- off + res

        # ── 21. Save residuals ────────────────────────────────────────────
        out_file <- file.path(outdir, paste0(ct, "_residuals.txt"))

        write.table(final,
            out_file,
            sep="\t", quote=FALSE, col.names=NA)

        feat_label <- if (is_atac) "Peaks" else "Genes"

        message("Saved: ", out_file)
        message("  ", feat_label, ": ", nrow(final), " | Samples: ", ncol(final))

        # ── 22. Optional Quantile Normalization ───────────────────────────
        if (as.logical("${quant_norm}")) {
            message("\n", paste(rep("=", 40), collapse=""))
            message("Applying quantile normalization...")
            message(paste(rep("=", 40), collapse=""))
            
            final_qn <- t(apply(final, 1, rank, ties.method = "average"))
            final_qn <- stats::qnorm(final_qn / (ncol(final_qn) + 1))
            
            qn_file <- file.path(outdir, paste0(ct, "_residuals_qn.txt"))
            write.table(final_qn,
                qn_file,
                sep="\t", quote=FALSE, col.names=NA)
            
            message("Saved QN: ", qn_file)
            message("  ", feat_label, ": ", nrow(final_qn), " | Samples: ", ncol(final_qn))
            
            # Save RDS with QN
            saveRDS(list(
                dge              = dge,
                offset           = off,
                residuals        = res,
                final_data       = final,
                final_data_qn    = final_qn,
                valid_samples    = colnames(dge),
                design           = design,
                fit              = fit,
                model            = form,
                mode             = "noBIOvar",
                batch_correction = as.logical("${batch_correction}"),
                batch_method     = ifelse(as.logical("${batch_correction}"), "${batch_method}", "none"),
                quant_norm       = TRUE,
                modality         = ifelse(is_atac, "snATAC-seq", "snRNA-seq")
            ), file.path(outdir, paste0(ct, "_results_qn.rds")))
        } else {
            # Save RDS without QN
            saveRDS(list(
                dge              = dge,
                offset           = off,
                residuals        = res,
                final_data       = final,
                valid_samples    = colnames(dge),
                design           = design,
                fit              = fit,
                model            = form,
                mode             = "noBIOvar",
                batch_correction = as.logical("${batch_correction}"),
                batch_method     = ifelse(as.logical("${batch_correction}"), "${batch_method}", "none"),
                quant_norm       = FALSE,
                modality         = ifelse(is_atac, "snATAC-seq", "snRNA-seq")
            ), file.path(outdir, paste0(ct, "_results.rds")))
        }

        message("Completed: ", ct, " -> ", outdir)
    }

## `phenotype_reformatting`

In [None]:
[phenotype_formatting]
parameter: celltype   = ['Ast','Ex','In','Mic','Oligo','OPC']
parameter: input_dir  = str
parameter: output_dir = str

input:  [f'{input_dir}/{ct}/{ct}_residuals.txt' for ct in celltype]
output: [f'{output_dir}/3_pheno_reformat/{ct}_phenotype.bed.gz' for ct in celltype]

task: trunk_workers = 1, trunk_size = 1, walltime = '2:00:00', mem = '16G', cores = 2

python: expand = "${ }", stderr = f'{_output[0]:n}.stderr', stdout = f'{_output[0]:n}.stdout'

    import os
    import subprocess
    import pandas as pd

    celltypes  = ${celltype}
    input_dir  = "${input_dir}"
    output_dir = "${output_dir}"

    def read_residuals(path):
        first_line = open(path).readline().rstrip("\n")
        col_names  = first_line.split("\t")
        df = pd.read_csv(path, sep="\t", header=None, skiprows=1)
        if df.shape[1] > len(col_names):
            peak_ids   = df.iloc[:, 0].values
            df         = df.iloc[:, 1:]
            df.columns = col_names
        else:
            peak_ids   = df.iloc[:, 0].values
            df         = df.iloc[:, 1:]
            df.columns = col_names[1:]
        return peak_ids, df

    def to_midpoint_bed(peak_ids, residuals):
        """Convert snATAC-seq peak IDs (chr-start-end) to midpoint BED format."""
        parts  = pd.Series(peak_ids).str.split("-", expand=True)
        chrs   = parts[0].values
        starts = parts[1].astype(int).values
        ends   = parts[2].astype(int).values
        mids   = ((starts + ends) // 2).astype(int)
        bed = pd.DataFrame({
            "#chr":  chrs,
            "start": mids,
            "end":   mids + 1,
            "ID":    peak_ids
        })
        bed = pd.concat([bed, residuals.reset_index(drop=True)], axis=1)
        return bed.sort_values(["#chr", "start"]).reset_index(drop=True)

    def run_cmd(cmd, label):
        r = subprocess.run(cmd, capture_output=True)
        if r.returncode != 0:
            print(f"WARNING: {label} failed: {r.stderr.decode()}")
        else:
            print(f"{label}: OK")

    for ct in celltypes:
        print(f"\n{'='*40}\nPhenotype Formatting: {ct}\n{'='*40}")

        out_dir = os.path.join(output_dir, "3_pheno_reformat")
        os.makedirs(out_dir, exist_ok=True)

        res_path = os.path.join(input_dir, ct, f"{ct}_residuals.txt")
        if not os.path.exists(res_path):
            print(f"WARNING: {res_path} not found, skipping.")
            continue

        peak_ids, residuals = read_residuals(res_path)
        print(f"Loaded {len(peak_ids)} peaks x {residuals.shape[1]} samples")

        bed     = to_midpoint_bed(peak_ids, residuals)
        out_bed = os.path.join(out_dir, f"{ct}_phenotype.bed")
        bed.to_csv(out_bed, sep="\t", index=False, float_format="%.15f")
        print(f"Written: {out_bed}")

        run_cmd(["bgzip", "-f", out_bed],                "bgzip")
        run_cmd(["tabix", "-p", "bed", f"{out_bed}.gz"], "tabix")
        print(f"Completed: {ct} -> {out_dir}")

## `region_filtering`

In [None]:
[region_filtering]
# Parameters
parameter: celltype   = ['Ast','Ex','In','Mic','Oligo','OPC']
parameter: input_dir  = str
parameter: output_dir = str
parameter: regions    = ""
parameter: gene_list  = ""  # Note: Use --gene_list in command line

# SoS Input/Output logic
input:  [f'{input_dir}/{ct}/{ct}_filtered_raw_counts.txt' for ct in (celltype if isinstance(celltype, list) else [celltype])]
output: [f'{output_dir}/3_region_filter/{ct}_filtered_regions_of_interest.txt' for ct in (celltype if isinstance(celltype, list) else [celltype])]

task: trunk_workers = 1, trunk_size = 1, walltime = '1:00:00', mem = '16G', cores = 2

python: expand = "${ }", stderr = f'{_output[0]:n}.stderr', stdout = f'{_output[0]:n}.stdout'
    import os
    import pandas as pd

    # Handle SoS passing single strings vs lists
    raw_ct = ${celltype!r}
    celltypes = [raw_ct] if isinstance(raw_ct, str) else raw_ct
    
    input_dir  = "${input_dir}"
    output_dir = "${output_dir}"
    regions_str = "${regions}"
    gene_list_str = "${gene_list}"

    def parse_regions(region_str):
        if not region_str or region_str.strip() == "":
            return []
        result = []
        for r in region_str.split(","):
            chrom, coords = r.strip().split(":")
            start, end    = coords.split("-")
            result.append({"chr": chrom, "start": int(start), "end": int(end)})
        return result

    def parse_peak_ids(peak_ids):
        parts = pd.Series(peak_ids).str.split("-", expand=True)
        return pd.DataFrame({
            "chr":   parts[0].values,
            "start": parts[1].astype(int).values,
            "end":   parts[2].astype(int).values
        })

    def overlaps_region(chr_col, start_col, end_col, reg):
        return (
            (chr_col   == reg["chr"]) &
            (start_col <   reg["end"]) &
            (end_col   >   reg["start"])
        )

    regions = parse_regions(regions_str)
    
    genes_to_filter = None
    if gene_list_str and gene_list_str.strip():
        genes_to_filter = set([g.strip() for g in gene_list_str.split(",")])

    for ct in celltypes:
        reg_dir = os.path.join(output_dir, "3_region_filter")
        os.makedirs(reg_dir, exist_ok=True)

        counts_path = os.path.join(input_dir, ct, f"{ct}_filtered_raw_counts.txt")
        if not os.path.exists(counts_path):
            continue

        df = pd.read_csv(counts_path, sep="\t", index_col=0)
        first_id = df.index[0]
        is_atac = "-" in str(first_id) and str(first_id).count("-") >= 2
        
        # Consistent output name to match SoS 'output' definition
        full_out = os.path.join(reg_dir, f"{ct}_filtered_regions_of_interest.txt")

        if is_atac:
            if not regions: continue
            df.index.name = "peak_id"
            df = df.reset_index()
            coords = parse_peak_ids(df["peak_id"].values)
            df["chr"], df["start"], df["end"] = coords["chr"], coords["start"], coords["end"]
            df["peakwidth"] = df["end"] - df["start"]
            
            mask = pd.Series(False, index=df.index)
            for reg in regions:
                mask |= overlaps_region(df["chr"], df["start"], df["end"], reg)

            region_df = df[mask].copy()
            region_df.to_csv(full_out, sep="\t", index=False)
        
        else:
            if not genes_to_filter: continue
            df.index.name = "gene_name"
            genes_present = set(df.index) & genes_to_filter
            if not genes_present: continue
            
            region_df = df.loc[list(genes_present)].copy()
            # FIX: Use the same filename as defined in the SoS 'output' block
            region_df.to_csv(full_out, sep="\t")

        print(f"Completed: {ct}")