# mtDNA Heteroplasmy Pipeline — v2 (Comprehensive)

## Overview
This notebook contains the **comprehensive mtDNA heteroplasmy pipeline**, with the intent to upgrade Stage01–Stage03 to full production versions.

### Current Stages (v2)
1. **Stage01 (chrM extraction — to be upgraded)**  
   Subset WGS CRAM → chrM BAM/BAI/SAM  
   Output filename: `<sample>_<age>_<sex>_chrM.bam`

2. **Stage02 (mtDNA variant calling)**  
   Call mtDNA variants in mitochondria mode  
   Output filename: `<sample>_<age>_<sex>_mt.filtered.vcf`

3. **Stage03 (filter + summary)**  
   Filter by VAF threshold and emit TSV  
   Output filename: `<sample>_<age>_<sex>_mt.vaf0.01.tsv`

## Output Summary
The final TSVs contain:  
`CHROM, POS, REF, ALT, QUAL, FILTER, DP, AF, HET`

## Goal
Aggregate per‑sample heteroplasmy metrics and plot **heteroplasmy vs age** across age bins.

## Notes
- We will progressively **upgrade Stage01–Stage03** to the full analysis workflow (coverage, contamination, annotation, etc.).
- The current implementation is a stable baseline for scaling and validation.


In [None]:
# SET GLOBAL VARIABLES
# NOTE:
# The Cloud Life Sciences (GLS) API is expired !
# Batch (GCB) migration in the All of Us (AOU) Workbench occurred in July 8, 2025
# For migration details:
# https://cloud.google.com/batch/docs/migrate-to-batch-from-cloud-life-sciences

import os
from pathlib import Path

WORKSPACE_BUCKET = os.getenv("WORKSPACE_BUCKET", "").rstrip("/")
GOOGLE_PROJECT = os.getenv("GOOGLE_PROJECT", "")
PET_SA_EMAIL = os.getenv("PET_SA_EMAIL", "")

outputFold = os.getenv("outputFold", "mtDNA_v25_pilot_5")
PORTID = int(os.getenv("PORTID", "8094"))
USE_MEM = int(os.getenv("USE_MEM", "32"))
SQL_DB_NAME = os.getenv("SQL_DB_NAME", "local_cromwell_run.db")

PROJECT_ROOT = Path(os.getenv("PROJECT_ROOT", "/mnt/f/research_drive/mtdna/leelab/mtDNA-analysis")).resolve()

print("WORKSPACE_BUCKET:", WORKSPACE_BUCKET)
print("GOOGLE_PROJECT:", GOOGLE_PROJECT)
print("PET_SA_EMAIL:", PET_SA_EMAIL)
print("PROJECT_ROOT:", PROJECT_ROOT)
print("PORTID:", PORTID, "USE_MEM:", USE_MEM)


In [None]:
import os
import shutil
import subprocess
import sys
from pathlib import Path

def run(cmd, check=True):
    print(f"$ {cmd}")
    return subprocess.run(
        cmd, shell=True, check=check,
        stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
        text=True
    ).stdout.strip()

# --- Dependency checks ---
checks = {
    "python": sys.executable,
    "pip": shutil.which("pip") or shutil.which("pip3") or "",
    "gsutil": shutil.which("gsutil") or "",
    "gcloud": shutil.which("gcloud") or "",
    "java": shutil.which("java") or "",
    "sdkman": shutil.which("sdk") or "",
}

print("Dependency check:")
for k, v in checks.items():
    print(f"  {k:8s} -> {v if v else 'MISSING'}")

if checks["java"]:
    print("\nJava version:")
    print(run("java -version", check=False))

if checks["gcloud"]:
    print("\nGcloud auth:")
    print(run("gcloud auth list --format='value(account)'", check=False))

# Optional: install pyhocon if missing
try:
    import pyhocon  # noqa: F401
    print("\npyhocon: OK")
except ImportError:
    print("\npyhocon: missing")
    # Uncomment to install
    # print(run(f"{sys.executable} -m pip install pyhocon", check=True))

# --- Optional bootstrap tools (comment out if not needed) ---
home = Path.home()
sdkman_dir = home / ".sdkman"

if not sdkman_dir.exists():
    print("\nInstalling SDKMAN...")
    run("curl -s https://get.sdkman.io -o install_sdkman.sh", check=True)
    run("bash install_sdkman.sh", check=True)
else:
    print("\nSDKMAN already installed.")

sdkman_init = sdkman_dir / "bin" / "sdkman-init.sh"
if sdkman_init.exists():
    run(f"bash -lc 'source {sdkman_init} && sdk install java 17.0.8-tem || true'", check=True)
    run(f"bash -lc 'source {sdkman_init} && sdk use java 17.0.8-tem'", check=True)
else:
    print("SDKMAN init script not found; skipping Java install.")

# Cromwell/WOMtool 91
if not Path("cromwell-91.jar").exists():
    print("Downloading cromwell-91.jar")
    run("curl -L https://github.com/broadinstitute/cromwell/releases/download/91/cromwell-91.jar -o cromwell-91.jar", check=True)
else:
    print("cromwell-91.jar already present.")

if not Path("womtool-91.jar").exists():
    print("Downloading womtool-91.jar")
    run("curl -L https://github.com/broadinstitute/cromwell/releases/download/91/womtool-91.jar -o womtool-91.jar", check=True)
else:
    print("womtool-91.jar already present.")

# Heap size to use in start_cromwell()
CROMWELL_HEAP_GB = 32
print("CROMWELL_HEAP_GB set to", CROMWELL_HEAP_GB)


In [None]:
import json
import time
import subprocess
import os
from pathlib import Path
import requests

# ---- Config ----
CROMWELL_PORT = 8094
CROMWELL_STATUS_URL = f"http://localhost:{CROMWELL_PORT}/engine/v1/status"
CROMWELL_API = f"http://localhost:{CROMWELL_PORT}/api/workflows/v1"
CROMWELL_CONF = Path("/home/jupyter/cromwell.conf")
CROMWELL_JAR = Path("cromwell-91.jar")  # upgraded (81 was buggy and slow)
SDKMAN_INIT = "/home/jupyter/.sdkman/bin/sdkman-init.sh"
JAVA_VER = "17.0.8-tem"
CROMWELL_HEAP_GB = 32  # new

STDOUT_LOG = Path("cromwell_server_stdout.log")
STDERR_LOG = Path("cromwell_server_stderr.log")
PID_FILE = Path("cromwell_server.pid")

DB_BASE = Path("/home/jupyter/cromwell_db/local_cromwell_run.db")
DB_DATA = Path(str(DB_BASE) + ".data")

# ---- Variables ----
_last_restart = 0

# ---- Helpers ----
def cromwell_up():
    try:
        r = requests.get(CROMWELL_STATUS_URL, timeout=2)
        return r.ok
    except Exception:
        return False

def cromwell_pid_running():
    if not PID_FILE.exists():
        return False
    try:
        pid = int(PID_FILE.read_text().strip())
        os.kill(pid, 0)
        return True
    except Exception:
        return False

def cromwell_healthy():
    return cromwell_pid_running() and cromwell_up()

def cromwell_persistent_ok():
    return DB_DATA.exists() and DB_DATA.stat().st_size > 0

def cromwell_persistent_recent(max_age_s=300):
    if not cromwell_persistent_ok():
        return False
    age = time.time() - DB_DATA.stat().st_mtime
    return age <= max_age_s

def start_cromwell():
    if cromwell_healthy():
        print("Cromwell already running and healthy.")
        if cromwell_persistent_ok():
            print("Persistence check: DB data file OK.")
            if cromwell_persistent_recent():
                print("DB was updated recently.")
        return

    cmd = (
        f"bash -lc 'source {SDKMAN_INIT} && sdk use java {JAVA_VER} && "
        f"nohup java -Xmx{CROMWELL_HEAP_GB}g "
        f"-Dconfig.file={CROMWELL_CONF} -Dwebservice.port={CROMWELL_PORT} "
        f"-jar {CROMWELL_JAR} server > {STDOUT_LOG} 2> {STDERR_LOG} & "
        f"echo $! > {PID_FILE} && disown'"
    )
    print("$", cmd)
    subprocess.run(cmd, shell=True, check=True)

    for _ in range(30):
        if cromwell_up():
            print("Cromwell is up.")
            if cromwell_persistent_ok():
                print("Persistence check: DB data file OK.")
                if cromwell_persistent_recent():
                    print("DB was updated recently.")
            return
        time.sleep(2)

    raise RuntimeError("Cromwell did not start. Check stderr/stdout logs.")

def tail_logs(n=50):
    if STDOUT_LOG.exists():
        print(f"--- {STDOUT_LOG} (last {n}) ---")
        print("\n".join(STDOUT_LOG.read_text().splitlines()[-n:]))
    else:
        print(f"{STDOUT_LOG} not found.")
    if STDERR_LOG.exists():
        print(f"--- {STDERR_LOG} (last {n}) ---")
        print("\n".join(STDERR_LOG.read_text().splitlines()[-n:]))
    else:
        print(f"{STDERR_LOG} not found.")

def pretty(obj):
    print(json.dumps(obj, indent=2))

def get_wf_status(wf_id, retries=10, sleep_s=2):
    url = f"{CROMWELL_API}/{wf_id}/status"
    last_err = None
    for _ in range(retries):
        r = requests.get(url)
        if r.status_code == 200:
            return r.json()
        if r.status_code == 404:
            time.sleep(sleep_s)
            continue
        last_err = r
        break
    if last_err is not None:
        last_err.raise_for_status()
    raise RuntimeError(f"Workflow {wf_id} not found after {retries} retries.")

def get_wf_metadata(wf_id, include_keys=None):
    url = f"{CROMWELL_API}/{wf_id}/metadata"
    if include_keys:
        for k in include_keys:
            url += f"&includeKey={k}" if "?" in url else f"?includeKey={k}"
    r = requests.get(url)
    r.raise_for_status()
    return r.json()

def wait_for_wf(wf_id, poll_s=5, timeout_s=600):
    global _last_restart
    deadline = time.time() + timeout_s
    while time.time() < deadline:
        try:
            status = get_wf_status(wf_id).get("status")
            print("Status:", status)
            if status in ("Succeeded", "Failed", "Aborted"):
                return status
        except Exception:
            now = time.time()
            if now - _last_restart > 30:
                print("Cromwell not reachable; restarting...")
                start_cromwell()
                _last_restart = now
        time.sleep(poll_s)
    raise TimeoutError(f"Workflow {wf_id} did not finish within {timeout_s}s")

def latest_workflow_id(wdl_name=None, status=None):
    params = {"page": 1, "pagesize": 20}
    if wdl_name:
        params["name"] = wdl_name
    if status:
        params["status"] = status

    r = requests.get(f"{CROMWELL_API}/query", params=params)
    if r.status_code != 200:
        payload = {"page": 1, "pagesize": 20}
        if wdl_name:
            payload["name"] = wdl_name
        if status:
            payload["status"] = status
        r = requests.post(f"{CROMWELL_API}/query", json=payload)

    r.raise_for_status()
    results = r.json().get("results", [])
    if not results:
        return None
    results.sort(key=lambda x: x.get("submission", ""), reverse=True)
    return results[0].get("id")

def get_callroots(wf_id):
    meta = get_wf_metadata(wf_id, include_keys=["callRoot", "calls"])
    callroots = []
    calls = meta.get("calls", {})
    for call_name, entries in calls.items():
        for e in entries:
            if "callRoot" in e:
                callroots.append((call_name, e["callRoot"]))
    return callroots

def fetch_task_logs_from_gcs(wf_id, call_name=None):
    callroots = get_callroots(wf_id)
    if not callroots:
        print("No callRoot entries found.")
        return

    for name, root in callroots:
        if call_name and call_name != name:
            continue
        stdout = f"{root}/stdout"
        stderr = f"{root}/stderr"
        print(f"\nCall: {name}")
        print("stdout:", stdout)
        print("stderr:", stderr)
        subprocess.run(f"gsutil cat {stdout} | tail -n 50", shell=True, check=False)
        subprocess.run(f"gsutil cat {stderr} | tail -n 50", shell=True, check=False)

def latest_workflow_id_gcs(workspace_bucket, workflow_name):
    cmd = f"gsutil ls -l {workspace_bucket}/workflows/cromwell-executions/{workflow_name}/"
    out = subprocess.check_output(cmd, shell=True, text=True)
    lines = [l for l in out.splitlines() if l.strip().startswith("gs://")]
    if not lines:
        return None
    lines.sort()
    latest = lines[-1].split()[-1].rstrip("/")
    return latest.split("/")[-1]


# Warm Up Cromwell Server (Start up/ Validate)

In [None]:
print("cromwell_up():", cromwell_up())
print("cromwell_pid_running():", cromwell_pid_running())
print("cromwell_healthy():", cromwell_healthy())

start_cromwell()

print("cromwell_up():", cromwell_up())
print("cromwell_pid_running():", cromwell_pid_running())
print("cromwell_healthy():", cromwell_healthy())

print("cromwell_persistent_ok():", cromwell_persistent_ok())
print("cromwell_persistent_recent():", cromwell_persistent_recent())

tail_logs(20)

# Optional: latest workflow ID (if any exist)
print("latest_workflow_id():", latest_workflow_id())



In [None]:
print("status endpoint:", requests.get(CROMWELL_STATUS_URL).text)


In [None]:
tail_logs(20)

# Cromwell Configuration 

## Define cromwell.conf

In [None]:
# upgrades: add system tuning for throughput
from pathlib import Path

CROMWELL_DB = "/home/jupyter/cromwell_db/local_cromwell_run.db"
Path("/home/jupyter/cromwell_db").mkdir(parents=True, exist_ok=True)

cromwell_conf = f"""include required(classpath("application"))

google {{
  application-name = "cromwell"
  auths = [{{
    name = "application_default"
    scheme = "application_default"
  }}]
}}

system {{
  new-workflow-poll-rate = 1
  max-concurrent-workflows = 50
  max-workflow-launch-count = 400
  job-rate-control {{
    jobs = 100
    per = "3 seconds"
  }}
}}

backend {{
  default = "GCPBATCH"
  providers {{
    Local.config.root = "/dev/null"

    GCPBATCH {{
      actor-factory = "cromwell.backend.google.batch.GcpBatchBackendLifecycleActorFactory"
      config {{
        project = "{GOOGLE_PROJECT}"
        concurrent-job-limit = 20
        root = "{WORKSPACE_BUCKET}/workflows/cromwell-executions"

        virtual-private-cloud {{
          network-name = "projects/{GOOGLE_PROJECT}/global/networks/network"
          subnetwork-name = "projects/{GOOGLE_PROJECT}/regions/us-central1/subnetworks/subnetwork"
        }}

        batch {{
          auth = "application_default"
          compute-service-account = "{PET_SA_EMAIL}"
          location = "us-central1"
        }}

        default-runtime-attributes {{
          noAddress: true
        }}

        filesystems {{
          gcs {{
            auth = "application_default"
          }}
        }}
      }}
    }}
  }}
}}

database {{
  profile = "slick.jdbc.HsqldbProfile$"
  insert-batch-size = 6000
  db {{
    driver = "org.hsqldb.jdbcDriver"
    url = "jdbc:hsqldb:file:{CROMWELL_DB};shutdown=false;hsqldb.default_table_type=cached;hsqldb.tx=mvcc;hsqldb.large_data=true;hsqldb.lob_compressed=true;hsqldb.script_format=3;hsqldb.result_max_memory_rows=20000"
    connectionTimeout = 300000
  }}
}}
"""

Path("/home/jupyter/cromwell.conf").write_text(cromwell_conf)
print("Wrote /home/jupyter/cromwell.conf")


# Download or Generate Metadata 

## Helper Functions

In [None]:
# Download mtdna metadata 
# TODO: replace with call to RScript to regenerate this 
# Alternatively make 
import gzip
import subprocess
from pathlib import Path

def ensure_mtdna_tsv(download=False, print_head=False):
    out_dir = Path("data/metadata")
    out_dir.mkdir(parents=True, exist_ok=True)

    url = "https://raw.githubusercontent.com/Kaychewe/mtDNA-analysis/meta/mtdna_mitoclock_aou_dataset_36246309_person_age_gender_crams.tsv.gz"
    path = out_dir / "mtdna_mitoclock_aou_dataset_36246309_person_age_gender_crams.tsv.gz"

    if download or not path.exists():
        subprocess.run(["curl", "-L", url, "-o", str(path)], check=True)
        print("Downloaded:", path.resolve())
    else:
        print("Using existing:", path.resolve())

    # gzip check + header
    with gzip.open(path, "rt") as f:
        header = f.readline().strip()
        first = f.readline().strip() if print_head else None

    print("Gzip OK.")
    print("Header:", header)
    if print_head:
        print("First row:", first)

    return path, header

# Example usage:
#ensure_mtdna_tsv(download=False, print_head=True)

def load_and_sort_by_person_id(path, limit=None):
    with gzip.open(path, "rt") as f:
        header = f.readline().strip().split("\t")
        rows = [line.strip().split("\t") for line in f if line.strip()]

    # find person_id column
    pid_idx = header.index("person_id")
    rows.sort(key=lambda r: int(r[pid_idx]))

    if limit:
        rows = rows[:limit]

    return header, rows


## Load Dataframe

In [None]:
# expecting age, gender, sex, cram, cram_id 
path = ensure_mtdna_tsv(download=False, print_head=True)[0]
header, rows = load_and_sort_by_person_id(path, limit=2)
print("Header:", header)
print("First 10 rows:")
for r in rows:
    print(r)

In [None]:
import pandas as pd
path = ensure_mtdna_tsv(download=False, print_head=True)[0]
df = pd.read_csv(path, sep="\t", compression="gzip")
df = df.sort_values("person_id").reset_index(drop=True)
print(df.head(2))


# Stage01 (Comprehensive): CRAM → chrM + NUMT BAM/SAM

**Goal**  
Start from a WGS CRAM, subset to chrM + NUMT intervals, clean/standardize reads, and produce final BAM/BAI/SAM plus key QC.

**Inputs**
- CRAM + CRAI
- Reference FASTA + index + dict
- chrM interval list + NUMT interval list
- sample metadata (age/sex) for naming

**Outputs**
- `out/<sample>_<age>_<sex>_chrM.proc.bam`
- `out/<sample>_<age>_<sex>_chrM.proc.bam.bai`
- `out/<sample>_<age>_<sex>_chrM.proc.sam`
- `out/<sample>_<age>_<sex>_chrM.unmap.bam`
- `out/<sample>_<age>_<sex>_chrM.duplicate.metrics`
- `out/<sample>_<age>_<sex>_chrM.mean_coverage.txt`
- `out/<sample>_<age>_<sex>_chrM.ct_failed.txt`

**Key Steps**
1. Subset CRAM to chrM + NUMT (`gatk PrintReads`)
2. Validate + remove broken mates
3. Remove malformed XQ tag (if present)
4. Revert to unmapped
5. Coverage metrics (mean coverage)
6. Mark duplicates
7. Sort + index
8. Export SAM


## Configure Stage01 Full GATK run 

In [None]:
from pathlib import Path

wdl_text = """\
version 1.0

workflow stage01_SubsetCramChrM {
  meta {
    description: "Stage01 (comprehensive): subset to chrM + NUMT, clean, mark duplicates, and emit BAM/BAI/SAM."
  }

  input {
    File input_cram
    File? input_crai
    String sample_id
    String? age
    String? sex

    File mt_interval_list
    File numt_interval_list
    File ref_fasta
    File ref_fasta_index
    File ref_dict

    String docker

    Int? mem_gb
    Int? n_cpu
    Int? preemptible_tries
    String? requester_pays_project
  }

  call SubsetAndProcessChrM {
    input:
      input_cram = input_cram,
      input_crai = input_crai,
      sample_id = sample_id,
      age = age,
      sex = sex,
      mt_interval_list = mt_interval_list,
      numt_interval_list = numt_interval_list,
      ref_fasta = ref_fasta,
      ref_fasta_index = ref_fasta_index,
      ref_dict = ref_dict,
      docker = docker,
      mem_gb = mem_gb,
      n_cpu = n_cpu,
      preemptible_tries = preemptible_tries,
      requester_pays_project = requester_pays_project
  }

  output {
    File final_bam = SubsetAndProcessChrM.final_bam
    File final_bai = SubsetAndProcessChrM.final_bai
    File final_sam = SubsetAndProcessChrM.final_sam
    File unmapped_bam = SubsetAndProcessChrM.unmapped_bam
    File duplicate_metrics = SubsetAndProcessChrM.duplicate_metrics
    Int reads_dropped = SubsetAndProcessChrM.reads_dropped
    Int mean_coverage = SubsetAndProcessChrM.mean_coverage
  }
}

task SubsetAndProcessChrM {
  input {
    File input_cram
    File? input_crai
    String sample_id
    String? age
    String? sex

    File mt_interval_list
    File numt_interval_list
    File ref_fasta
    File ref_fasta_index
    File ref_dict

    String docker

    Int? mem_gb
    Int? n_cpu
    Int? preemptible_tries
    String? requester_pays_project
  }

  String age_label = select_first([age, "NA"])
  String sex_label = select_first([sex, "NA"])
  String prefix = sample_id + "_" + age_label + "_" + sex_label + "_chrM"

  Float ref_size = size(ref_fasta, "GB") + size(ref_fasta_index, "GB") + size(ref_dict, "GB")
  Int disk_size = ceil(ref_size) + ceil(size(input_cram, "GB")) + 20
  Int machine_mem = select_first([mem_gb, 8])
  Int command_mem = (machine_mem * 1000) - 500
  String appended_crai = input_cram + ".crai"

  String d = "$"

  command <<<
    set -euo pipefail

    mkdir -p out

    this_cram="~{input_cram}"
    this_crai="~{select_first([input_crai, appended_crai])}"

    echo "STEP 1: Subset CRAM to chrM + NUMT (PrintReads)"
    gatk --java-options "-Xmx~{command_mem}m" PrintReads \
      -R ~{ref_fasta} \
      -L ~{mt_interval_list} \
      -L ~{numt_interval_list} \
      ~{"--gcs-project-for-requester-pays " + requester_pays_project} \
      -I ~{d}{this_cram} --read-index ~{d}{this_crai} \
      -O "out/~{prefix}.bam"

    echo "STEP 2: Validate + remove broken mates"
    set +e
    gatk --java-options "-Xmx~{command_mem}m" ValidateSamFile \
      -INPUT "out/~{prefix}.bam" \
      -O output.txt \
      -M VERBOSE \
      -IGNORE_WARNINGS true \
      -MAX_OUTPUT 9999999
    cat output.txt | \
      grep "ERROR.*Mate not found for paired read" | \
      sed -e "s/ERROR::MATE_NOT_FOUND:Read name //g" | \
      sed -e "s/, Mate not found for paired read//g" > read_list.txt
    cat read_list.txt | wc -l | sed "s/^ *//g" > "out/~{prefix}.ct_failed.txt"
    if [[ $(tr -d "\\r\\n" < read_list.txt|wc -c) -eq 0 ]]; then
      cp "out/~{prefix}.bam" rescued.bam
    else
      gatk --java-options "-Xmx~{command_mem}m" FilterSamReads \
        -I "out/~{prefix}.bam" \
        -O rescued.bam \
        -READ_LIST_FILE read_list.txt \
        -FILTER excludeReadList
    fi
    set -e

    echo "STEP 2.5: Remove malformed XQ tag"
    samtools view -h rescued.bam \
      | sed 's/\\tXQ:i:[0-9]\\+//g' \
      | samtools view -b -o cleaned.bam

    echo "STEP 3: Revert to unmapped (cleaned)"
    gatk --java-options "-Xmx~{command_mem}m" RevertSam \
      -INPUT cleaned.bam \
      -OUTPUT_BY_READGROUP false \
      -OUTPUT "out/~{prefix}.unmap.bam" \
      -VALIDATION_STRINGENCY LENIENT \
      -ATTRIBUTE_TO_CLEAR FT \
      -ATTRIBUTE_TO_CLEAR CO \
      -ATTRIBUTE_TO_CLEAR XQ \
      -SORT_ORDER queryname \
      -RESTORE_ORIGINAL_QUALITIES false

    echo "STEP 4: Collect WGS metrics"
    gatk --java-options "-Xmx~{command_mem}m" CollectWgsMetrics \
      INPUT="out/~{prefix}.bam" \
      INTERVALS=~{mt_interval_list} \
      VALIDATION_STRINGENCY=SILENT \
      REFERENCE_SEQUENCE=~{ref_fasta} \
      OUTPUT="out/~{prefix}.wgs_metrics.txt" \
      USE_FAST_ALGORITHM=true \
      READ_LENGTH=151 \
      COVERAGE_CAP=100000 \
      INCLUDE_BQ_HISTOGRAM=true \
      THEORETICAL_SENSITIVITY_OUTPUT="out/~{prefix}.theoretical_sensitivity.txt"

    R --vanilla <<CODE
      df = read.table("out/~{prefix}.wgs_metrics.txt",skip=6,header=TRUE,stringsAsFactors=FALSE,sep='\\t',nrows=1)
      write.table(floor(df[,"MEAN_COVERAGE"]), "out/~{prefix}.mean_coverage.txt", quote=F, col.names=F, row.names=F)
    CODE

    echo "STEP 5: Mark duplicates"
    gatk --java-options "-Xmx~{command_mem}m" MarkDuplicates \
      INPUT="out/~{prefix}.bam" \
      OUTPUT=md.bam \
      METRICS_FILE="out/~{prefix}.duplicate.metrics" \
      VALIDATION_STRINGENCY=SILENT \
      OPTICAL_DUPLICATE_PIXEL_DISTANCE=2500 \
      ASSUME_SORT_ORDER="queryname" \
      CLEAR_DT="false" \
      ADD_PG_TAG_TO_READS=false

    echo "STEP 6: Sort + index"
    gatk --java-options "-Xmx~{command_mem}m" SortSam \
      INPUT=md.bam \
      OUTPUT="out/~{prefix}.proc.bam" \
      SORT_ORDER="coordinate" \
      CREATE_INDEX=true \
      MAX_RECORDS_IN_RAM=300000

    echo "STEP 7: Export SAM"
    samtools view -h "out/~{prefix}.proc.bam" > "out/~{prefix}.proc.sam"
  >>>

  runtime {
    memory: machine_mem + " GB"
    disks: "local-disk " + disk_size + " HDD"
    docker: docker
    preemptible: select_first([preemptible_tries, 5])
    cpu: select_first([n_cpu, 1])
  }

  output {
    File final_bam = "out/~{prefix}.proc.bam"
    File final_bai = "out/~{prefix}.proc.bai"
    File final_sam = "out/~{prefix}.proc.sam"
    File unmapped_bam = "out/~{prefix}.unmap.bam"
    File duplicate_metrics = "out/~{prefix}.duplicate.metrics"
    Int reads_dropped = read_int("out/~{prefix}.ct_failed.txt")
    Int mean_coverage = read_int("out/~{prefix}.mean_coverage.txt")
  }
}
"""

out_dir = Path("./WDL/s001")
out_dir.mkdir(parents=True, exist_ok=True)
wdl_path = out_dir / "stage01_SubsetCramChrM_v2.wdl"
wdl_path.write_text(wdl_text)

print(f"Wrote: {wdl_path.resolve()}")


In [None]:
! cat /home/jupyter/workspaces/mtdnaheteroplasmyandaginganalysis/WDL/s001/stage01_SubsetCramChrM.wdl

In [None]:
import json
from pathlib import Path

# ---- Select sample ----
SELECT_PERSON_ID = 1000696 
SELECT_ROW_INDEX = 0

if SELECT_PERSON_ID is not None:
    row = df.loc[df["person_id"] == SELECT_PERSON_ID].iloc[0]
else:
    row = df.iloc[SELECT_ROW_INDEX]

sample_id = str(row["person_id"])
age = str(row["age"]) if "age" in row and not pd.isna(row["age"]) else None
sex = str(row["sex_at_birth"]) if "sex_at_birth" in row and not pd.isna(row["sex_at_birth"]) else None

# ---- Reference paths ----
mt_interval_list = "gs://gcp-public-data--broad-references/hg38/v0/chrM/chrM.hg38.interval_list"
numt_interval_list = "gs://fc-secure-76d68a64-00aa-40a7-b2c5-ca956db2719b/intervals/NUMTv3_all385.hg38.interval_list"

ref_fasta = "gs://gcp-public-data--broad-references/hg38/v0/Homo_sapiens_assembly38.fasta"
ref_fasta_index = "gs://gcp-public-data--broad-references/hg38/v0/Homo_sapiens_assembly38.fasta.fai"
ref_dict = "gs://gcp-public-data--broad-references/hg38/v0/Homo_sapiens_assembly38.dict"

inputs = {
    "stage01_SubsetCramChrM.input_cram": row["cram_uri"],
    "stage01_SubsetCramChrM.input_crai": row["cram_index_uri"],
    "stage01_SubsetCramChrM.sample_id": sample_id,
    "stage01_SubsetCramChrM.age": age,
    "stage01_SubsetCramChrM.sex": sex,
    "stage01_SubsetCramChrM.mt_interval_list": mt_interval_list,
    "stage01_SubsetCramChrM.numt_interval_list": numt_interval_list,
    "stage01_SubsetCramChrM.ref_fasta": ref_fasta,
    "stage01_SubsetCramChrM.ref_fasta_index": ref_fasta_index,
    "stage01_SubsetCramChrM.ref_dict": ref_dict,
    "stage01_SubsetCramChrM.docker": "kchewe/mtdna-tools:0.1.0",
    "stage01_SubsetCramChrM.requester_pays_project": GOOGLE_PROJECT,
    "stage01_SubsetCramChrM.mem_gb": 16,
    "stage01_SubsetCramChrM.n_cpu": 4,
}

out_path = Path("./WDL/s001/stage01_SubsetCramChrM.inputs_v2.json")
out_path.write_text(json.dumps(inputs, indent=2) + "\n")
print("Wrote:", out_path.resolve())
print("Selected sample:", sample_id)


In [None]:
!cat /home/jupyter/workspaces/mtdnaheteroplasmyandaginganalysis/WDL/s001/stage01_SubsetCramChrM.inputs_v2.json

## Submit Stage01

In [None]:
import subprocess
from pathlib import Path
import requests

# Submit stage01 (comprehensive)
start_cromwell()

wdl_path = Path("./WDL/s001/stage01_SubsetCramChrM_v2.wdl")
json_path = Path("./WDL/s001/stage01_SubsetCramChrM.inputs_v2.json")

# Validate WDL
validate_cmd = (
    "bash -lc 'source /home/jupyter/.sdkman/bin/sdkman-init.sh "
    "&& sdk use java 17.0.8-tem "
    f"&& java -jar womtool-91.jar validate {wdl_path}'"
)
print("$", validate_cmd)
subprocess.run(validate_cmd, shell=True, check=True)

# Submit to Cromwell
cromwell_url = "http://localhost:8094/api/workflows/v1"
files = {
    "workflowSource": wdl_path.open("rb"),
    "workflowInputs": json_path.open("rb"),
}
print("Submitting to:", cromwell_url)
resp = requests.post(cromwell_url, files=files, headers={"accept": "application/json"})
resp.raise_for_status()
wf = resp.json()
print("Response:", wf)


## Monitor Stage 01 

In [None]:
# Monitor + logs
wf_id = wf.get("id")
print(wf_id)

In [None]:
print(wf_id)
status = wait_for_wf(wf_id, poll_s=30, timeout_s=7200)
print("Final status:", status)

pretty(get_wf_metadata(wf_id, include_keys=["failures", "callRoot"]))
fetch_task_logs_from_gcs(wf_id)

# Stage02 (Comprehensive): 

#### Configuration (WDL and JSON)

In [None]:
from pathlib import Path

wdl_text = """\
version 1.0

workflow stage02_MtOnly {
  meta {
    description: "Stage02 (full, single-file): align/call mt + nuc, contamination, haplochecker."
  }

  input {
    File input_bam
    File input_bai
    String sample_id
    String? age
    String? sex

    File ref_dict
    File ref_fasta
    File ref_fasta_index

    File mt_dict
    File mt_fasta
    File mt_fasta_index
    File blacklisted_sites
    File blacklisted_sites_index

    File nuc_interval_list
    File mt_interval_list

    Int mt_mean_coverage

    Boolean use_haplotype_caller_nucdna = true
    Int hc_dp_lower_bound = 10

    String docker
    File? gatk_override
    String gatk_version = "4.2.6.0"
    String? m2_extra_args
    String? m2_filter_extra_args
    Float? vaf_filter_threshold
    Float? f_score_beta
    Boolean compress_output_vcf = false
    Float? verifyBamID

    Int? max_read_length
    File haplocheck_zip

    Int? preemptible_tries
    Int? n_cpu
  }

  String age_label = select_first([age, "NA"])
  String sex_label = select_first([sex, "NA"])
  String sample_label = sample_id + "_" + age_label + "_" + sex_label

  if (use_haplotype_caller_nucdna) {
    call MongoHC as CallNucHCIntegrated {
      input:
        input_bam = input_bam,
        input_bai = input_bai,
        sample_name = sample_label,
        nuc_interval_list = nuc_interval_list,
        ref_fasta = ref_fasta,
        ref_fai = ref_fasta_index,
        ref_dict = ref_dict,
        suffix = ".nuc",
        compress = compress_output_vcf,
        gatk_override = gatk_override,
        gatk_docker_override = docker,
        gatk_version = gatk_version,
        hc_dp_lower_bound = hc_dp_lower_bound,
        mem = 4,
        preemptible_tries = preemptible_tries,
        n_cpu = n_cpu
    }
  }

  if (!use_haplotype_caller_nucdna) {
    call MongoNucM2 as CallNucM2Integrated {
      input:
        input_bam = input_bam,
        input_bai = input_bai,
        sample_name = sample_label,

        ref_fasta = ref_fasta,
        ref_fai = ref_fasta_index,
        ref_dict = ref_dict,
        suffix = ".nuc",
        mt_interval_list = nuc_interval_list,

        m2_extra_args = select_first([m2_extra_args, ""]),

        max_alt_allele_count = 4,
        vaf_filter_threshold = 0.95,
        verifyBamID = verifyBamID,
        compress = compress_output_vcf,

        gatk_override = gatk_override,
        gatk_docker_override = docker,
        gatk_version = gatk_version,
        mem = 4,
        preemptible_tries = preemptible_tries,
        n_cpu = n_cpu
    }
  }

  Int M2_mem = if mt_mean_coverage > 25000 then 14 else 7

  call MongoRunM2InitialFilterSplit as CallMt {
    input:
      sample_name = sample_label,
      input_bam = input_bam,
      input_bai = input_bai,
      verifyBamID = verifyBamID,
      mt_interval_list = mt_interval_list,
      ref_fasta = mt_fasta,
      ref_fai = mt_fasta_index,
      ref_dict = mt_dict,
      suffix = "",
      compress = compress_output_vcf,
      m2_extra_filtering_args = select_first([m2_filter_extra_args, ""]) + " --min-median-mapping-quality 0",
      max_alt_allele_count = 4,
      vaf_filter_threshold = 0,
      blacklisted_sites = blacklisted_sites,
      blacklisted_sites_index = blacklisted_sites_index,
      f_score_beta = f_score_beta,
      gatk_override = gatk_override,
      gatk_docker_override = docker,
      gatk_version = gatk_version,
      m2_extra_args = select_first([m2_extra_args, ""]),
      mem = M2_mem,
      preemptible_tries = preemptible_tries,
      n_cpu = n_cpu
  }

  call GetContamination {
    input:
      input_vcf = CallMt.vcf_for_haplochecker,
      sample_name = sample_label,
      mean_coverage = mt_mean_coverage,
      preemptible_tries = preemptible_tries,
      haplochecker_docker = docker,
      haplocheck_zip = haplocheck_zip
  }

  call MongoM2FilterContaminationSplit as FilterContamination {
    input:
      raw_vcf = CallMt.filtered_vcf,
      raw_vcf_index = CallMt.filtered_vcf_idx,
      raw_vcf_stats = CallMt.stats,
      sample_name = sample_label,
      hasContamination = GetContamination.hasContamination,
      contamination_major = GetContamination.major_level,
      contamination_minor = GetContamination.minor_level,
      suffix = "",
      run_contamination = true,
      verifyBamID = verifyBamID,
      ref_fasta = mt_fasta,
      ref_fai = mt_fasta_index,
      ref_dict = mt_dict,
      compress = compress_output_vcf,
      gatk_override = gatk_override,
      gatk_docker_override = docker,
      gatk_version = gatk_version,
      m2_extra_filtering_args = select_first([m2_filter_extra_args, ""]) + " --min-median-mapping-quality 0",
      max_alt_allele_count = 4,
      vaf_filter_threshold = vaf_filter_threshold,
      blacklisted_sites = blacklisted_sites,
      blacklisted_sites_index = blacklisted_sites_index,
      f_score_beta = f_score_beta,
      preemptible_tries = preemptible_tries
  }

  output {
    File out_vcf = FilterContamination.filtered_vcf
    File out_vcf_index = FilterContamination.filtered_vcf_idx
    File split_vcf = FilterContamination.split_vcf
    File split_vcf_index = FilterContamination.split_vcf_index
    File nuc_vcf = select_first([CallNucHCIntegrated.full_pass_vcf, CallNucM2Integrated.full_pass_vcf])
    File nuc_vcf_index = select_first([CallNucHCIntegrated.full_pass_vcf_index, CallNucM2Integrated.full_pass_vcf_index])
    File nuc_vcf_unfiltered = select_first([CallNucHCIntegrated.filtered_vcf, CallNucM2Integrated.filtered_vcf])
    File split_nuc_vcf = select_first([CallNucHCIntegrated.split_vcf, CallNucM2Integrated.split_vcf])
    File split_nuc_vcf_index = select_first([CallNucHCIntegrated.split_vcf_index, CallNucM2Integrated.split_vcf_index])
    Int nuc_variants_pass = select_first([CallNucHCIntegrated.post_filt_vars, CallNucM2Integrated.post_filt_vars])
    File input_vcf_for_haplochecker = CallMt.vcf_for_haplochecker
    File contamination_metrics = GetContamination.contamination_file
    String major_haplogroup = GetContamination.major_hg
    Float contamination = FilterContamination.contamination
    String hasContamination = GetContamination.hasContamination
    Float contamination_major = GetContamination.major_level
    Float contamination_minor = GetContamination.minor_level
  }
}

task GetContamination {
  input {
    File input_vcf
    String sample_name
    Int mean_coverage
    File haplocheck_zip
    String haplochecker_docker
    Int? preemptible_tries
  }

  Int disk_size = ceil(size(input_vcf, "GB")) + 20
  String d = "$"

  command <<<
  set -e

  mkdir out
  this_basename=out/"~{sample_name}"
  this_mean_cov="~{mean_coverage}"
  this_vcf="~{input_vcf}"

  this_vcf_nvar=$(cat "~{d}{this_vcf}" | grep ^chrM | wc -l | sed 's/^ *//g')
  echo "~{sample_name} has VCF with ~{d}{this_vcf_nvar} variants for contamination."

  zip_path="~{haplocheck_zip}"
  jar xf "${zip_path}"
  chmod +x haplocheck

  ./haplocheck --out output "~{d}{this_vcf}"

  if [ -s output ]; then
    sed 's/\"//g' output > output-noquotes
  else
    : > output-noquotes
  fi

  if grep -q "SampleID" output-noquotes; then
    awk -F "\t" 'NR==1{print;next}{$1="~{sample_name}";print}' output-noquotes > output-noquotes.fixed
    mv output-noquotes.fixed output-noquotes
  fi

  cp 'output-noquotes' "~{d}{this_basename}_output_noquotes"

  FORMAT_ERROR="Bad contamination file format"
  if grep -q "SampleID" output-noquotes; then
    grep "SampleID" output-noquotes > headers
    if [ `awk '{print $2}' headers` != "Contamination" ]; then echo $FORMAT_ERROR; fi
    if [ `awk '{print $6}' headers` != "HgMajor" ]; then echo $FORMAT_ERROR; fi
    if [ `awk '{print $8}' headers` != "HgMinor" ]; then echo $FORMAT_ERROR; fi
    if [ `awk '{print $14}' headers` != "MeanHetLevelMajor" ]; then echo $FORMAT_ERROR; fi
    if [ `awk '{print $15}' headers` != "MeanHetLevelMinor" ]; then echo $FORMAT_ERROR; fi
  else
    echo $FORMAT_ERROR
  fi

  if grep -q "SampleID" output-noquotes && grep -v "SampleID" output-noquotes | grep -q . && [ "~{d}{this_mean_cov}" -gt 0 ] && [ "~{d}{this_vcf_nvar}" -gt 0 ]; then
    grep -v "SampleID" output-noquotes > output-data
    awk -F "\t" '{print $2}' output-data > "~{d}{this_basename}.contamination.txt"
    awk -F "\t" '{print $6}' output-data > "~{d}{this_basename}.major_hg.txt"
    awk -F "\t" '{print $8}' output-data > "~{d}{this_basename}.minor_hg.txt"
    awk -F "\t" '{print $14}' output-data > "~{d}{this_basename}.mean_het_major.txt"
    awk -F "\t" '{print $15}' output-data > "~{d}{this_basename}.mean_het_minor.txt"
  else
    echo "NO" > "~{d}{this_basename}.contamination.txt"
    echo "NONE" > "~{d}{this_basename}.major_hg.txt"
    echo "NONE" > "~{d}{this_basename}.minor_hg.txt"
    echo "0.000" > "~{d}{this_basename}.mean_het_major.txt"
    echo "0.000" > "~{d}{this_basename}.mean_het_minor.txt"
  fi
  >>>
  runtime {
    preemptible: select_first([preemptible_tries, 5])
    memory: "3 GB"
    disks: "local-disk " + disk_size + " HDD"
    docker: haplochecker_docker
  }
  output {
    File contamination_file = "out/~{sample_name}_output_noquotes"
    String hasContamination = read_string("out/~{sample_name}.contamination.txt")
    String major_hg = read_string("out/~{sample_name}.major_hg.txt")
    String minor_hg = read_string("out/~{sample_name}.minor_hg.txt")
    Float major_level = read_float("out/~{sample_name}.mean_het_major.txt")
    Float minor_level = read_float("out/~{sample_name}.mean_het_minor.txt")
  }
}

task MongoHC {
  input {
    File ref_fasta
    File ref_fai
    File ref_dict
    File input_bam
    File input_bai

    String sample_name
    String suffix = ""

    Int max_reads_per_alignment_start = 75
    String? hc_extra_args
    Boolean make_bamout = false

    File? nuc_interval_list
    File? force_call_vcf
    File? force_call_vcf_index

    Boolean compress
    String gatk_version
    File? gatk_override
    String? gatk_docker_override
    Float? contamination

    Int hc_dp_lower_bound

    Int mem
    Int? preemptible_tries
    Int? n_cpu
  }

  Int machine_mem = if defined(mem) then mem * 1000 else 3500
  Int command_mem = machine_mem - 500

  Float ref_size = size(ref_fasta, "GB") + size(ref_fai, "GB") + size(ref_dict, "GB")
  Int disk_size = ceil((size(input_bam, "GB") * 2) + ref_size) + 22

  String d = "$"

  command <<<
    set -e

    mkdir out
    this_sample=out/"~{sample_name}"
    this_basename="~{d}{this_sample}""~{suffix}"
    bamoutfile="~{d}{this_basename}.bamout.bam"
    touch "~{d}{bamoutfile}"

    if [[ ~{make_bamout} == 'true' ]]; then bamoutstr="--bam-output ~{d}{this_basename}.bamout.bam"; else bamoutstr=""; fi

    gatk --java-options "-Xmx~{command_mem}m" HaplotypeCaller \
      -R ~{ref_fasta} \
      -I ~{input_bam} \
      ~{"-L " + nuc_interval_list} \
      -O "~{d}{this_basename}.raw.vcf" \
      ~{hc_extra_args} \
      -contamination ~{default="0" contamination} \
      ~{"--genotype-filtered-alleles --alleles " + force_call_vcf} \
      --max-reads-per-alignment-start ~{max_reads_per_alignment_start} \
      --max-mnp-distance 0 \
      --annotation StrandBiasBySample \
      -G StandardAnnotation -G StandardHCAnnotation \
      -GQB 10 -GQB 20 -GQB 30 -GQB 40 -GQB 50 -GQB 60 -GQB 70 -GQB 80 -GQB 90 ~{d}{bamoutstr}

    gatk --java-options "-Xmx~{command_mem}m" SelectVariants -V "~{d}{this_basename}.raw.vcf" -select-type SNP -O snps.vcf
    gatk --java-options "-Xmx~{command_mem}m" VariantFiltration -V snps.vcf \
      -R ~{ref_fasta} \
      -O snps_filtered.vcf \
      -filter "QD < 2.0" --filter-name "QD2" \
      -filter "QUAL < 30.0" --filter-name "QUAL30" \
      -filter "SOR > 3.0" --filter-name "SOR3" \
      -filter "FS > 60.0" --filter-name "FS60" \
      -filter "MQ < 40.0" --filter-name "MQ40" \
      -filter "MQRankSum < -12.5" --filter-name "MQRankSum-12.5" \
      -filter "ReadPosRankSum < -8.0" --filter-name "ReadPosRankSum-8" \
      --genotype-filter-expression "isHet == 1" --genotype-filter-name "isHetFilt" \
      --genotype-filter-expression "isHomRef == 1" --genotype-filter-name "isHomRefFilt" \
      ~{'--genotype-filter-expression "DP < ' + hc_dp_lower_bound + '" --genotype-filter-name "genoDP' + hc_dp_lower_bound + '"'}

    gatk --java-options "-Xmx~{command_mem}m" SelectVariants -V "~{d}{this_basename}.raw.vcf" -select-type INDEL -O indels.vcf
    gatk --java-options "-Xmx~{command_mem}m" VariantFiltration -V indels.vcf \
      -R ~{ref_fasta} \
      -O indels_filtered.vcf \
      -filter "QD < 2.0" --filter-name "QD2" \
      -filter "QUAL < 30.0" --filter-name "QUAL30" \
      -filter "FS > 200.0" --filter-name "FS200" \
      -filter "SOR > 10.0" --filter-name "SOR10" \
      -filter "ReadPosRankSum < -20.0" --filter-name "ReadPosRankSum-20" \
      --genotype-filter-expression "isHet == 1" --genotype-filter-name "isHetFilt" \
      --genotype-filter-expression "isHomRef == 1" --genotype-filter-name "isHomRefFilt" \
      ~{'--genotype-filter-expression "DP < ' + hc_dp_lower_bound + '" --genotype-filter-name "genoDP' + hc_dp_lower_bound + '"'}

    gatk --java-options "-Xmx~{command_mem}m" MergeVcfs -I snps_filtered.vcf -I indels_filtered.vcf -O "~{d}{this_basename}.vcf"

    gatk --java-options "-Xmx~{command_mem}m" SelectVariants \
      -V "~{d}{this_basename}.vcf" \
      --exclude-filtered \
      --set-filtered-gt-to-nocall \
      --exclude-non-variants \
      -O "~{d}{this_basename}.pass.vcf"

    gatk --java-options "-Xmx~{command_mem}m" CountVariants -V $this_basename.pass.vcf | tail -n1 > "~{d}{this_basename}.passvars.txt"

    gatk --java-options "-Xmx~{command_mem}m" LeftAlignAndTrimVariants \
      -R ~{ref_fasta} \
      -V "~{d}{this_basename}.pass.vcf" \
      -O "~{d}{this_basename}.pass.split.vcf" \
      --split-multi-allelics \
      --dont-trim-alleles \
      --keep-original-ac \
      --create-output-variant-index
  >>>

  runtime {
    docker: select_first([gatk_docker_override, "us.gcr.io/broad-gatk/gatk:"+gatk_version])
    memory: machine_mem + " MB"
    disks: "local-disk " + disk_size + " HDD"
    preemptible: select_first([preemptible_tries, 5])
    cpu: select_first([n_cpu,1])
  }
  output {
    File raw_vcf = "out/~{sample_name}~{suffix}.raw.vcf"
    File raw_vcf_idx = "out/~{sample_name}~{suffix}.raw.vcf.idx"
    File output_bamOut = "out/~{sample_name}~{suffix}.bamout.bam"
    File filtered_vcf = "out/~{sample_name}~{suffix}.vcf"
    File filtered_vcf_idx = "out/~{sample_name}~{suffix}.vcf.idx"
    File full_pass_vcf = "out/~{sample_name}~{suffix}.pass.vcf"
    File full_pass_vcf_index = "out/~{sample_name}~{suffix}.pass.vcf.idx"
    Int post_filt_vars = read_int("out/~{sample_name}~{suffix}.passvars.txt")
    File split_vcf = "out/~{sample_name}~{suffix}.pass.split.vcf"
    File split_vcf_index = "out/~{sample_name}~{suffix}.pass.split.vcf.idx"
  }
}

task MongoNucM2 {
  input {
    File ref_fasta
    File ref_fai
    File ref_dict
    File input_bam
    File input_bai

    String sample_name
    String suffix = ""

    Int max_reads_per_alignment_start = 75
    String? m2_extra_args
    Boolean make_bamout = false
    Boolean compress

    File? mt_interval_list

    Float? vaf_cutoff
    String? m2_extra_filtering_args
    Int max_alt_allele_count
    Float? vaf_filter_threshold
    Float? f_score_beta
    Float? verifyBamID
    File? blacklisted_sites
    File? blacklisted_sites_index

    File? gatk_override
    String gatk_version
    String? gatk_docker_override
    Int mem
    Int? preemptible_tries
    Int? n_cpu
  }

  Float ref_size = size(ref_fasta, "GB") + size(ref_fai, "GB")
  Int disk_size = ceil(size(input_bam, "GB")*2 + ref_size) + 20
  Float defval = 0.0

  Int machine_mem = if defined(mem) then mem * 1000 else 3500
  Int command_mem = machine_mem - 500

  String d = "$"

  command <<<
    set -e

    mkdir out
    this_sample=out/"~{sample_name}"
    this_contamination="~{select_first([verifyBamID, defval])}"
    this_bam="~{input_bam}"
    this_basename="~{d}{this_sample}~{suffix}"
    bamoutfile="~{d}{this_basename}.bamout.bam"
    touch "~{d}{bamoutfile}"
    if [[ ~{make_bamout} == 'true' ]]; then bamoutstr="--bam-output ~{d}{bamoutfile}"; else bamoutstr=""; fi

    gatk --java-options "-Xmx~{command_mem}m" Mutect2 \
      -R ~{ref_fasta} \
      -I "~{d}{this_bam}" \
      ~{"-L " + mt_interval_list} \
      -O "~{d}{this_basename}.raw.vcf" \
      ~{m2_extra_args} \
      ~{"--minimum-allele-fraction " + vaf_filter_threshold} \
      --annotation StrandBiasBySample \
      --max-reads-per-alignment-start ~{max_reads_per_alignment_start} \
      --max-mnp-distance 0 ~{d}{bamoutstr}

    gatk --java-options "-Xmx~{command_mem}m" FilterMutectCalls -V "~{d}{this_basename}.raw.vcf" \
      -R ~{ref_fasta} \
      -O filtered.vcf \
      --stats "~{d}{this_basename}.raw.vcf.stats" \
      ~{m2_extra_filtering_args} \
      --max-alt-allele-count ~{max_alt_allele_count} \
      ~{"--min-allele-fraction " + vaf_filter_threshold} \
      ~{"--f-score-beta " + f_score_beta} \
      --contamination-estimate "~{d}{this_contamination}"

    ~{"gatk IndexFeatureFile -I " + blacklisted_sites}

    gatk --java-options "-Xmx~{command_mem}m" VariantFiltration -V filtered.vcf \
      -O "~{d}{this_basename}.vcf" \
      --apply-allele-specific-filters \
      ~{"--mask-name 'blacklisted_site' --mask " + blacklisted_sites}

    gatk --java-options "-Xmx~{command_mem}m" SelectVariants \
      -V "~{d}{this_basename}.vcf" \
      --exclude-filtered \
      -O "~{d}{this_basename}.pass.vcf"

    gatk CountVariants -V "~{d}{this_basename}.pass.vcf" | tail -n1 > "~{d}{this_basename}.passvars.txt"

    gatk --java-options "-Xmx~{command_mem}m" LeftAlignAndTrimVariants \
      -R ~{ref_fasta} \
      -V "~{d}{this_basename}.pass.vcf" \
      -O "~{d}{this_basename}.pass.split.vcf" \
      --split-multi-allelics \
      --dont-trim-alleles \
      --keep-original-ac \
      --create-output-variant-index
  >>>
  runtime {
    docker: select_first([gatk_docker_override, "us.gcr.io/broad-gatk/gatk:"+gatk_version])
    memory: machine_mem + " MB"
    disks: "local-disk " + disk_size + " HDD"
    preemptible: select_first([preemptible_tries, 5])
    cpu: select_first([n_cpu,2])
  }
  output {
    File raw_vcf = "out/~{sample_name}~{suffix}.raw.vcf"
    File raw_vcf_idx = "out/~{sample_name}~{suffix}.raw.vcf.idx"
    File stats = "out/~{sample_name}~{suffix}.raw.vcf.stats"
    File output_bamOut = "out/~{sample_name}~{suffix}.bamout.bam"

    File filtered_vcf = "out/~{sample_name}~{suffix}.vcf"
    File filtered_vcf_idx = "out/~{sample_name}~{suffix}.vcf.idx"

    File full_pass_vcf = "out/~{sample_name}~{suffix}.pass.vcf"
    File full_pass_vcf_index = "out/~{sample_name}~{suffix}.pass.vcf.idx"
    Int post_filt_vars = read_int("out/~{sample_name}~{suffix}.passvars.txt")

    File split_vcf = "out/~{sample_name}~{suffix}.pass.split.vcf"
    File split_vcf_index = "out/~{sample_name}~{suffix}.pass.split.vcf.idx"
  }
}

task MongoRunM2InitialFilterSplit {
  input {
    String sample_name
    File input_bam
    File input_bai
    Float? verifyBamID
    String suffix

    File ref_fasta
    File ref_fai
    File ref_dict
    Int max_reads_per_alignment_start = 75
    String? m2_extra_args
    Boolean make_bamout = false
    Boolean compress

    File? mt_interval_list

    Float? vaf_cutoff
    String? m2_extra_filtering_args
    Int max_alt_allele_count
    Float? vaf_filter_threshold
    Float? f_score_beta

    File? blacklisted_sites
    File? blacklisted_sites_index

    String? gatk_docker_override
    File? gatk_override
    String gatk_version
    Int mem
    Int? preemptible_tries
    Int? n_cpu
  }

  Float ref_size = size(ref_fasta, "GB") + size(ref_fai, "GB")
  Int disk_size = (ceil(size(input_bam, "GB") + ref_size) * 2) + 20
  Float defval = 0.0

  Int machine_mem = if defined(mem) then mem * 1000 else 3500
  Int command_mem = machine_mem - 500

  String d = "$"

  command <<<
    set -e

    mkdir out
    this_sample=out/"~{sample_name}"
    this_contamination="~{select_first([verifyBamID, defval])}"
    this_basename="~{d}{this_sample}~{suffix}"
    bamoutfile="~{d}{this_basename}.bamout.bam"
    touch "~{d}{bamoutfile}"
    if [[ ~{make_bamout} == 'true' ]]; then bamoutstr="--bam-output ~{d}{bamoutfile}"; else bamoutstr=""; fi

    gatk --java-options "-Xmx~{command_mem}m" Mutect2 \
      -R ~{ref_fasta} \
      -I ~{input_bam} \
      ~{"-L " + mt_interval_list} \
      -O "~{d}{this_basename}.raw.vcf" \
      ~{m2_extra_args} \
      --annotation StrandBiasBySample \
      --read-filter MateOnSameContigOrNoMappedMateReadFilter \
      --read-filter MateUnmappedAndUnmappedReadFilter \
      --mitochondria-mode \
      --max-reads-per-alignment-start ~{max_reads_per_alignment_start} \
      --max-mnp-distance 0 ~{d}{bamoutstr}

    gatk --java-options "-Xmx~{command_mem}m" FilterMutectCalls -V "~{d}{this_basename}.raw.vcf" \
      -R ~{ref_fasta} \
      -O filtered.vcf \
      --stats "~{d}{this_basename}.raw.vcf.stats" \
      ~{m2_extra_filtering_args} \
      --max-alt-allele-count ~{max_alt_allele_count} \
      --mitochondria-mode \
      ~{"--min-allele-fraction " + vaf_filter_threshold} \
      ~{"--f-score-beta " + f_score_beta} \
      --contamination-estimate "~{d}{this_contamination}"

    ~{"gatk IndexFeatureFile -I " + blacklisted_sites}

    gatk --java-options "-Xmx~{command_mem}m" VariantFiltration -V filtered.vcf \
      -O "~{d}{this_basename}.filtered.vcf" \
      --apply-allele-specific-filters \
      ~{"--mask-name 'blacklisted_site' --mask " + blacklisted_sites}

    gatk --java-options "-Xmx~{command_mem}m" LeftAlignAndTrimVariants \
      -R ~{ref_fasta} \
      -V "~{d}{this_basename}.filtered.vcf" \
      -O split.vcf \
      --split-multi-allelics \
      --dont-trim-alleles \
      --keep-original-ac

    gatk --java-options "-Xmx~{command_mem}m" SelectVariants \
      -V split.vcf \
      -O "~{d}{this_basename}.splitAndPassOnly.vcf" \
      --exclude-filtered
  >>>
  runtime {
    docker: select_first([gatk_docker_override, "us.gcr.io/broad-gatk/gatk:"+gatk_version])
    memory: machine_mem + " MB"
    disks: "local-disk " + disk_size + " HDD"
    preemptible: select_first([preemptible_tries, 5])
    cpu: select_first([n_cpu,2])
  }
  output {
    File raw_vcf = "out/~{sample_name}~{suffix}.raw.vcf"
    File raw_vcf_idx = "out/~{sample_name}~{suffix}.raw.vcf.idx"
    File stats = "out/~{sample_name}~{suffix}.raw.vcf.stats"
    File output_bamOut = "out/~{sample_name}~{suffix}.bamout.bam"

    File filtered_vcf = "out/~{sample_name}~{suffix}.filtered.vcf"
    File filtered_vcf_idx = "out/~{sample_name}~{suffix}.filtered.vcf.idx"

    File vcf_for_haplochecker = "out/~{sample_name}~{suffix}.splitAndPassOnly.vcf"
  }
}

task MongoM2FilterContaminationSplit {
  input {
    File raw_vcf
    File raw_vcf_index
    File raw_vcf_stats
    String sample_name
    String hasContamination
    Float contamination_major
    Float contamination_minor
    Float? verifyBamID

    Boolean run_contamination
    File ref_fasta
    File ref_fai
    File ref_dict

    Boolean compress
    Float? vaf_cutoff
    String suffix

    String? m2_extra_filtering_args
    Int max_alt_allele_count
    Float? vaf_filter_threshold
    Float? f_score_beta

    File? blacklisted_sites
    File? blacklisted_sites_index

    File? gatk_override
    String? gatk_docker_override
    String gatk_version

    Int? preemptible_tries
  }

  Float ref_size = size(ref_fasta, "GB") + size(ref_fai, "GB")
  Int disk_size = ceil(size(raw_vcf, "GB") + ref_size) + 20
  Float defval = 0.0
  String d = "$"

  command <<<
    set -e

    mkdir out

    this_sample=out/"~{sample_name}"
    this_raw_vcf="~{raw_vcf}"
    this_raw_stats="~{raw_vcf_stats}"
    this_has_contam="~{hasContamination}"
    this_verifybam="~{select_first([verifyBamID, defval])}"
    this_contam_major="~{contamination_major}"
    this_contam_minor="~{contamination_minor}"

    this_basename="~{d}{this_sample}~{suffix}"
    bamoutfile="~{d}{this_basename}.bamout.bam"
    touch "~{d}{bamoutfile}"

    if [[ "~{d}{this_has_contam}" == 'YES' ]]; then
      if (( $(echo "~{d}{this_contam_major} == 0.0"|bc -l) )); then
        this_hc_contamination="~{d}{this_contam_minor}"
      else
        this_hc_contamination=$( bc <<< "1-~{d}{this_contam_major}" )
      fi
    else
      this_hc_contamination=0.0
    fi

    echo "~{d}{this_hc_contamination}" > "~{d}{this_basename}.hc_contam.txt"

    if (( $(echo "~{d}{this_verifybam} > ~{d}{this_hc_contamination}"|bc -l) )); then
      this_max_contamination="~{d}{this_verifybam}"
    else
      this_max_contamination="~{d}{this_hc_contamination}"
    fi

    gatk --java-options "-Xmx2500m" FilterMutectCalls \
      -V "~{d}{this_raw_vcf}" \
      -R ~{ref_fasta} \
      -O filtered.vcf \
      --stats "~{d}{this_raw_stats}" \
      ~{m2_extra_filtering_args} \
      --max-alt-allele-count ~{max_alt_allele_count} \
      --mitochondria-mode \
      ~{"--min-allele-fraction " + vaf_filter_threshold} \
      ~{"--f-score-beta " + f_score_beta} \
      --contamination-estimate "~{d}{this_max_contamination}"

    ~{"gatk IndexFeatureFile -I " + blacklisted_sites}

    gatk --java-options "-Xmx2500m" VariantFiltration \
      -V filtered.vcf \
      -O "~{d}{this_basename}.vcf" \
      --apply-allele-specific-filters \
      ~{"--mask-name 'blacklisted_site' --mask " + blacklisted_sites}

    gatk --java-options "-Xmx2500m" LeftAlignAndTrimVariants \
      -R ~{ref_fasta} \
      -V "~{d}{this_basename}.vcf" \
      -O "~{d}{this_basename}.split.vcf" \
      --split-multi-allelics \
      --dont-trim-alleles \
      --keep-original-ac \
      --create-output-variant-index
  >>>
  runtime {
    docker: select_first([gatk_docker_override, "us.gcr.io/broad-gatk/gatk:"+gatk_version])
    memory: "4 MB"
    disks: "local-disk " + disk_size + " HDD"
    preemptible: select_first([preemptible_tries, 5])
    cpu: 2
  }
  output {
    File filtered_vcf = "out/~{sample_name}~{suffix}.vcf"
    File filtered_vcf_idx = "out/~{sample_name}~{suffix}.vcf.idx"
    File split_vcf = "out/~{sample_name}~{suffix}.split.vcf"
    File split_vcf_index = "out/~{sample_name}~{suffix}.split.vcf.idx"
    Float contamination = read_float("out/~{sample_name}~{suffix}.hc_contam.txt")
  }
}
"""

out_dir = Path("./WDL/s002")
out_dir.mkdir(parents=True, exist_ok=True)
wdl_path = out_dir / "stage02_MtOnly_v2.wdl"
wdl_path.write_text(wdl_text)

print(f"Wrote: {wdl_path.resolve()}")


In [None]:
! cat /home/jupyter/workspaces/mtdnaheteroplasmyandaginganalysis/WDL/s002/stage02_MtOnly_v2.wdl

In [None]:
import json
from pathlib import Path

# ---- choose stage01 workflow id (latest) ----
WF_ID = latest_workflow_id("stage01_SubsetCramChrM")
print("Stage01 WF_ID:", WF_ID)

# ---- pull outputs from Cromwell metadata ----
meta = get_wf_metadata(WF_ID, include_keys=["outputs"])
outputs = meta.get("outputs", {})

input_bam = outputs.get("stage01_SubsetCramChrM.final_bam")
input_bai = outputs.get("stage01_SubsetCramChrM.final_bai")
sample_label = outputs.get("stage01_SubsetCramChrM.final_bam", "").split("/")[-1].replace(".proc.bam", "")

if not input_bam or not input_bai:
    raise ValueError("Could not find Stage01 outputs in metadata.")

# parse sample_id / age / sex from filename
# format: <sample_id>_<age>_<sex>_chrM.proc.bam
parts = sample_label.split("_")
sample_id = parts[0] if len(parts) > 0 else ""
age = parts[1] if len(parts) > 1 else None
sex = parts[2] if len(parts) > 2 else None

# ---- mt reference paths ----
mt_fasta = "gs://gcp-public-data--broad-references/hg38/v0/chrM/Homo_sapiens_assembly38.chrM.fasta"
mt_fasta_index = "gs://gcp-public-data--broad-references/hg38/v0/chrM/Homo_sapiens_assembly38.chrM.fasta.fai"
mt_dict = "gs://gcp-public-data--broad-references/hg38/v0/chrM/Homo_sapiens_assembly38.chrM.dict"
mt_interval_list = "gs://gcp-public-data--broad-references/hg38/v0/chrM/chrM.hg38.interval_list"

# ---- nuc + blacklist ----
ref_fasta = "gs://gcp-public-data--broad-references/hg38/v0/Homo_sapiens_assembly38.fasta"
ref_fasta_index = "gs://gcp-public-data--broad-references/hg38/v0/Homo_sapiens_assembly38.fasta.fai"
ref_dict = "gs://gcp-public-data--broad-references/hg38/v0/Homo_sapiens_assembly38.dict"
nuc_interval_list = "gs://gcp-public-data--broad-references/hg38/v0/wgs_calling_regions.hg38.interval_list"

blacklisted_sites = "gs://gcp-public-data--broad-references/hg38/v0/mitochondria/blacklisted_sites.vcf"
blacklisted_sites_index = "gs://gcp-public-data--broad-references/hg38/v0/mitochondria/blacklisted_sites.vcf.idx"

# ---- haplochecker zip (from mtSwirl or your workspace) ----
haplocheck_zip = "gs://fc-secure-76d68a64-00aa-40a7-b2c5-ca956db2719b/refs/haplocheck.zip"

# ---- build inputs ----
inputs = {
    "stage02_MtOnly.input_bam": input_bam,
    "stage02_MtOnly.input_bai": input_bai,
    "stage02_MtOnly.sample_id": sample_id,
    "stage02_MtOnly.age": age,
    "stage02_MtOnly.sex": sex,

    "stage02_MtOnly.ref_dict": ref_dict,
    "stage02_MtOnly.ref_fasta": ref_fasta,
    "stage02_MtOnly.ref_fasta_index": ref_fasta_index,

    "stage02_MtOnly.mt_dict": mt_dict,
    "stage02_MtOnly.mt_fasta": mt_fasta,
    "stage02_MtOnly.mt_fasta_index": mt_fasta_index,
    "stage02_MtOnly.blacklisted_sites": blacklisted_sites,
    "stage02_MtOnly.blacklisted_sites_index": blacklisted_sites_index,

    "stage02_MtOnly.nuc_interval_list": nuc_interval_list,
    "stage02_MtOnly.mt_interval_list": mt_interval_list,

    "stage02_MtOnly.mt_mean_coverage": outputs.get("stage01_SubsetCramChrM.mean_coverage", 0),

    "stage02_MtOnly.docker": "kchewe/mtdna-tools:0.1.0",
    "stage02_MtOnly.haplocheck_zip": haplocheck_zip,

    "stage02_MtOnly.use_haplotype_caller_nucdna": True,
    "stage02_MtOnly.hc_dp_lower_bound": 10,
    "stage02_MtOnly.vaf_filter_threshold": 0.01,
    "stage02_MtOnly.f_score_beta": 1.0,
    "stage02_MtOnly.compress_output_vcf": False,
    "stage02_MtOnly.n_cpu": 2,
    "stage02_MtOnly.preemptible_tries": 5,
}

out_path = Path("./WDL/s002/stage02_MtOnly.inputs_v2.json")
out_path.parent.mkdir(parents=True, exist_ok=True)
out_path.write_text(json.dumps(inputs, indent=2) + "\n")
print("Wrote:", out_path.resolve())
print("sample_id:", sample_id, "age:", age, "sex:", sex)


#### Monitor Stage02 

In [None]:
import subprocess
from pathlib import Path
import requests

start_cromwell()

wdl_path = Path("./WDL/s002/stage02_MtOnly_v2.wdl")
json_path = Path("./WDL/s002/stage02_MtOnly.inputs_v2.json")

# Validate WDL
validate_cmd = (
    "bash -lc 'source /home/jupyter/.sdkman/bin/sdkman-init.sh "
    "&& sdk use java 17.0.8-tem "
    f"&& java -jar womtool-91.jar validate {wdl_path}'"
)
print("$", validate_cmd)
subprocess.run(validate_cmd, shell=True, check=True)

# Submit to Cromwell
cromwell_url = "http://localhost:8094/api/workflows/v1"
files = {
    "workflowSource": wdl_path.open("rb"),
    "workflowInputs": json_path.open("rb"),
}
print("Submitting to:", cromwell_url)
resp = requests.post(cromwell_url, files=files, headers={"accept": "application/json"})
resp.raise_for_status()
wf = resp.json()
print("Response:", wf)
print("WF_ID:", wf.get("id"))


In [None]:
# Monitor + logs
wf_id = wf.get("id")
print(wf_id)

print(wf_id)
status = wait_for_wf(wf_id, poll_s=30, timeout_s=7200)
print("Final status:", status)

pretty(get_wf_metadata(wf_id, include_keys=["failures", "callRoot"]))
fetch_task_logs_from_gcs(wf_id)