## mtDNA notebook
### mitoClock 

## Configuration and Setup

In [None]:
# SET GLOBAL VARIABLES
# NOTE: 
# The Cloud Life Sciences (GLS) API is expired !
# Batch (GCB) migration in the All of Us (AOU) Workbench occurred in July 8, 2025
# For migration details, 
# Refer to: https://cloud.google.com/batch/docs/migrate-to-batch-from-cloud-life-sciences. and here

import os
from pathlib import Path

WORKSPACE_BUCKET = os.getenv("WORKSPACE_BUCKET", "").rstrip("/")
GOOGLE_PROJECT = os.getenv("GOOGLE_PROJECT", "")
PET_SA_EMAIL = os.getenv("PET_SA_EMAIL", "")

outputFold = os.getenv("outputFold", "mtDNA_v25_pilot_5")
PORTID = int(os.getenv("PORTID", "8094"))
USE_MEM = int(os.getenv("USE_MEM", "32"))
SQL_DB_NAME = os.getenv("SQL_DB_NAME", "local_cromwell_run.db")

PROJECT_ROOT = Path(os.getenv("PROJECT_ROOT", "/mnt/f/research_drive/mtdna/leelab/mtDNA-analysis")).resolve()

print("WORKSPACE_BUCKET:", WORKSPACE_BUCKET)
print("GOOGLE_PROJECT:", GOOGLE_PROJECT)
print("PET_SA_EMAIL:", PET_SA_EMAIL)
print("PROJECT_ROOT:", PROJECT_ROOT)
print("PORTID:", PORTID, "USE_MEM:", USE_MEM)


In [None]:
import os
import shutil
import subprocess
import sys
from pathlib import Path

def run(cmd, check=True):
    print(f"$ {cmd}")
    return subprocess.run(
        cmd, shell=True, check=check,
        stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
        text=True
    ).stdout.strip()

# --- Dependency checks ---
checks = {
    "python": sys.executable,
    "pip": shutil.which("pip") or shutil.which("pip3") or "",
    "gsutil": shutil.which("gsutil") or "",
    "gcloud": shutil.which("gcloud") or "",
    "java": shutil.which("java") or "",
    "sdkman": shutil.which("sdk") or "",
}

print("Dependency check:")
for k, v in checks.items():
    print(f"  {k:8s} -> {v if v else 'MISSING'}")

if checks["java"]:
    print("\nJava version:")
    print(run("java -version", check=False))

if checks["gcloud"]:
    print("\nGcloud auth:")
    print(run("gcloud auth list --format='value(account)'", check=False))

# Optional: install pyhocon if missing
try:
    import pyhocon  # noqa: F401
    print("\npyhocon: OK")
except ImportError:
    print("\npyhocon: missing")
    # Uncomment to install
    # print(run(f"{sys.executable} -m pip install pyhocon", check=True))

# --- Optional bootstrap tools (comment out if not needed) ---
home = Path.home()
sdkman_dir = home / ".sdkman"

if not sdkman_dir.exists():
    print("\nInstalling SDKMAN...")
    run("curl -s https://get.sdkman.io -o install_sdkman.sh", check=True)
    run("bash install_sdkman.sh", check=True)
else:
    print("\nSDKMAN already installed.")

sdkman_init = sdkman_dir / "bin" / "sdkman-init.sh"
if sdkman_init.exists():
    run(f"bash -lc 'source {sdkman_init} && sdk install java 17.0.8-tem || true'", check=True)
    run(f"bash -lc 'source {sdkman_init} && sdk use java 17.0.8-tem'", check=True)
else:
    print("SDKMAN init script not found; skipping Java install.")

# Cromwell/WOMtool 91
if not Path("cromwell-91.jar").exists():
    print("Downloading cromwell-91.jar")
    run("curl -L https://github.com/broadinstitute/cromwell/releases/download/91/cromwell-91.jar -o cromwell-91.jar", check=True)
else:
    print("cromwell-91.jar already present.")

if not Path("womtool-91.jar").exists():
    print("Downloading womtool-91.jar")
    run("curl -L https://github.com/broadinstitute/cromwell/releases/download/91/womtool-91.jar -o womtool-91.jar", check=True)
else:
    print("womtool-91.jar already present.")

# Heap size to use in start_cromwell()
CROMWELL_HEAP_GB = 32
print("CROMWELL_HEAP_GB set to", CROMWELL_HEAP_GB)


## CROMWELL SERVER FUNCTIONS 
- START UP
- CHECK STATUS
- GET LOGS

In [None]:
# some very helpful wrappers 
import json
import time
import subprocess
import os
from pathlib import Path
import requests

# ---- Config ----
CROMWELL_PORT = 8094
CROMWELL_STATUS_URL = f"http://localhost:{CROMWELL_PORT}/engine/v1/status"
CROMWELL_API = f"http://localhost:{CROMWELL_PORT}/api/workflows/v1"
CROMWELL_CONF = Path("/home/jupyter/cromwell.conf")
CROMWELL_JAR = Path("cromwell-91.jar")  # upgraded (81 was buggy and slow)
SDKMAN_INIT = "/home/jupyter/.sdkman/bin/sdkman-init.sh"
JAVA_VER = "17.0.8-tem"
CROMWELL_HEAP_GB = 32  # new

STDOUT_LOG = Path("cromwell_server_stdout.log")
STDERR_LOG = Path("cromwell_server_stderr.log")
PID_FILE = Path("cromwell_server.pid")

DB_BASE = Path("/home/jupyter/cromwell_db/local_cromwell_run.db")
DB_DATA = Path(str(DB_BASE) + ".data")

# ---- Variables ----
_last_restart = 0

# ---- Helpers ----
def cromwell_up():
    try:
        r = requests.get(CROMWELL_STATUS_URL, timeout=2)
        return r.ok
    except Exception:
        return False

def cromwell_pid_running():
    if not PID_FILE.exists():
        return False
    try:
        pid = int(PID_FILE.read_text().strip())
        os.kill(pid, 0)
        return True
    except Exception:
        return False

def cromwell_healthy():
    return cromwell_pid_running() and cromwell_up()

def cromwell_persistent_ok():
    return DB_DATA.exists() and DB_DATA.stat().st_size > 0

def cromwell_persistent_recent(max_age_s=300):
    if not cromwell_persistent_ok():
        return False
    age = time.time() - DB_DATA.stat().st_mtime
    return age <= max_age_s

def start_cromwell():
    if cromwell_healthy():
        print("Cromwell already running and healthy.")
        if cromwell_persistent_ok():
            print("Persistence check: DB data file OK.")
            if cromwell_persistent_recent():
                print("DB was updated recently.")
        return

    cmd = (
        f"bash -lc 'source {SDKMAN_INIT} && sdk use java {JAVA_VER} && "
        f"nohup java -Xmx{CROMWELL_HEAP_GB}g "
        f"-Dconfig.file={CROMWELL_CONF} -Dwebservice.port={CROMWELL_PORT} "
        f"-jar {CROMWELL_JAR} server > {STDOUT_LOG} 2> {STDERR_LOG} & "
        f"echo $! > {PID_FILE} && disown'"
    )
    print("$", cmd)
    subprocess.run(cmd, shell=True, check=True)

    for _ in range(30):
        if cromwell_up():
            print("Cromwell is up.")
            if cromwell_persistent_ok():
                print("Persistence check: DB data file OK.")
                if cromwell_persistent_recent():
                    print("DB was updated recently.")
            return
        time.sleep(2)

    raise RuntimeError("Cromwell did not start. Check stderr/stdout logs.")

def tail_logs(n=50):
    if STDOUT_LOG.exists():
        print(f"--- {STDOUT_LOG} (last {n}) ---")
        print("\n".join(STDOUT_LOG.read_text().splitlines()[-n:]))
    else:
        print(f"{STDOUT_LOG} not found.")
    if STDERR_LOG.exists():
        print(f"--- {STDERR_LOG} (last {n}) ---")
        print("\n".join(STDERR_LOG.read_text().splitlines()[-n:]))
    else:
        print(f"{STDERR_LOG} not found.")

def pretty(obj):
    print(json.dumps(obj, indent=2))

def get_wf_status(wf_id, retries=10, sleep_s=2):
    url = f"{CROMWELL_API}/{wf_id}/status"
    last_err = None
    for _ in range(retries):
        r = requests.get(url)
        if r.status_code == 200:
            return r.json()
        if r.status_code == 404:
            time.sleep(sleep_s)
            continue
        last_err = r
        break
    if last_err is not None:
        last_err.raise_for_status()
    raise RuntimeError(f"Workflow {wf_id} not found after {retries} retries.")

def get_wf_metadata(wf_id, include_keys=None):
    url = f"{CROMWELL_API}/{wf_id}/metadata"
    if include_keys:
        for k in include_keys:
            url += f"&includeKey={k}" if "?" in url else f"?includeKey={k}"
    r = requests.get(url)
    r.raise_for_status()
    return r.json()

def wait_for_wf(wf_id, poll_s=5, timeout_s=600):
    global _last_restart
    deadline = time.time() + timeout_s
    while time.time() < deadline:
        try:
            status = get_wf_status(wf_id).get("status")
            print("Status:", status)
            if status in ("Succeeded", "Failed", "Aborted"):
                return status
        except Exception:
            now = time.time()
            if now - _last_restart > 30:
                print("Cromwell not reachable; restarting...")
                start_cromwell()
                _last_restart = now
        time.sleep(poll_s)
    raise TimeoutError(f"Workflow {wf_id} did not finish within {timeout_s}s")

def latest_workflow_id(wdl_name=None, status=None):
    params = {"page": 1, "pagesize": 20}
    if wdl_name:
        params["name"] = wdl_name
    if status:
        params["status"] = status

    r = requests.get(f"{CROMWELL_API}/query", params=params)
    if r.status_code != 200:
        payload = {"page": 1, "pagesize": 20}
        if wdl_name:
            payload["name"] = wdl_name
        if status:
            payload["status"] = status
        r = requests.post(f"{CROMWELL_API}/query", json=payload)

    r.raise_for_status()
    results = r.json().get("results", [])
    if not results:
        return None
    results.sort(key=lambda x: x.get("submission", ""), reverse=True)
    return results[0].get("id")

def get_callroots(wf_id):
    meta = get_wf_metadata(wf_id, include_keys=["callRoot", "calls"])
    callroots = []
    calls = meta.get("calls", {})
    for call_name, entries in calls.items():
        for e in entries:
            if "callRoot" in e:
                callroots.append((call_name, e["callRoot"]))
    return callroots

def fetch_task_logs_from_gcs(wf_id, call_name=None):
    callroots = get_callroots(wf_id)
    if not callroots:
        print("No callRoot entries found.")
        return

    for name, root in callroots:
        if call_name and call_name != name:
            continue
        stdout = f"{root}/stdout"
        stderr = f"{root}/stderr"
        print(f"\nCall: {name}")
        print("stdout:", stdout)
        print("stderr:", stderr)
        subprocess.run(f"gsutil cat {stdout} | tail -n 50", shell=True, check=False)
        subprocess.run(f"gsutil cat {stderr} | tail -n 50", shell=True, check=False)

def latest_workflow_id_gcs(workspace_bucket, workflow_name):
    cmd = f"gsutil ls -l {workspace_bucket}/workflows/cromwell-executions/{workflow_name}/"
    out = subprocess.check_output(cmd, shell=True, text=True)
    lines = [l for l in out.splitlines() if l.strip().startswith("gs://")]
    if not lines:
        return None
    lines.sort()
    latest = lines[-1].split()[-1].rstrip("/")
    return latest.split("/")[-1]


### Unit Tests 

In [None]:
start_cromwell()

In [None]:
print("cromwell_up():", cromwell_up())
print("cromwell_pid_running():", cromwell_pid_running())
print("cromwell_healthy():", cromwell_healthy())

start_cromwell()

print("cromwell_up():", cromwell_up())
print("cromwell_pid_running():", cromwell_pid_running())
print("cromwell_healthy():", cromwell_healthy())

print("cromwell_persistent_ok():", cromwell_persistent_ok())
print("cromwell_persistent_recent():", cromwell_persistent_recent())

tail_logs(20)

# Optional: latest workflow ID (if any exist)
print("latest_workflow_id():", latest_workflow_id())



In [None]:
print("latest_workflow_id():", latest_workflow_id())

In [None]:
# search the bucket for latest ID by name 
latest_workflow_id_gcs(WORKSPACE_BUCKET, "stage01_SubsetCramChrM")

In [None]:
# get logs from tail 
N=50
tail_logs(n=N)

 ##  Run mode using Google Batch   
 

#### [IMPORTANT] Cromwell configuration

In [None]:
# upgrades: add system tuning for throughput
from pathlib import Path

CROMWELL_DB = "/home/jupyter/cromwell_db/local_cromwell_run.db"
Path("/home/jupyter/cromwell_db").mkdir(parents=True, exist_ok=True)

cromwell_conf = f"""include required(classpath("application"))

google {{
  application-name = "cromwell"
  auths = [{{
    name = "application_default"
    scheme = "application_default"
  }}]
}}

system {{
  new-workflow-poll-rate = 1
  max-concurrent-workflows = 50
  max-workflow-launch-count = 400
  job-rate-control {{
    jobs = 100
    per = "3 seconds"
  }}
}}

backend {{
  default = "GCPBATCH"
  providers {{
    Local.config.root = "/dev/null"

    GCPBATCH {{
      actor-factory = "cromwell.backend.google.batch.GcpBatchBackendLifecycleActorFactory"
      config {{
        project = "{GOOGLE_PROJECT}"
        concurrent-job-limit = 20
        root = "{WORKSPACE_BUCKET}/workflows/cromwell-executions"

        virtual-private-cloud {{
          network-name = "projects/{GOOGLE_PROJECT}/global/networks/network"
          subnetwork-name = "projects/{GOOGLE_PROJECT}/regions/us-central1/subnetworks/subnetwork"
        }}

        batch {{
          auth = "application_default"
          compute-service-account = "{PET_SA_EMAIL}"
          location = "us-central1"
        }}

        default-runtime-attributes {{
          noAddress: true
        }}

        filesystems {{
          gcs {{
            auth = "application_default"
          }}
        }}
      }}
    }}
  }}
}}

database {{
  profile = "slick.jdbc.HsqldbProfile$"
  insert-batch-size = 6000
  db {{
    driver = "org.hsqldb.jdbcDriver"
    url = "jdbc:hsqldb:file:{CROMWELL_DB};shutdown=false;hsqldb.default_table_type=cached;hsqldb.tx=mvcc;hsqldb.large_data=true;hsqldb.lob_compressed=true;hsqldb.script_format=3;hsqldb.result_max_memory_rows=20000"
    connectionTimeout = 300000
  }}
}}
"""

Path("/home/jupyter/cromwell.conf").write_text(cromwell_conf)
print("Wrote /home/jupyter/cromwell.conf")


In [None]:
!ls ~/cromwell.conf


In [None]:
!cat ~/cromwell.conf

## WDL Testing
### Hello World 

In [None]:
# a small tutorial on how to submit jobs for processing 
from pathlib import Path
test_dir = Path("./WDL/test")
test_dir.mkdir(parents=True, exist_ok=True)
print(f"Created: {test_dir.resolve()}")


In [None]:
from pathlib import Path

wdl_text = """\
version 1.0

workflow HelloWorld {
  call HelloTask
  output {
    String msg = HelloTask.out
  }
}

task HelloTask {
  input {
    String name
  }
  command <<<
    echo "Hello, ~{name}!"
  >>>
  output {
    String out = read_string(stdout())
  }
  runtime {
    docker: "ubuntu:22.04"
  }
}
"""

wdl_path = Path("./WDL/test/hello.wdl")
wdl_path.write_text(wdl_text)
print(f"Wrote: {wdl_path.resolve()}")


In [None]:
import json
from pathlib import Path

inputs = {
  "HelloWorld.HelloTask.name": "World"
}

json_path = Path("./WDL/test/hello.inputs.json")
json_path.write_text(json.dumps(inputs, indent=2) + "\n")
print(f"Wrote: {json_path.resolve()}")


In [None]:
# import subprocess
# from pathlib import Path
# import requests

# wdl_path = Path("./WDL/test/hello.wdl")
# json_path = Path("./WDL/test/hello.inputs.json")

# # Validate WDL with womtool (Java 17)
# validate_cmd = (
#     "bash -lc 'source /home/jupyter/.sdkman/bin/sdkman-init.sh "
#     "&& sdk use java 17.0.8-tem "
#     "&& java -jar womtool-91.jar validate "
#     f"{wdl_path}'"
# )
# print("$", validate_cmd)
# subprocess.run(validate_cmd, shell=True, check=True)

# # Submit to Cromwell
# cromwell_url = "http://localhost:8094/api/workflows/v1"
# files = {
#     "workflowSource": wdl_path.open("rb"),
#     "workflowInputs": json_path.open("rb"),
# }
# print("Submitting to:", cromwell_url)
# resp = requests.post(cromwell_url, files=files, headers={"accept": "application/json"})
# resp.raise_for_status()
# wf = resp.json()
# print("Response:", wf)


In [None]:
# import time 
# wf_id = wf.get("id")

# start_time = time.perf_counter()
# status = wait_for_wf(wf_id)
# print("Final status:", status)
# fetch_task_logs_from_gcs(wf_id)
# end_time = time.perf_counter()

# elapsed_time = end_time - start_time
# print(f"Elapsed time: {elapsed_time:.4f} seconds")


## Download or Generate Metadata 

In [None]:
# Download mtdna metadata 
# Contains age, gender, sex and .cram/ .crai paths in google workspace
import gzip
import subprocess
from pathlib import Path

def ensure_mtdna_tsv(download=False, print_head=False):
    out_dir = Path("data/metadata")
    out_dir.mkdir(parents=True, exist_ok=True)

    url = "https://raw.githubusercontent.com/Kaychewe/mtDNA-analysis/meta/mtdna_mitoclock_aou_dataset_36246309_person_age_gender_crams.tsv.gz"
    path = out_dir / "mtdna_mitoclock_aou_dataset_36246309_person_age_gender_crams.tsv.gz"

    if download or not path.exists():
        subprocess.run(["curl", "-L", url, "-o", str(path)], check=True)
        print("Downloaded:", path.resolve())
    else:
        print("Using existing:", path.resolve())

    # gzip check + header
    with gzip.open(path, "rt") as f:
        header = f.readline().strip()
        first = f.readline().strip() if print_head else None

    print("Gzip OK.")
    print("Header:", header)
    if print_head:
        print("First row:", first)

    return path, header

# Example usage:
#ensure_mtdna_tsv(download=False, print_head=True)

def load_and_sort_by_person_id(path, limit=None):
    with gzip.open(path, "rt") as f:
        header = f.readline().strip().split("\t")
        rows = [line.strip().split("\t") for line in f if line.strip()]

    # find person_id column
    pid_idx = header.index("person_id")
    rows.sort(key=lambda r: int(r[pid_idx]))

    if limit:
        rows = rows[:limit]

    return header, rows


In [None]:
# expecting age, gender, sex, cram, cram_id 
path = ensure_mtdna_tsv(download=False, print_head=True)[0]
header, rows = load_and_sort_by_person_id(path, limit=2)
print("Header:", header)
print("First 10 rows:")
for r in rows:
    print(r)

In [None]:
import pandas as pd
path = ensure_mtdna_tsv(download=False, print_head=True)[0]
df = pd.read_csv(path, sep="\t", compression="gzip")
df = df.sort_values("person_id").reset_index(drop=True)
print(df.head(2))


## Step 01: CRAM -> BAM/SAM (chrM + NUMT)
Goal:
- Start from a WGS CRAM and subset to chrM + NUMT intervals.
- Clean and standardize reads (remove broken mates, revert, mark duplicates, sort).
- Produce final outputs for downstream analysis.

Inputs:
- CRAM/CRAI for one sample.
- Reference FASTA + index + dict.
- chrM interval list + NUMT interval list.
- Optional metadata (age/sex) to label outputs.

Outputs:
- Final processed BAM + BAI
- Final SAM (from processed BAM)
- Unmapped BAM
- Duplicate metrics + coverage stats

#### Configuration

In [None]:
import json
from pathlib import Path

# ---- Select sample ----
SELECT_PERSON_ID = 1000004
SELECT_ROW_INDEX = 0

if SELECT_PERSON_ID is not None:
    row = df.loc[df["person_id"] == SELECT_PERSON_ID].iloc[0]
else:
    row = df.iloc[SELECT_ROW_INDEX]

sample_id = str(row["person_id"])
age = str(row["age"]) if "age" in row and not pd.isna(row["age"]) else None
sex = str(row["sex_at_birth"]) if "sex_at_birth" in row and not pd.isna(row["sex_at_birth"]) else None

# ---- Required reference paths ----
ref_fasta = "gs://gcp-public-data--broad-references/hg38/v0/Homo_sapiens_assembly38.fasta"

# ---- Build inputs ----
inputs = {
    "stage01_SubsetCramChrM.sample_id": sample_id,
    "stage01_SubsetCramChrM.age": age,
    "stage01_SubsetCramChrM.sex": sex,
    "stage01_SubsetCramChrM.input_cram": row["cram_uri"],
    "stage01_SubsetCramChrM.input_crai": row["cram_index_uri"],
    "stage01_SubsetCramChrM.ref_fasta": ref_fasta,
    "stage01_SubsetCramChrM.docker": "kchewe/mtdna-samtools:conda-latest",
}

out_path = Path("./WDL/s001/stage01_SubsetCramChrM.inputs.json")
out_path.write_text(json.dumps(inputs, indent=2) + "\n")
print("Wrote:", out_path.resolve())
print("Selected sample:", sample_id, "age:", age, "sex:", sex)


In [None]:
# inspect JSON 
!cat /home/jupyter/workspaces/mtdnaheteroplasmyandaginganalysis/WDL/s001/stage01_SubsetCramChrM.inputs.json

In [None]:
! gsutil -u $GOOGLE_PROJECT ls -l gs://gcp-public-data--broad-references/hg38/v0/chrM/chrM.hg38.interval_list

In [None]:
from pathlib import Path

wdl_text = """\
version 1.0

workflow stage01_SubsetCramChrM {
  meta {
    description: "Lightweight samtools chrM-only subset: emit BAM/BAI/SAM."
  }

  input {
    String sample_id
    String? age
    String? sex
    File input_cram
    File input_crai
    File ref_fasta
    String docker
  }

  call SubsetChrM_Samtools {
    input:
      sample_id = sample_id,
      age = age,
      sex = sex,
      input_cram = input_cram,
      input_crai = input_crai,
      ref_fasta = ref_fasta,
      docker = docker
  }

  output {
    File final_bam = SubsetChrM_Samtools.final_bam
    File final_bai = SubsetChrM_Samtools.final_bai
    File final_sam = SubsetChrM_Samtools.final_sam
  }
}

task SubsetChrM_Samtools {
  input {
    String sample_id
    String? age
    String? sex
    File input_cram
    File input_crai
    File ref_fasta
    String docker
  }

  String age_label = select_first([age, "NA"])
  String sex_label = select_first([sex, "NA"])
  String prefix = sample_id + "_" + age_label + "_" + sex_label + "_chrM"

  command <<<
    set -e
    mkdir -p out
    
    # simplify the header 
    # samtools reheader command 

    # Subset chrM by contig name
    samtools view -T "~{ref_fasta}" -b "~{input_cram}" chrM -o "out/~{prefix}.bam"

    # Index BAM
    samtools index "out/~{prefix}.bam"

    # Export SAM
    samtools view -h "out/~{prefix}.bam" > "out/~{prefix}.sam"
  >>>

  runtime {
    docker: "~{docker}"
    memory: "8 GB"
    disks: "local-disk 200 HDD"
    bootDiskSizeGb: 50
  }

  output {
    File final_bam = "out/~{prefix}.bam"
    File final_bai = "out/~{prefix}.bam.bai"
    File final_sam = "out/~{prefix}.sam"
  }
}
"""

out_dir = Path("./WDL/s001")
out_dir.mkdir(parents=True, exist_ok=True)
wdl_path = out_dir / "stage01_SubsetCramChrM.wdl"
wdl_path.write_text(wdl_text)

print(f"Wrote: {wdl_path.resolve()}")


In [None]:
!cat /home/jupyter/workspaces/mtdnaheteroplasmyandaginganalysis/WDL/s001/stage01_SubsetCramChrM.wdl


#### Submit Stage01

In [None]:
import subprocess
from pathlib import Path
import requests

# Submit stage01
# Extract chrM
# Updates from meeting 2/13
# Simplified filenaming conventions
# <sample_id>_<age>_<sex>.<bam|sam>

start_cromwell()

wdl_path = Path("./WDL/s001/stage01_SubsetCramChrM.wdl")
json_path = Path("./WDL/s001/stage01_SubsetCramChrM.inputs.json")

# Validate WDL
validate_cmd = (
    "bash -lc 'source /home/jupyter/.sdkman/bin/sdkman-init.sh "
    "&& sdk use java 17.0.8-tem "
    f"&& java -jar womtool-91.jar validate {wdl_path}'"
)
print("$", validate_cmd)
subprocess.run(validate_cmd, shell=True, check=True)

# Submit to Cromwell
cromwell_url = "http://localhost:8094/api/workflows/v1"
files = {
    "workflowSource": wdl_path.open("rb"),
    "workflowInputs": json_path.open("rb"),
}
print("Submitting to:", cromwell_url)
resp = requests.post(cromwell_url, files=files, headers={"accept": "application/json"})
resp.raise_for_status()
wf = resp.json()
print("Response:", wf)


In [None]:
# Monitor + logs
wf_id = wf.get("id")
print(wf_id)


#### Monitor Stage 01

In [None]:
# RUN TIME MUST BE RECORDED 
# WRITE WDL TO INSPECT THE OUTPUTS
# OR GSUITL 
# OR PYTHON COMMNADS IN JUPYTER 
# Note dynamic cormwell PID resolver works correctly 
print(wf_id)
status = wait_for_wf(wf_id, poll_s=30, timeout_s=7200)
print("Final status:", status)

pretty(get_wf_metadata(wf_id, include_keys=["failures", "callRoot"]))
fetch_task_logs_from_gcs(wf_id)

In [None]:
# uncomment to inpsect outputs for stage01 
# expected BAMs, SAMs
# Note use prior wf_id 
# stdout and stderr maybe avaiable
# ! gsutil ls -lh gs://fc-secure-76d68a64-00aa-40a7-b2c5-ca956db2719b/workflows/cromwell-executions/stage01_SubsetCramChrM/f38d54dd-1c21-481f-88c8-2582feede570/**

### Validations 

In [None]:
# get logs from tail
# 
N=50
tail_logs(n=N)

In [None]:
## LOGS 
## RESOLVED
## Failed run due to params 
## WF_ID=af51294d-673e-4100-a3ff-742d992417f3
! gsutil -u $GOOGLE_PROJECT cat gs://fc-secure-76d68a64-00aa-40a7-b2c5-ca956db2719b/workflows/cromwell-executions/stage01_SubsetCramChrM/af51294d-673e-4100-a3ff-742d992417f3/call-SubsetAndProcessChrM/stderr

In [None]:
! gsutil -u $GOOGLE_PROJECT cat gs://fc-secure-76d68a64-00aa-40a7-b2c5-ca956db2719b/workflows/cromwell-executions/stage01_SubsetCramChrM/1c02f1eb-3144-4afb-941c-bb18b8f6de0b/call-SubsetAndProcessChrM/script
    

In [None]:
! gsutil -u $GOOGLE_PROJECT ls gs://fc-secure-76d68a64-00aa-40a7-b2c5-ca956db2719b/workflows/cromwell-executions/stage01_SubsetCramChrM/cb919b1f-204e-43ec-9d58-e23ca3c42b92/call-SubsetAndProcessChrM/**

In [None]:
! gsutil -u $GOOGLE_PROJECT ls gs://fc-secure-76d68a64-00aa-40a7-b2c5-ca956db2719b/workflows/cromwell-executions/stage01_SubsetCramChrM/56f3c33b-2fcd-49d7-98fe-3b9ce4256695/**

### Review BAMS and SAMs 

In [None]:
wf_id

## Stage 02 mtdna VCFs

#### Configuration

In [None]:
import json
from pathlib import Path
import requests

sep = "\n # --------------------------------- "

# ---- choose stage01 workflow id (latest) ----
WF_ID = latest_workflow_id("stage01_SubsetCramChrM")
print("Stage01 WF_ID:", WF_ID)

# ---- pull outputs from Cromwell metadata ----
meta = get_wf_metadata(WF_ID, include_keys=["outputs"])
outputs = meta.get("outputs", {})
print(sep)
print(outputs)
print(sep)

# These keys match the samtools-only stage01 WDL
input_bam = outputs.get("stage01_SubsetCramChrM.final_bam")
input_bai = outputs.get("stage01_SubsetCramChrM.final_bai")

if not input_bam or not input_bai:
    raise ValueError("Could not find stage01 outputs in metadata. Check output key names.")

# Extract sample_id, age, sex from filename: <sample>_<age>_<sex>_chrM.bam
fname = input_bam.split("/")[-1]                # 1000004_85_Male_chrM.bam
parts = fname.split("_")
sample_id = parts[0] if len(parts) > 0 else ""
age = parts[1] if len(parts) > 1 else None
sex = parts[2] if len(parts) > 2 else None

print(input_bam)
print(input_bai)
print(sample_id, age, sex)

# ---- mt reference paths ----
mt_fasta = "gs://gcp-public-data--broad-references/hg38/v0/chrM/Homo_sapiens_assembly38.chrM.fasta"
mt_fasta_index = "gs://gcp-public-data--broad-references/hg38/v0/chrM/Homo_sapiens_assembly38.chrM.fasta.fai"
mt_dict = "gs://gcp-public-data--broad-references/hg38/v0/chrM/Homo_sapiens_assembly38.chrM.dict"
mt_interval_list = "gs://gcp-public-data--broad-references/hg38/v0/chrM/chrM.hg38.interval_list"

# ---- build inputs ----
inputs = {
    "stage02_MtOnly.input_bam": input_bam,
    "stage02_MtOnly.input_bai": input_bai,
    "stage02_MtOnly.sample_id": sample_id,
    "stage02_MtOnly.age": age,
    "stage02_MtOnly.sex": sex,
    "stage02_MtOnly.mt_fasta": mt_fasta,
    "stage02_MtOnly.mt_fasta_index": mt_fasta_index,
    "stage02_MtOnly.mt_dict": mt_dict,
    "stage02_MtOnly.mt_interval_list": mt_interval_list,
    "stage02_MtOnly.gatk_docker": "kchewe/mtdna-stage04:0.1.3",
    "stage02_MtOnly.mem_gb": 8,
    "stage02_MtOnly.n_cpu": 2,
}

out_path = Path("./WDL/s002/stage02_MtOnly.inputs.json")
out_path.parent.mkdir(parents=True, exist_ok=True)
out_path.write_text(json.dumps(inputs, indent=2) + "\n")
print("Wrote:", out_path.resolve())


In [None]:
! cat /home/jupyter/workspaces/mtdnaheteroplasmyandaginganalysis/WDL/s002/stage02_MtOnly.inputs.json

In [None]:
from pathlib import Path

wdl_text = """\
version 1.0

workflow stage02_MtOnly {
  meta {
    description: "Stage02 simplified: mtDNA-only variant calling (Mutect2 + FilterMutectCalls)."
  }

  input {
    File input_bam
    File input_bai
    String sample_id
    String? age
    String? sex

    File mt_fasta
    File mt_fasta_index
    File mt_dict
    File mt_interval_list

    String gatk_docker
    Int? mem_gb
    Int? n_cpu
  }

  call CallMtVariants {
    input:
      input_bam = input_bam,
      input_bai = input_bai,
      sample_id = sample_id,
      age = age,
      sex = sex,
      mt_fasta = mt_fasta,
      mt_fasta_index = mt_fasta_index,
      mt_dict = mt_dict,
      mt_interval_list = mt_interval_list,
      gatk_docker = gatk_docker,
      mem_gb = mem_gb,
      n_cpu = n_cpu
  }

  output {
    File out_vcf = CallMtVariants.out_vcf
    File out_vcf_index = CallMtVariants.out_vcf_index
  }
}

task CallMtVariants {
  input {
    File input_bam
    File input_bai
    String sample_id
    String? age
    String? sex

    File mt_fasta
    File mt_fasta_index
    File mt_dict
    File mt_interval_list

    String gatk_docker
    Int? mem_gb
    Int? n_cpu
  }

  String age_label = select_first([age, "NA"])
  String sex_label = select_first([sex, "NA"])
  String prefix = sample_id + "_" + age_label + "_" + sex_label + "_mt"

  command <<<
    set -e
    mkdir -p out

    gatk Mutect2 \
      -R "~{mt_fasta}" \
      -I "~{input_bam}" \
      -L "~{mt_interval_list}" \
      --mitochondria-mode \
      -O "out/~{prefix}.raw.vcf"

    gatk FilterMutectCalls \
      -R "~{mt_fasta}" \
      -V "out/~{prefix}.raw.vcf" \
      -O "out/~{prefix}.filtered.vcf"
  >>>

  runtime {
    docker: "~{gatk_docker}"
    memory: select_first([mem_gb, 8]) + " GB"
    cpu: select_first([n_cpu, 2])
    disks: "local-disk 200 HDD"
    bootDiskSizeGb: 50
  }

  output {
    File out_vcf = "out/~{prefix}.filtered.vcf"
    File out_vcf_index = "out/~{prefix}.filtered.vcf.idx"
  }
}
"""

out_dir = Path("./WDL/s002")
out_dir.mkdir(parents=True, exist_ok=True)
wdl_path = out_dir / "stage02_MtOnly.wdl"
wdl_path.write_text(wdl_text)

print(f"Wrote: {wdl_path.resolve()}")


In [None]:
! cat /home/jupyter/workspaces/mtdnaheteroplasmyandaginganalysis/WDL/s002/stage02_MtOnly.wdl

#### Submit Stage02

In [None]:
import subprocess
from pathlib import Path
import requests

# Submit stage02
# Simplified filenaming conventions
# <sample_id>_<age>_<sex>_mt.filtered.vcf

start_cromwell()

wdl_path = Path("./WDL/s002/stage02_MtOnly.wdl")
json_path = Path("./WDL/s002/stage02_MtOnly.inputs.json")

# Validate WDL
validate_cmd = (
    "bash -lc 'source /home/jupyter/.sdkman/bin/sdkman-init.sh "
    "&& sdk use java 17.0.8-tem "
    f"&& java -jar womtool-91.jar validate {wdl_path}'"
)
print("$", validate_cmd)
subprocess.run(validate_cmd, shell=True, check=True)

# Submit to Cromwell
cromwell_url = "http://localhost:8094/api/workflows/v1"
files = {
    "workflowSource": wdl_path.open("rb"),
    "workflowInputs": json_path.open("rb"),
}
print("Submitting to:", cromwell_url)
resp = requests.post(cromwell_url, files=files, headers={"accept": "application/json"})
resp.raise_for_status()
wf = resp.json()
print("Response:", wf)


#### Monitor Stage02

In [None]:
# Monitor + logs
wf_id = wf.get("id")
print(wf_id)

In [None]:
print(wf_id)
status = wait_for_wf(wf_id, poll_s=30, timeout_s=7200)
print("Final status:", status)

pretty(get_wf_metadata(wf_id, include_keys=["failures", "callRoot"]))
fetch_task_logs_from_gcs(wf_id)

In [None]:
! gsutil ls gs://fc-secure-76d68a64-00aa-40a7-b2c5-ca956db2719b/workflows/cromwell-executions/**/a5f85f2f-65c6-442f-b21c-72624f3c3762/**

In [None]:
! gsutil cat gs://fc-secure-76d68a64-00aa-40a7-b2c5-ca956db2719b/workflows/cromwell-executions/stage02_MtOnly/a5f85f2f-65c6-442f-b21c-72624f3c3762/call-CallMtVariants/out/1000004_85_Male_mt.filtered.vcf | awk 'BEGIN{OFS="\t"} !/^#/ {print $1,$2,$4,$5,$6,$7,$8,$9,$10}'


In [None]:
! gsutil cat gs://fc-secure-76d68a64-00aa-40a7-b2c5-ca956db2719b/workflows/cromwell-executions/stage02_MtOnly/a5f85f2f-65c6-442f-b21c-72624f3c3762/call-CallMtVariants/out/1000004_85_Male_mt.filtered.vcf  | grep -v '^#' | wc -l


## Stage 03 Normalize/annotate/filter mt variants, apply VAF thresholds

#### Configuration

In [None]:
import json
from pathlib import Path

# ---- Stage02 outputs ----
WF_ID = latest_workflow_id("stage02_MtOnly")
meta = get_wf_metadata(WF_ID, include_keys=["outputs"])
outputs = meta.get("outputs", {})

input_vcf = outputs.get("stage02_MtOnly.out_vcf")
input_vcf_index = outputs.get("stage02_MtOnly.out_vcf_index")

if not input_vcf or not input_vcf_index:
    raise ValueError("Could not find stage02 outputs in metadata. Check output key names.")

# Extract sample_id/age/sex from filename
fname = input_vcf.split("/")[-1]   # e.g. 1000004_85_Male_mt.filtered.vcf
parts = fname.split("_")
sample_id = parts[0] if len(parts) > 0 else ""
age = parts[1] if len(parts) > 1 else None
sex = parts[2] if len(parts) > 2 else None

inputs = {
    "stage03_MtFilter.input_vcf": input_vcf,
    "stage03_MtFilter.input_vcf_index": input_vcf_index,
    "stage03_MtFilter.sample_id": sample_id,
    "stage03_MtFilter.age": age,
    "stage03_MtFilter.sex": sex,
    "stage03_MtFilter.vaf_min": 0.01,
    "stage03_MtFilter.docker": "kchewe/mtdna-tools:0.1.0",
}

out_path = Path("./WDL/s003/stage03_MtFilter.inputs.json")
out_path.parent.mkdir(parents=True, exist_ok=True)
out_path.write_text(json.dumps(inputs, indent=2) + "\n")
print("Wrote:", out_path.resolve())


In [None]:
! cat /home/jupyter/workspaces/mtdnaheteroplasmyandaginganalysis/WDL/s003/stage03_MtFilter.inputs.json

In [None]:
! gsutil ls gs://fc-secure-76d68a64-00aa-40a7-b2c5-ca956db2719b/workflows/cromwell-executions/**/out/**.vcf.gz

In [None]:
!gsutil cat gs://fc-secure-76d68a64-00aa-40a7-b2c5-ca956db2719b/workflows/cromwell-executions/stage01_SubsetCramChrM/3cb6f132-8c37-4b31-a460-289d537506ee/call-SubsetChrM_Samtools/out/1000406_25_Female_chrM.sam | grep -v "@head -n 10 

In [None]:
# TODO: directly compute heteroplasmy 
# TODO: collect into one dir by type 

from pathlib import Path

wdl_text = """\
version 1.0

workflow stage03_MtFilter {
  meta {
    description: "Stage03 simplified: filter mtDNA variants by VAF and emit VCF + TSV."
  }

  input {
    File input_vcf
    File input_vcf_index
    String sample_id
    String? age
    String? sex
    Float vaf_min = 0.01
    String docker
  }

  call FilterToTsv {
    input:
      input_vcf = input_vcf,
      input_vcf_index = input_vcf_index,
      sample_id = sample_id,
      age = age,
      sex = sex,
      vaf_min = vaf_min,
      docker = docker
  }

  output {
    File out_vcf = FilterToTsv.out_vcf
    File out_tsv = FilterToTsv.out_tsv
  }
}

task FilterToTsv {
  input {
    File input_vcf
    File input_vcf_index
    String sample_id
    String? age
    String? sex
    Float vaf_min
    String docker
  }

  String age_label = select_first([age, "NA"])
  String sex_label = select_first([sex, "NA"])
  String prefix = sample_id + "_" + age_label + "_" + sex_label + "_mt"

  command <<<
    set -e
    mkdir -p out

    # Filter by VAF (AF in FORMAT)
    bcftools view -i "FORMAT/AF>=~{vaf_min}" "~{input_vcf}" -Oz -o "out/~{prefix}.vaf~{vaf_min}.vcf.gz"
    tabix -p vcf "out/~{prefix}.vaf~{vaf_min}.vcf.gz"

    # TSV output
    bcftools query -f '%CHROM\\t%POS\\t%REF\\t%ALT\\t%QUAL\\t%FILTER\\t%INFO/DP\\t%FORMAT/AF\\n' \
      "out/~{prefix}.vaf~{vaf_min}.vcf.gz" > "out/~{prefix}.vaf~{vaf_min}.tsv"
  >>>

  runtime {
    docker: "~{docker}"
    memory: "8 GB"
    cpu: 2
    disks: "local-disk 200 HDD"
    bootDiskSizeGb: 50
  }

  output {
    File out_vcf = "out/~{prefix}.vaf~{vaf_min}.vcf.gz"
    File out_tsv = "out/~{prefix}.vaf~{vaf_min}.tsv"
  }
}
"""

out_dir = Path("./WDL/s003")
out_dir.mkdir(parents=True, exist_ok=True)
wdl_path = out_dir / "stage03_MtFilter.wdl"
wdl_path.write_text(wdl_text)

print(f"Wrote: {wdl_path.resolve()}")


In [None]:
! cat /home/jupyter/workspaces/mtdnaheteroplasmyandaginganalysis/WDL/s003/stage03_MtFilter.wdl

#### Submit Stage03

In [None]:
import subprocess
from pathlib import Path
import requests

# Submit stage03
# Outputs:
# <sample_id>_<age>_<sex>_mt.vaf0.01.vcf.gz
# <sample_id>_<age>_<sex>_mt.vaf0.01.tsv

start_cromwell()

wdl_path = Path("./WDL/s003/stage03_MtFilter.wdl")
json_path = Path("./WDL/s003/stage03_MtFilter.inputs.json")

# Validate WDL
validate_cmd = (
    "bash -lc 'source /home/jupyter/.sdkman/bin/sdkman-init.sh "
    "&& sdk use java 17.0.8-tem "
    f"&& java -jar womtool-91.jar validate {wdl_path}'"
)
print("$", validate_cmd)
subprocess.run(validate_cmd, shell=True, check=True)

# Submit to Cromwell
cromwell_url = "http://localhost:8094/api/workflows/v1"
files = {
    "workflowSource": wdl_path.open("rb"),
    "workflowInputs": json_path.open("rb"),
}
print("Submitting to:", cromwell_url)
resp = requests.post(cromwell_url, files=files, headers={"accept": "application/json"})
resp.raise_for_status()
wf = resp.json()
print("Response:", wf)


#### Monitor Stage03 

In [None]:
# Monitor + logs
wf_id = wf.get("id")
print(wf_id)

In [None]:
print(wf_id)
status = wait_for_wf(wf_id, poll_s=30, timeout_s=7200)
print("Final status:", status)

pretty(get_wf_metadata(wf_id, include_keys=["failures", "callRoot"]))
fetch_task_logs_from_gcs(wf_id)

In [None]:
! gsutil ls gs://fc-secure-76d68a64-00aa-40a7-b2c5-ca956db2719b/workflows/cromwell-executions/stage03_MtFilter/dc972013-ceb9-45a3-b4f6-76ef369887a2/call-FilterToTsv/**

In [None]:
! gsutil cat gs://fc-secure-76d68a64-00aa-40a7-b2c5-ca956db2719b/workflows/cromwell-executions/stage03_MtFilter/dc972013-ceb9-45a3-b4f6-76ef369887a2/call-FilterToTsv/out/1000004_85_Male_mt.vaf0.01.tsv 

In [None]:
! gsutil cat gs://fc-secure-76d68a64-00aa-40a7-b2c5-ca956db2719b/workflows/cromwell-executions/stage03_MtFilter/dc972013-ceb9-45a3-b4f6-76ef369887a2/call-FilterToTsv/out/1000004_85_Male_mt.vaf0.01.tsv | wc -l

## BATCH RUN (10 samples per age group)

In [None]:
start_cromwell()

In [None]:
cromwell_up()

#### Test Batch Submission 

In [None]:
import json
import time
from pathlib import Path

# ---- Todo -------
# compute heter

# ---- age bins ----
# age_bins = [(18,39), (40,59), (60,79), (80,120)]
age_bins = [(60,79), (80,120)]

# ---- config ----
N_PER_BIN = 100
BATCH_SIZE = 20

DOCKER_SAMTOOLS = "kchewe/mtdna-samtools:conda-latest"
DOCKER_GATK = "kchewe/mtdna-stage04:0.1.3"
DOCKER_TOOLS = "kchewe/mtdna-tools:0.1.0"

ref_fasta = "gs://gcp-public-data--broad-references/hg38/v0/Homo_sapiens_assembly38.fasta"
mt_fasta = "gs://gcp-public-data--broad-references/hg38/v0/chrM/Homo_sapiens_assembly38.chrM.fasta"
mt_fasta_index = "gs://gcp-public-data--broad-references/hg38/v0/chrM/Homo_sapiens_assembly38.chrM.fasta.fai"
mt_dict = "gs://gcp-public-data--broad-references/hg38/v0/chrM/Homo_sapiens_assembly38.chrM.dict"
mt_interval_list = "gs://gcp-public-data--broad-references/hg38/v0/chrM/chrM.hg38.interval_list"

# ---- output log ----
log_path = Path("./WDL/batch_run_log.tsv")
if not log_path.exists():
    log_path.write_text(
        "bin\tperson_id\t"
        "s1_wf\ts1_status\ts1_bam\ts1_bai\t"
        "s2_wf\ts2_status\ts2_vcf\ts2_vcf_idx\t"
        "s3_wf\ts3_status\ts3_vcf\ts3_tsv\n"
    )

def read_processed_ids():
    processed = set()
    lines = log_path.read_text().strip().splitlines()
    for line in lines[1:]:
        parts = line.split("\t")
        if len(parts) > 1:
            processed.add(parts[1])
    return processed

def submit_workflow(wdl_path, inputs_path):
    cromwell_url = "http://localhost:8094/api/workflows/v1"
    files = {
        "workflowSource": open(wdl_path, "rb"),
        "workflowInputs": open(inputs_path, "rb"),
    }
    r = requests.post(cromwell_url, files=files, headers={"accept": "application/json"})
    r.raise_for_status()
    return r.json()["id"]

def wait_many(wf_map, label):
    """wf_map: sample_id -> wf_id"""
    pending = dict(wf_map)
    statuses = {}
    while pending:
        for sid, wf_id in list(pending.items()):
            status = get_wf_status(wf_id).get("status")
            print(f"[{label}] {sid} status={status}")
            if status in ("Succeeded", "Failed", "Aborted"):
                statuses[sid] = status
                pending.pop(sid)
        if pending:
            time.sleep(30)
    return statuses

def chunks(lst, n):
    for i in range(0, len(lst), n):
        yield lst[i:i+n]

processed_ids = read_processed_ids()
print("Already processed:", len(processed_ids))

# ---- main loop ----
for lo, hi in age_bins:
    print(f"\n=== AGE BIN {lo}-{hi} ===")

    group = df[(df["age"] >= lo) & (df["age"] <= hi)]
    group = group[~group["person_id"].astype(str).isin(processed_ids)]
    group = group.head(N_PER_BIN)
    print(f"Selected {len(group)} samples after skipping processed")

    rows = list(group.to_dict(orient="records"))

    for batch in chunks(rows, BATCH_SIZE):
        print(f"\n--- Submitting batch of {len(batch)} samples ---")

        # ---------- Stage01 batch ----------
        s1_wf_map = {}
        for row in batch:
            sample_id = str(row["person_id"])
            age = str(row["age"])
            sex = str(row["sex_at_birth"]) if "sex_at_birth" in row and not pd.isna(row["sex_at_birth"]) else None

            s1_inputs = {
                "stage01_SubsetCramChrM.sample_id": sample_id,
                "stage01_SubsetCramChrM.age": age,
                "stage01_SubsetCramChrM.sex": sex,
                "stage01_SubsetCramChrM.input_cram": row["cram_uri"],
                "stage01_SubsetCramChrM.input_crai": row["cram_index_uri"],
                "stage01_SubsetCramChrM.ref_fasta": ref_fasta,
                "stage01_SubsetCramChrM.docker": DOCKER_SAMTOOLS,
            }
            s1_json = Path("./WDL/s001/stage01_SubsetCramChrM.inputs.json")
            s1_json.write_text(json.dumps(s1_inputs, indent=2) + "\n")
            print(f"[Stage01] submitting {sample_id}")
            s1_wf = submit_workflow("./WDL/s001/stage01_SubsetCramChrM.wdl", s1_json)
            s1_wf_map[sample_id] = s1_wf
            print(f"[Stage01] {sample_id} WF_ID={s1_wf}")

        s1_status = wait_many(s1_wf_map, "Stage01")

        # ---------- Stage02 batch ----------
        s2_wf_map = {}
        s1_outputs = {}
        for row in batch:
            sample_id = str(row["person_id"])
            age = str(row["age"])
            sex = str(row["sex_at_birth"]) if "sex_at_birth" in row and not pd.isna(row["sex_at_birth"]) else None

            if s1_status.get(sample_id) != "Succeeded":
                print(f"[Stage02] skip {sample_id} (Stage01 {s1_status.get(sample_id)})")
                continue

            s1_out = get_wf_metadata(s1_wf_map[sample_id], include_keys=["outputs"])["outputs"]
            bam = s1_out["stage01_SubsetCramChrM.final_bam"]
            bai = s1_out["stage01_SubsetCramChrM.final_bai"]
            s1_outputs[sample_id] = (bam, bai)

            s2_inputs = {
                "stage02_MtOnly.input_bam": bam,
                "stage02_MtOnly.input_bai": bai,
                "stage02_MtOnly.sample_id": sample_id,
                "stage02_MtOnly.age": age,
                "stage02_MtOnly.sex": sex,
                "stage02_MtOnly.mt_fasta": mt_fasta,
                "stage02_MtOnly.mt_fasta_index": mt_fasta_index,
                "stage02_MtOnly.mt_dict": mt_dict,
                "stage02_MtOnly.mt_interval_list": mt_interval_list,
                "stage02_MtOnly.gatk_docker": DOCKER_GATK,
                "stage02_MtOnly.mem_gb": 8,
                "stage02_MtOnly.n_cpu": 2,
            }
            s2_json = Path("./WDL/s002/stage02_MtOnly.inputs.json")
            s2_json.write_text(json.dumps(s2_inputs, indent=2) + "\n")
            print(f"[Stage02] submitting {sample_id}")
            s2_wf = submit_workflow("./WDL/s002/stage02_MtOnly.wdl", s2_json)
            s2_wf_map[sample_id] = s2_wf
            print(f"[Stage02] {sample_id} WF_ID={s2_wf}")

        s2_status = wait_many(s2_wf_map, "Stage02")

        # ---------- Stage03 batch ----------
        s3_wf_map = {}
        s2_outputs = {}
        for row in batch:
            sample_id = str(row["person_id"])
            age = str(row["age"])
            sex = str(row["sex_at_birth"]) if "sex_at_birth" in row and not pd.isna(row["sex_at_birth"]) else None

            if s2_status.get(sample_id) != "Succeeded":
                print(f"[Stage03] skip {sample_id} (Stage02 {s2_status.get(sample_id)})")
                continue

            s2_out = get_wf_metadata(s2_wf_map[sample_id], include_keys=["outputs"])["outputs"]
            vcf = s2_out["stage02_MtOnly.out_vcf"]
            vcf_idx = s2_out["stage02_MtOnly.out_vcf_index"]
            s2_outputs[sample_id] = (vcf, vcf_idx)

            s3_inputs = {
                "stage03_MtFilter.input_vcf": vcf,
                "stage03_MtFilter.input_vcf_index": vcf_idx,
                "stage03_MtFilter.sample_id": sample_id,
                "stage03_MtFilter.age": age,
                "stage03_MtFilter.sex": sex,
                "stage03_MtFilter.vaf_min": 0.01,
                "stage03_MtFilter.docker": DOCKER_TOOLS,
            }
            s3_json = Path("./WDL/s003/stage03_MtFilter.inputs.json")
            s3_json.write_text(json.dumps(s3_inputs, indent=2) + "\n")
            print(f"[Stage03] submitting {sample_id}")
            s3_wf = submit_workflow("./WDL/s003/stage03_MtFilter.wdl", s3_json)
            s3_wf_map[sample_id] = s3_wf
            print(f"[Stage03] {sample_id} WF_ID={s3_wf}")

        s3_status = wait_many(s3_wf_map, "Stage03")

        # ---------- Logging ----------
        for row in batch:
            sample_id = str(row["person_id"])
            s1_wf = s1_wf_map.get(sample_id, "")
            s2_wf = s2_wf_map.get(sample_id, "")
            s3_wf = s3_wf_map.get(sample_id, "")

            s1_stat = s1_status.get(sample_id, "")
            s2_stat = s2_status.get(sample_id, "")
            s3_stat = s3_status.get(sample_id, "")

            bam, bai = s1_outputs.get(sample_id, ("", ""))
            vcf, vcf_idx = s2_outputs.get(sample_id, ("", ""))

            out_vcf = ""
            out_tsv = ""
            if s3_stat == "Succeeded":
                s3_out = get_wf_metadata(s3_wf_map[sample_id], include_keys=["outputs"])["outputs"]
                out_vcf = s3_out.get("stage03_MtFilter.out_vcf", "")
                out_tsv = s3_out.get("stage03_MtFilter.out_tsv", "")

            line = f"{lo}-{hi}\t{sample_id}\t{s1_wf}\t{s1_stat}\t{bam}\t{bai}\t{s2_wf}\t{s2_stat}\t{vcf}\t{vcf_idx}\t{s3_wf}\t{s3_stat}\t{out_vcf}\t{out_tsv}\n"
            log_path.write_text(log_path.read_text() + line)

print("All bins done.")


#### Batch Submission 10 samples per age group

In [None]:
import json
import time
from pathlib import Path

# ---- age bins ----
# decades: 18–19, 20s, 30s, ... 90s, 100–120
age_bins = [(18,19)] + [(d, d+9) for d in range(20, 100, 10)] + [(100,120)]

# ---- config ----
N_PER_BIN = 100 # per sample group
BATCH_SIZE = 20 # run in batches of 20
POLL_S = 30
HEARTBEAT_S = 300  # print a heartbeat every 5 min even if no status change

DOCKER_SAMTOOLS = "kchewe/mtdna-samtools:conda-latest"
DOCKER_GATK = "kchewe/mtdna-stage04:0.1.3"
DOCKER_TOOLS = "kchewe/mtdna-tools:0.1.0"

ref_fasta = "gs://gcp-public-data--broad-references/hg38/v0/Homo_sapiens_assembly38.fasta"
mt_fasta = "gs://gcp-public-data--broad-references/hg38/v0/chrM/Homo_sapiens_assembly38.chrM.fasta"
mt_fasta_index = "gs://gcp-public-data--broad-references/hg38/v0/chrM/Homo_sapiens_assembly38.chrM.fasta.fai"
mt_dict = "gs://gcp-public-data--broad-references/hg38/v0/chrM/Homo_sapiens_assembly38.chrM.dict"
mt_interval_list = "gs://gcp-public-data--broad-references/hg38/v0/chrM/chrM.hg38.interval_list"

# ---- output log ----
log_path = Path("./WDL/batch_run_log.tsv")
if not log_path.exists():
    log_path.write_text(
        "bin\tperson_id\t"
        "s1_wf\ts1_status\ts1_bam\ts1_bai\t"
        "s2_wf\ts2_status\ts2_vcf\ts2_vcf_idx\t"
        "s3_wf\ts3_status\ts3_vcf\ts3_tsv\n"
    )

def read_processed_ids():
    processed = set()
    lines = log_path.read_text().strip().splitlines()
    for line in lines[1:]:
        parts = line.split("\t")
        if len(parts) > 1:
            processed.add(parts[1])
    return processed

def submit_workflow(wdl_path, inputs_path):
    cromwell_url = "http://localhost:8094/api/workflows/v1"
    files = {
        "workflowSource": open(wdl_path, "rb"),
        "workflowInputs": open(inputs_path, "rb"),
    }
    r = requests.post(cromwell_url, files=files, headers={"accept": "application/json"})
    r.raise_for_status()
    return r.json()["id"]

def wait_many(wf_map, label):
    pending = dict(wf_map)
    statuses = {}
    last_seen = {sid: None for sid in pending}
    last_hb = time.time()

    while pending:
        for sid, wf_id in list(pending.items()):
            status = get_wf_status(wf_id).get("status")
            if status != last_seen[sid]:
                print(f"[{label}] {sid} status={status}")
                last_seen[sid] = status
            if status in ("Succeeded", "Failed", "Aborted"):
                statuses[sid] = status
                pending.pop(sid)

        now = time.time()
        if pending and (now - last_hb) >= HEARTBEAT_S:
            print(f"[{label}] heartbeat: {len(pending)} still running")
            last_hb = now

        if pending:
            time.sleep(POLL_S)
    return statuses

def chunks(lst, n):
    for i in range(0, len(lst), n):
        yield lst[i:i+n]

processed_ids = read_processed_ids()
print("Already processed:", len(processed_ids))

# ---- main loop ----
for lo, hi in age_bins:
    print(f"\n=== AGE BIN {lo}-{hi} ===")

    group = df[(df["age"] >= lo) & (df["age"] <= hi)]
    group = group[~group["person_id"].astype(str).isin(processed_ids)]
    group = group.head(N_PER_BIN)
    print(f"Selected {len(group)} samples after skipping processed")

    rows = list(group.to_dict(orient="records"))

    for batch in chunks(rows, BATCH_SIZE):
        print(f"\n--- Submitting batch of {len(batch)} samples ---")

        # ---------- Stage01 batch ----------
        s1_wf_map = {}
        for row in batch:
            sample_id = str(row["person_id"])
            age = str(row["age"])
            sex = str(row["sex_at_birth"]) if "sex_at_birth" in row and not pd.isna(row["sex_at_birth"]) else None

            s1_inputs = {
                "stage01_SubsetCramChrM.sample_id": sample_id,
                "stage01_SubsetCramChrM.age": age,
                "stage01_SubsetCramChrM.sex": sex,
                "stage01_SubsetCramChrM.input_cram": row["cram_uri"],
                "stage01_SubsetCramChrM.input_crai": row["cram_index_uri"],
                "stage01_SubsetCramChrM.ref_fasta": ref_fasta,
                "stage01_SubsetCramChrM.docker": DOCKER_SAMTOOLS,
            }
            s1_json = Path("./WDL/s001/stage01_SubsetCramChrM.inputs.json")
            s1_json.write_text(json.dumps(s1_inputs, indent=2) + "\n")
            print(f"[Stage01] submitting {sample_id}")
            s1_wf = submit_workflow("./WDL/s001/stage01_SubsetCramChrM.wdl", s1_json)
            s1_wf_map[sample_id] = s1_wf
            print(f"[Stage01] {sample_id} WF_ID={s1_wf}")

        s1_status = wait_many(s1_wf_map, "Stage01")

        # ---------- Stage02 batch ----------
        s2_wf_map = {}
        s1_outputs = {}
        for row in batch:
            sample_id = str(row["person_id"])
            age = str(row["age"])
            sex = str(row["sex_at_birth"]) if "sex_at_birth" in row and not pd.isna(row["sex_at_birth"]) else None

            if s1_status.get(sample_id) != "Succeeded":
                print(f"[Stage02] skip {sample_id} (Stage01 {s1_status.get(sample_id)})")
                continue

            s1_out = get_wf_metadata(s1_wf_map[sample_id], include_keys=["outputs"])["outputs"]
            bam = s1_out["stage01_SubsetCramChrM.final_bam"]
            bai = s1_out["stage01_SubsetCramChrM.final_bai"]
            s1_outputs[sample_id] = (bam, bai)

            s2_inputs = {
                "stage02_MtOnly.input_bam": bam,
                "stage02_MtOnly.input_bai": bai,
                "stage02_MtOnly.sample_id": sample_id,
                "stage02_MtOnly.age": age,
                "stage02_MtOnly.sex": sex,
                "stage02_MtOnly.mt_fasta": mt_fasta,
                "stage02_MtOnly.mt_fasta_index": mt_fasta_index,
                "stage02_MtOnly.mt_dict": mt_dict,
                "stage02_MtOnly.mt_interval_list": mt_interval_list,
                "stage02_MtOnly.gatk_docker": DOCKER_GATK,
                "stage02_MtOnly.mem_gb": 8,
                "stage02_MtOnly.n_cpu": 2,
            }
            s2_json = Path("./WDL/s002/stage02_MtOnly.inputs.json")
            s2_json.write_text(json.dumps(s2_inputs, indent=2) + "\n")
            print(f"[Stage02] submitting {sample_id}")
            s2_wf = submit_workflow("./WDL/s002/stage02_MtOnly.wdl", s2_json)
            s2_wf_map[sample_id] = s2_wf
            print(f"[Stage02] {sample_id} WF_ID={s2_wf}")

        s2_status = wait_many(s2_wf_map, "Stage02")

        # ---------- Stage03 batch ----------
        s3_wf_map = {}
        s2_outputs = {}
        for row in batch:
            sample_id = str(row["person_id"])
            age = str(row["age"])
            sex = str(row["sex_at_birth"]) if "sex_at_birth" in row and not pd.isna(row["sex_at_birth"]) else None

            if s2_status.get(sample_id) != "Succeeded":
                print(f"[Stage03] skip {sample_id} (Stage02 {s2_status.get(sample_id)})")
                continue

            s2_out = get_wf_metadata(s2_wf_map[sample_id], include_keys=["outputs"])["outputs"]
            vcf = s2_out["stage02_MtOnly.out_vcf"]
            vcf_idx = s2_out["stage02_MtOnly.out_vcf_index"]
            s2_outputs[sample_id] = (vcf, vcf_idx)

            s3_inputs = {
                "stage03_MtFilter.input_vcf": vcf,
                "stage03_MtFilter.input_vcf_index": vcf_idx,
                "stage03_MtFilter.sample_id": sample_id,
                "stage03_MtFilter.age": age,
                "stage03_MtFilter.sex": sex,
                "stage03_MtFilter.vaf_min": 0.01,
                "stage03_MtFilter.docker": DOCKER_TOOLS,
            }
            s3_json = Path("./WDL/s003/stage03_MtFilter.inputs.json")
            s3_json.write_text(json.dumps(s3_inputs, indent=2) + "\n")
            print(f"[Stage03] submitting {sample_id}")
            s3_wf = submit_workflow("./WDL/s003/stage03_MtFilter.wdl", s3_json)
            s3_wf_map[sample_id] = s3_wf
            print(f"[Stage03] {sample_id} WF_ID={s3_wf}")

        s3_status = wait_many(s3_wf_map, "Stage03")

        # ---------- Logging ----------
        for row in batch:
            sample_id = str(row["person_id"])
            s1_wf = s1_wf_map.get(sample_id, "")
            s2_wf = s2_wf_map.get(sample_id, "")
            s3_wf = s3_wf_map.get(sample_id, "")

            s1_stat = s1_status.get(sample_id, "")
            s2_stat = s2_status.get(sample_id, "")
            s3_stat = s3_status.get(sample_id, "")

            bam, bai = s1_outputs.get(sample_id, ("", ""))
            vcf, vcf_idx = s2_outputs.get(sample_id, ("", ""))

            out_vcf = ""
            out_tsv = ""
            if s3_stat == "Succeeded":
                s3_out = get_wf_metadata(s3_wf_map[sample_id], include_keys=["outputs"])["outputs"]
                out_vcf = s3_out.get("stage03_MtFilter.out_vcf", "")
                out_tsv = s3_out.get("stage03_MtFilter.out_tsv", "")

            line = f"{lo}-{hi}\t{sample_id}\t{s1_wf}\t{s1_stat}\t{bam}\t{bai}\t{s2_wf}\t{s2_stat}\t{vcf}\t{vcf_idx}\t{s3_wf}\t{s3_stat}\t{out_vcf}\t{out_tsv}\n"
            log_path.write_text(log_path.read_text() + line)

print("All bins done.")


In [None]:
! gsutil ls  gs://fc-secure-76d68a64-00aa-40a7-b2c5-ca956db2719b/workflows/cromwell-executions/**/6b17c985-d6c2-46b0-9d00-ece5058d5d3a/**

In [None]:
cromwell_up()

### mtDNA Variant EDA 

In [None]:
# load tsv 
# parse age, sex, gender 
# parse vcfs 
# parse tsv 
mtDNA_data_path = "./WDL/batch_run_log.tsv"