## 1. Identify potential regions from aDNA VCF files

- We have calculated this in the adna/_species_/adna_regions_contig.bed file.
- These are regions the aDNA is confident enough to call as Reference Calls or Variant Calls.
- These are mapped and align to the putative ancestral genome.
- We have lifted them over to the a9 Hoiho genome, but don't need that. We can use the ancestral genome and extract that.
- Find regions between waitaha and richdalei (the aDNA samples) that intersect, so we can use them from BEAST.

In [None]:
import polars as pl
from cyvcf2 import VCF

metadata = pl.read_csv(
    "../Hoiho_Genomes_24Feb2024_JGG_3Pops.csv", separator="\t"
)

In [None]:
# sample = sample_metadata.select("ID").to_pandas().values[0][0]
# Add Location and Population3 to the sample name
# sample_metadata = sample_metadata.select("Location", "Population3")
# sample_metadata = sample_metadata.to_pandas()
# sample = sample + "_" + sample_metadata["Location"].values[0] + "_" + sample_metadata["Population3"].values
# Convert to str without the ndarray stuff
# sample = str(sample[0])

# For all, let's get sample id conversions
sample_id_map = {}

for row in metadata.iter_rows(named=True):
    sample = row["ID"]
    location = row["Location"]
    population = row["Population3"]
    new_sample = f"{sample}_{location}_{population}"
    sample_id_map[sample] = new_sample

# Now add waitaha and richdalei
sample_id_map["waitaha"] = "waitaha_waitaha_waitaha"
sample_id_map["richdalei"] = "richdalei_richdalei_richdalei"

In [None]:
# Get all species from halstats file
seabirds = pl.read_csv("../seabird_alignment_halstats", skip_lines=4)['GenomeName'].to_list()
seabirds = [s for s in seabirds if s is not None]
seabirds = [s for s in seabirds if not s.startswith("Anc")]
# Filter out c90 (it's subantarctic islands and we have them from the regular pop)
seabirds = [s for s in seabirds if not s.startswith("c90")]
seabirds = [s for s in seabirds if not s.startswith("a9")] # We have a9 in the SNPs as well
# Remove Megadyptes_antipodes
seabirds = [s for s in seabirds if not s.startswith("Megadyptesantipodes")]
# Assert a9 is not in seabirds
assert "a9" not in seabirds, "a9 should not be in seabirds"

penguin_prefixes = (
    "Aptenodytes",   # king & emperor
    "Spheniscus",    # banded penguins
    "Pygoscelis",    # brush‑tails
    "Eudyptula",     # little penguins
    "Eudyptes"       # crested penguins (includes Eudyptesmoseleyi)
)

penguins = [sp for sp in seabirds if sp.startswith(penguin_prefixes)]

# For each of the seabirds, add them to the sample_id_mapping
# The id is the same as the sample, but the mapped id we should:
# - Remove the _genomic suffix
# - Then repeat the 3 times == species_species_species
for sp in seabirds:
    if sp.endswith("_genomic"):
        mapped_id = sp[:-9]  # Remove the _genomic suffix
    else:
        mapped_id = sp
    sample_id_map[sp] = f"{mapped_id}_{mapped_id}_{mapped_id}"

In [None]:
new_columns = ["Strands", "AncContig", "AncLength", "AncStart", "AncEnd", "ModernContig", "ModernLength", "ModernStart", "ModernEnd"]

In [None]:
import polars as pl

# If your PSL has the 5-line BLAT header, set skip_rows=5; otherwise keep 0.
skip_rows = 0  # or 5

# PSL columns we need (0-based indices)
cols = [8, 9, 10, 11, 12, 13, 14, 17, 18, 19, 20]
names = [
    "Strands", "AncContig", "AncLength", "AncStart", "AncEnd",
    "ModernContig", "ModernLength", "BlockCount", "BlockSizes", "QStarts", "TStarts"
]

psl = pl.read_csv(
    "adna/waitaha/adna_regions_contig_liftover_a9.psl",
    separator="\t",
    has_header=False,
    skip_rows=skip_rows,
    columns=cols,
    new_columns=names,
    infer_schema_length=0,
    ignore_errors=True,  # tolerate occasional short lines
    # compression="gzip",  # uncomment if your PSL is gzipped
)

# Clean trailing commas, split lists, cast to int, and explode.
# Use only features present across older Polars versions.
waitaha_blocks = (
    psl
    .with_columns([
        pl.col("BlockCount").cast(pl.Int64),
        pl.col("BlockSizes").fill_null("").str.replace(r',+$', '', literal=False).str.split(",").alias("bs_str"),
        pl.col("QStarts").fill_null("").str.replace(r',+$', '', literal=False).str.split(",").alias("qs_str"),
        pl.col("TStarts").fill_null("").str.replace(r',+$', '', literal=False).str.split(",").alias("ts_str"),
    ])
    # If your Polars is older (<0.19), swap `.list.eval(...)` for `.arr.eval(...)` in the next 3 lines.
    .with_columns([
        pl.col("bs_str").list.eval(pl.element().cast(pl.Int64)).alias("bs"),
        pl.col("qs_str").list.eval(pl.element().cast(pl.Int64)).alias("qs"),
        pl.col("ts_str").list.eval(pl.element().cast(pl.Int64)).alias("ts"),
    ])
    .drop(["bs_str", "qs_str", "ts_str"])
    .with_row_index(name="_rid")          # keep the source PSL row id (handy for grouping/joins)
    .explode(["bs", "qs", "ts"])          # one row per ungapped block
    .filter(                              
        pl.all_horizontal(
            pl.col("bs").is_not_null(),
            pl.col("qs").is_not_null(),
            pl.col("ts").is_not_null()
        )
    )
    .with_columns([
        # 0-based, half-open block intervals (PSL/BED style)
        pl.col("qs").alias("AncStart0"),
        (pl.col("qs") + pl.col("bs")).alias("AncEnd0"),
        pl.col("ts").alias("ModernStart0"),
        (pl.col("ts") + pl.col("bs")).alias("ModernEnd0"),
        # 1-based starts (useful for VCF POS)
        (pl.col("qs") + 1).alias("AncStart1"),
        (pl.col("ts") + 1).alias("ModernStart1"),
        (pl.col("qs") + pl.col("bs") + 1).alias("AncEnd1"),
        (pl.col("ts") + pl.col("bs") + 1).alias("ModernEnd1"),
    ])
    .select([
        "Strands", "AncContig", "AncLength", "AncStart", "AncEnd",
        "ModernContig", "ModernLength", "BlockCount",
        "AncStart0", "AncEnd0", "ModernStart0", "ModernEnd0",
        "AncStart1", "ModernStart1", "ModernEnd1",
        "bs", "_rid",
    ])
)

In [None]:
import polars as pl

# If your PSL has the 5-line BLAT header, set skip_rows=5; otherwise keep 0.
skip_rows = 0  # or 5

# PSL columns we need (0-based indices)
cols = [8, 9, 10, 11, 12, 13, 14, 17, 18, 19, 20]
names = [
    "Strands", "AncContig", "AncLength", "AncStart", "AncEnd",
    "ModernContig", "ModernLength", "BlockCount", "BlockSizes", "QStarts", "TStarts"
]

psl = pl.read_csv(
    "adna/richdalei/adna_regions_contig_liftover_a9.psl",
    separator="\t",
    has_header=False,
    skip_rows=skip_rows,
    columns=cols,
    new_columns=names,
    infer_schema_length=0,
    ignore_errors=True,  # tolerate occasional short lines
    # compression="gzip",  # uncomment if your PSL is gzipped
)

# Clean trailing commas, split lists, cast to int, and explode.
# Use only features present across older Polars versions.
richdalei_blocks = (
    psl
    .with_columns([
        pl.col("BlockCount").cast(pl.Int64),
        pl.col("BlockSizes").fill_null("").str.replace(r',+$', '', literal=False).str.split(",").alias("bs_str"),
        pl.col("QStarts").fill_null("").str.replace(r',+$', '', literal=False).str.split(",").alias("qs_str"),
        pl.col("TStarts").fill_null("").str.replace(r',+$', '', literal=False).str.split(",").alias("ts_str"),
    ])
    # If your Polars is older (<0.19), swap `.list.eval(...)` for `.arr.eval(...)` in the next 3 lines.
    .with_columns([
        pl.col("bs_str").list.eval(pl.element().cast(pl.Int64)).alias("bs"),
        pl.col("qs_str").list.eval(pl.element().cast(pl.Int64)).alias("qs"),
        pl.col("ts_str").list.eval(pl.element().cast(pl.Int64)).alias("ts"),
    ])
    .drop(["bs_str", "qs_str", "ts_str"])
    .with_row_index(name="_rid")          # keep the source PSL row id (handy for grouping/joins)
    .explode(["bs", "qs", "ts"])          # one row per ungapped block
    .filter(                              
        pl.all_horizontal(
            pl.col("bs").is_not_null(),
            pl.col("qs").is_not_null(),
            pl.col("ts").is_not_null()
        )
    )
    .with_columns([
        # 0-based, half-open block intervals (PSL/BED style)
        pl.col("qs").alias("AncStart0"),
        (pl.col("qs") + pl.col("bs")).alias("AncEnd0"),
        pl.col("ts").alias("ModernStart0"),
        (pl.col("ts") + pl.col("bs")).alias("ModernEnd0"),
        # 1-based starts (useful for VCF POS)
        (pl.col("qs") + 1).alias("AncStart1"),
        (pl.col("ts") + 1).alias("ModernStart1"),
        (pl.col("qs") + pl.col("bs") + 1).alias("AncEnd1"),
        (pl.col("ts") + pl.col("bs") + 1).alias("ModernEnd1"),
    ])
    .select([
        "Strands", "AncContig", "AncLength", "AncStart", "AncEnd",
        "ModernContig", "ModernLength", "BlockCount",
        "AncStart0", "AncEnd0", "ModernStart0", "ModernEnd0",
        "AncStart1", "ModernStart1", "ModernEnd1",
        "bs", "_rid",
    ])
)

# Waitaha Export

In [None]:
# Open up the isec file
isec_vcf_file = "adna/waitaha/isec/0002.vcf"
isec_vcf = VCF(isec_vcf_file)

# As a tuple (chr, start, end)
waitaha_isec_regions = []

for variant in isec_vcf:
    contig = variant.CHROM
    pos = variant.POS

    # If the variant is in a region for waitaha_blocks, add it to waitaha_isec_regions
    hit_blocks = waitaha_blocks.filter(
        (pl.col("ModernContig") == contig) &
        (pl.col("ModernStart1") <= pos) &
        (pos < pl.col("ModernEnd1"))
    )
    if hit_blocks.height > 0:
        # Get the first hit block (there should only be one)
        hit_block = hit_blocks[0]
        start = hit_block["ModernStart1"][0]
        end = hit_block["ModernEnd1"][0]
        waitaha_isec_regions.append((contig, start, end))

# Get unique
waitaha_isec_regions = list(set(waitaha_isec_regions))
print(len(waitaha_isec_regions), "unique regions found in isec VCF for waitaha")

In [None]:
waitaha_isec_regions[0:5]

In [None]:
# Let's keep only regions >= 200bp
waitaha_isec_regions = [
    region for region in waitaha_isec_regions if (region[2] - region[1]) >= 200
]
print(len(waitaha_isec_regions), "regions >= 200bp")

In [None]:
import random

# Great, randomly choose 40 of these
waitaha_regions = random.sample(waitaha_isec_regions, 80)

## Let's extract some regions and export

In [None]:
from needletail import (NeedletailError, Record, parse_fastx_file, reverse_complement)

# # We need to get regions that are:
# # - Present in both waitaha and richdalei (intersections)
# # - Regions found in waitaha
# # - Regions found in richdalei
# # - Then any others can be from anywhere on the modern genome (provided they have SNPs)

# num_regions = 400

# n_waitaha = int(num_regions * 0.2)
# n_richdalei = int(num_regions * 0.2)
# n_other = num_regions - (n_waitaha + n_richdalei)
# [n_waitaha, n_richdalei, n_other]

In [None]:
vcf_file = "../merged.a9.filtered.qual20_fmissing0.2.2alleles.snpsonly.pp6.19.removed.vcf.gz"
vcf = VCF(vcf_file)
samples = vcf.samples
len(samples)

In [None]:
# We do it this way because we can only confidently call even reference alleles in aDNA in a very limited regions

In [None]:
# Load up the modern genome and extract some regions
# Let's pull this from VCF tools snp density VCFtools version 0.1.16
snpden = pl.read_csv("out.snpden", separator="\t", has_header=True)

# Get min, max, mean, stddev of SNP Density
# Column we want is SNP_COUNT (or VARIANTS/KB), but they are equivalent for our purposes
# Bin size is 5kbp
snpden_stats = snpden.select([
    pl.col("SNP_COUNT").min().alias("min_SNP_COUNT"),
    pl.col("SNP_COUNT").max().alias("max_SNP_COUNT"),
    pl.col("SNP_COUNT").mean().alias("mean_SNP_COUNT"),
    pl.col("SNP_COUNT").std().alias("std_SNP_COUNT")
])

# Let's filter to SNP_COUNT > 2, and max of mean + 2.5*stddev
snpden_filtered = snpden.filter(
    (pl.col("SNP_COUNT") > 2) &
    (pl.col("SNP_COUNT") < (snpden_stats["mean_SNP_COUNT"][0] + 2.5 * snpden_stats["std_SNP_COUNT"][0]))
)

# Mean, min, max, stddev of SNP_COUNT before and after filtering
[len(snpden), len(snpden_filtered), 
 snpden_stats["min_SNP_COUNT"][0], snpden_stats["max_SNP_COUNT"][0], snpden_stats["mean_SNP_COUNT"][0], snpden_stats["std_SNP_COUNT"][0],
 snpden_filtered.select([
     pl.col("SNP_COUNT").min().alias("min_SNP_COUNT"),
     pl.col("SNP_COUNT").max().alias("max_SNP_COUNT"),
     pl.col("SNP_COUNT").mean().alias("mean_SNP_COUNT"),
     pl.col("SNP_COUNT").std().alias("std_SNP_COUNT")
 ]),
 # And the 2.5*stddev threshold
 snpden_stats["mean_SNP_COUNT"][0] + 2.5 * snpden_stats["std_SNP_COUNT"][0]
]

In [None]:
modern_lines = []

for (contig, start, end) in waitaha_regions:  # or .iter_rows(named=True) for pandas
    modern = f"{contig}:{start + 1}-{end}"
    modern_lines.append(modern)

with open("modern_regions.txt", "w") as f: f.write("\n".join(modern_lines) + "\n")

# Generates a multi-FASTA with one entry per region in anc_regions.txt (same order)
! samtools faidx -r modern_regions.txt ../a9_genome_masked.fa \
  | bcftools consensus -s filtered2 -H I -M N adna/waitaha/aDNA_on_mod_sorted.bcf \
  > out/waitaha.multi.fa

In [None]:
len(waitaha_regions)

In [None]:
import os
from pathlib import Path
import subprocess
from concurrent.futures import ThreadPoolExecutor, as_completed

# ---- inputs ----
regions_file = "modern_regions.txt"
reference_fa = "../a9_genome_masked.fa"
out_dir = Path("out")
max_workers = 8  # same as -j 8

# ---- setup ----
out_dir.mkdir(parents=True, exist_ok=True)

def run_consensus_for_sample(sample: str) -> str:
    """
    samtools faidx -r modern_regions.txt ../a9_genome_masked.fa \
      | bcftools consensus -s {sample} -H I -M N {vcf_file} \
      > out/{sample}.multi.fa
    """
    out_path = out_dir / f"{sample}.multi.fa"

    sam_cmd = ["samtools", "faidx", "-r", regions_file, reference_fa]
    bcf_cmd = ["bcftools", "consensus", "-s", sample, "-H", "I", "-M", "N", vcf_file]

    # Open output file for bcftools' stdout
    with open(out_path, "wb") as fout:
        # samtools produces FASTA to stdout
        p1 = subprocess.Popen(
            sam_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
        )
        # bcftools reads FASTA on stdin and writes consensus to stdout
        p2 = subprocess.Popen(
            bcf_cmd, stdin=p1.stdout, stdout=fout, stderr=subprocess.PIPE
        )
        # Allow p1 to get SIGPIPE if p2 exits early
        p1.stdout.close()  # type: ignore[attr-defined]

        # Wait and collect stderr safely
        _, err2 = p2.communicate()
        _, err1 = p1.communicate()

        if p1.returncode != 0:
            raise RuntimeError(
                f"[{sample}] samtools faidx failed (code {p1.returncode}).\n{err1.decode(errors='ignore')}"
            )
        if p2.returncode != 0:
            raise RuntimeError(
                f"[{sample}] bcftools consensus failed (code {p2.returncode}).\n{err2.decode(errors='ignore')}"
            )

    return str(out_path)

# ---- run in parallel ----
errors = []
print(f"Running consensus for {len(samples)} samples with {max_workers} workers…")
with ThreadPoolExecutor(max_workers=max_workers) as ex:
    futs = {ex.submit(run_consensus_for_sample, s): s for s in samples}
    for fut in as_completed(futs):
        sample = futs[fut]
        try:
            outp = fut.result()
            print(f"✓ {sample} → {outp}")
        except Exception as e:
            errors.append((sample, str(e)))
            print(f"✗ {sample} failed: {e}")

if errors:
    print("\nSome samples failed:")
    for s, msg in errors:
        print(f"- {s}: {msg}")
    raise SystemExit(1)
else:
    print("\nAll samples completed.")


In [None]:
import os, json, glob

# Assumes you already created:
# - anc_regions.txt (one region per line, ordered)
# - modern_regions.txt (one region per line, ordered)
# - out/waitaha.multi.fa (aDNA consensus across all anc regions)
# - out/{sample}.multi.fa for each modern sample (consensus across all modern regions)
#
# Also assumes you saved:
# - strand_flags.json  -> list like ["+-", "++", ...] in the same row order
# - coords.json        -> list of pairs [(anc_str, modern_str), ...] same row order

# Load Waitaha multi-FASTA in order
waitaha_records = list(parse_fastx_file("out/waitaha.multi.fa"))  # your existing parser

# Load modern multi-FASTAs, one per sample
modern_records = {}  # sample -> list of records (same length/order as waitaha_records)
for p in glob.glob("out/*.multi.fa"):
    base = os.path.basename(p)
    name = base.replace(".multi.fa", "")
    if name.lower() == "waitaha":
        continue
    modern_records[name] = list(parse_fastx_file(p))

# Sanity checks
n = len(waitaha_records)

os.makedirs("regions", exist_ok=True)

for i, region in enumerate(waitaha_regions):
    modern_str = f"{region[0]}:{region[1] + 1}-{region[2]}"
    d = f"regions/region_{i}"
    os.makedirs(d, exist_ok=True)

    # coordinates.txt (anc line first, modern line second)
    with open(f"{d}/coordinates.txt", "w") as f:
        f.write(modern_str + "\n")

    # Waitaha sequence (RC only if '+-')
    wrec = waitaha_records[i]
    wseq = wrec.seq  # your parser yields the seq string (no newlines)

    with open(f"{d}/waitaha.fasta", "w") as f:
        f.write(">waitaha_waitaha_waitaha\n")
        f.write(wseq + "\n")  # unwrapped

    # Modern sequences (no RC), one file per sample, header is >{sample}
    for sample, recs in modern_records.items():
        mseq = recs[i].seq
        # Get sample from sample_id_map
        sample = sample_id_map.get(sample, sample)  # Fallback to original if not
        with open(f"{d}/{sample}.fasta", "w") as f:
            f.write(f">{sample}\n")
            f.write(mseq + "\n")  # unwrapped


# Richdalei Export

In [None]:
# Open up the isec file
isec_vcf_file = "adna/richdalei/isec/0002.vcf"
isec_vcf = VCF(isec_vcf_file)

# As a tuple (chr, start, end)
richdalei_isec_regions = []

for variant in isec_vcf:
    contig = variant.CHROM
    pos = variant.POS

    # If the variant is in a region for richdalei_blocks, add it to richdalei_isec_regions
    hit_blocks = richdalei_blocks.filter(
        (pl.col("ModernContig") == contig) &
        (pl.col("ModernStart1") <= pos) &
        (pos < pl.col("ModernEnd1"))
    )
    if hit_blocks.height > 0:
        # Get the first hit block (there should only be one)
        hit_block = hit_blocks[0]
        start = hit_block["ModernStart1"][0]
        end = hit_block["ModernEnd1"][0]
        richdalei_isec_regions.append((contig, start, end))

# Get unique
richdalei_isec_regions = list(set(richdalei_isec_regions))
print(len(richdalei_isec_regions), "unique regions found in isec VCF for richdalei")

In [None]:
      
import polars as pl
from cyvcf2 import VCF

# --- 1. Prepare the Blocks DataFrame ---
# Assuming richdalei_blocks is already a Polars DataFrame like this:
# richdalei_blocks = pl.DataFrame({
#     "ModernContig": ["chr1", "chr1", "chr2"],
#     "ModernStart1": [100, 500, 200],
#     "ModernEnd1": [200, 600, 300],
# })

# For the join, it's best to rename the columns to match the variants DataFrame.
# Also, sort the data, which is a requirement for join_asof.
blocks_df = richdalei_blocks.rename({
    "ModernContig": "CHROM",
    "ModernStart1": "START",
    "ModernEnd1": "END"
}).sort("CHROM", "START")


# --- 2. Load VCF Variants into a Polars DataFrame ---
# Instead of iterating and processing, we'll just extract the necessary data.
isec_vcf_file = "adna/richdalei/isec/0002.vcf"
isec_vcf = VCF(isec_vcf_file)

# This list comprehension is much faster than a for-loop with processing inside.
variants_data = [
    {"CHROM": variant.CHROM, "POS": variant.POS}
    for variant in isec_vcf
]

# If the VCF is very large and causes memory issues, you can process it in chunks.
# But for most cases, this is fine and very fast.
variants_df = pl.DataFrame(variants_data).sort("CHROM", "POS")


# --- 3. Perform the High-Speed `join_asof` ---
# This is the core of the solution. `join_asof` finds the last block `START`
# that is less than or equal to the variant `POS`, for each chromosome.
found_blocks = variants_df.join_asof(
    blocks_df,
    left_on="POS",
    right_on="START",
    by="CHROM",
    strategy="backward" # "backward" means find the last match before the position
)

# --- 4. Filter and Get Unique Regions ---
# The join_asof gives us the candidate block. Now we must filter to ensure
# the position is actually within the block's bounds (POS < END).
# Then, we select the block columns and get the unique ones.
richdalei_isec_regions_df = (
    found_blocks
    .filter(pl.col("END").is_not_null()) # Remove variants that found no block
    .filter(pl.col("POS") < pl.col("END")) # The crucial interval condition
    .select(["CHROM", "START", "END"])   # Select the columns defining a region
    .unique()                            # Get the unique regions
)

print(f"{richdalei_isec_regions_df.height} unique regions found in isec VCF for richdalei")

# If you absolutely need the final output as a list of tuples:
# richdalei_isec_regions = list(richdalei_isec_regions_df.iter_rows())
# print(richdalei_isec_regions[:5])

    

In [None]:
richdalei_isec_regions[0:5]

In [None]:
# Let's keep only regions >= 200bp
richdalei_isec_regions = [
    region for region in richdalei_isec_regions if (region[2] - region[1]) >= 2000
]
print(len(richdalei_isec_regions), "regions >= 2000bp")

import random

# Great, randomly choose 40 of these
richdalei_regions = random.sample(richdalei_isec_regions, 80)

In [None]:
modern_lines = []

for (contig, start, end) in richdalei_regions:  # or .iter_rows(named=True) for pandas
    modern = f"{contig}:{start + 1}-{end}"
    modern_lines.append(modern)

with open("modern_regions.txt", "w") as f: f.write("\n".join(modern_lines) + "\n")

# Generates a multi-FASTA with one entry per region in anc_regions.txt (same order)
! samtools faidx -r modern_regions.txt ../a9_genome_masked.fa \
  | bcftools consensus -s with_rg -H I -M N adna/richdalei/richdalei_sorted.bcf \
  > out_richdalei/richdalei.multi.fa

len(richdalei_regions)

In [None]:
import os
from pathlib import Path
import subprocess
from concurrent.futures import ThreadPoolExecutor, as_completed

# ---- inputs ----
regions_file = "modern_regions.txt"
reference_fa = "../a9_genome_masked.fa"
out_dir = Path("out_richdalei")
max_workers = 8  # same as -j 8

# ---- setup ----
out_dir.mkdir(parents=True, exist_ok=True)

def run_consensus_for_sample(sample: str) -> str:
    """
    samtools faidx -r modern_regions.txt ../a9_genome_masked.fa \
      | bcftools consensus -s {sample} -H I -M N {vcf_file} \
      > out_richdalei/{sample}.multi.fa
    """
    out_path = out_dir / f"{sample}.multi.fa"

    sam_cmd = ["samtools", "faidx", "-r", regions_file, reference_fa]
    bcf_cmd = ["bcftools", "consensus", "-s", sample, "-H", "I", "-M", "N", vcf_file]

    # Open output file for bcftools' stdout
    with open(out_path, "wb") as fout:
        # samtools produces FASTA to stdout
        p1 = subprocess.Popen(
            sam_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
        )
        # bcftools reads FASTA on stdin and writes consensus to stdout
        p2 = subprocess.Popen(
            bcf_cmd, stdin=p1.stdout, stdout=fout, stderr=subprocess.PIPE
        )
        # Allow p1 to get SIGPIPE if p2 exits early
        p1.stdout.close()  # type: ignore[attr-defined]

        # Wait and collect stderr safely
        _, err2 = p2.communicate()
        _, err1 = p1.communicate()

        if p1.returncode != 0:
            raise RuntimeError(
                f"[{sample}] samtools faidx failed (code {p1.returncode}).\n{err1.decode(errors='ignore')}"
            )
        if p2.returncode != 0:
            raise RuntimeError(
                f"[{sample}] bcftools consensus failed (code {p2.returncode}).\n{err2.decode(errors='ignore')}"
            )

    return str(out_path)

# ---- run in parallel ----
errors = []
print(f"Running consensus for {len(samples)} samples with {max_workers} workers…")
with ThreadPoolExecutor(max_workers=max_workers) as ex:
    futs = {ex.submit(run_consensus_for_sample, s): s for s in samples}
    for fut in as_completed(futs):
        sample = futs[fut]
        try:
            outp = fut.result()
            print(f"✓ {sample} → {outp}")
        except Exception as e:
            errors.append((sample, str(e)))
            print(f"✗ {sample} failed: {e}")

if errors:
    print("\nSome samples failed:")
    for s, msg in errors:
        print(f"- {s}: {msg}")
    raise SystemExit(1)
else:
    print("\nAll samples completed.")


In [None]:
import os, json, glob

# Assumes you already created:
# - anc_regions.txt (one region per line, ordered)
# - modern_regions.txt (one region per line, ordered)
# - out/waitaha.multi.fa (aDNA consensus across all anc regions)
# - out/{sample}.multi.fa for each modern sample (consensus across all modern regions)
#
# Also assumes you saved:
# - strand_flags.json  -> list like ["+-", "++", ...] in the same row order
# - coords.json        -> list of pairs [(anc_str, modern_str), ...] same row order

# Load Waitaha multi-FASTA in order
richdalei_records = list(parse_fastx_file("out_richdalei/richdalei.multi.fa"))  # your existing parser

# Load modern multi-FASTAs, one per sample
modern_records = {}  # sample -> list of records (same length/order as waitaha_records)
for p in glob.glob("out_richdalei/*.multi.fa"):
    base = os.path.basename(p)
    name = base.replace(".multi.fa", "")
    if name.lower() == "richdalei":
        continue
    modern_records[name] = list(parse_fastx_file(p))

# Sanity checks
n = len(richdalei_regions)
if n > 100:
    print(f"Warning: {n} richdalei regions found, expected 80. Check your input files.")


os.makedirs("regions", exist_ok=True)

for i, region in enumerate(richdalei_regions):
    i = i
    modern_str = f"{region[0]}:{region[1] + 1}-{region[2]}"
    d = f"regions/region_{i+80}"
    os.makedirs(d, exist_ok=True)

    # coordinates.txt (anc line first, modern line second)
    with open(f"{d}/coordinates.txt", "w") as f:
        f.write(modern_str + "\n")

    wrec = richdalei_records[i]

    wseq = wrec.seq  # your parser yields the seq string (no newlines)

    with open(f"{d}/richdalei.fasta", "w") as f:
        f.write(">richdalei_richdalei_richdalei\n")
        f.write(wseq + "\n")  # unwrapped

    # Modern sequences (no RC), one file per sample, header is >{sample}
    for sample, recs in modern_records.items():
        mseq = recs[i].seq
        # Get sample from sample_id_map
        sample = sample_id_map.get(sample, sample)  # Fallback to original if not
        with open(f"{d}/{sample}.fasta", "w") as f:
            f.write(f">{sample}\n")
            f.write(mseq + "\n")  # unwrapped


# STOP


In [None]:
import os, json, glob, subprocess
from pathlib import Path

def write_regions_files(df, anc_path, modern_path, *, one_based=True):
    """Emit anc and modern region lists (if columns exist), preserve row order.
       Returns (anc_regions_file or None, modern_regions_file or None,
                coords:list[(anc_str or None, modern_str or None)],
                strands:list or None)"""
    anc_lines, modern_lines = [], []
    coords = []
    strands = df["Strands"].to_list() if "Strands" in df.columns else None

    for row in df.to_dicts():
        anc_str = None
        if all(c in row for c in ("AncContig","AncStart","AncEnd")):
            s = row["AncStart"] + (1 if one_based else 0)
            e = row["AncEnd"]
            anc_str = f"{row['AncContig']}:{s}-{e}"
            anc_lines.append(anc_str)

        modern_str = None
        if all(c in row for c in ("ModernContig","ModernStart","ModernEnd")):
            s = row["ModernStart"] + (1 if one_based else 0)
            e = row["ModernEnd"]
            modern_str = f"{row['ModernContig']}:{s}-{e}"
            modern_lines.append(modern_str)

        coords.append((anc_str, modern_str))

    anc_file = None
    if anc_lines:
        Path(anc_path).write_text("\n".join(anc_lines) + "\n")

    modern_file = None
    if modern_lines:
        Path(modern_path).write_text("\n".join(modern_lines) + "\n")

    return anc_path if anc_lines else None, modern_path if modern_lines else None, coords, strands


import subprocess, signal
from pathlib import Path

def run_pipe_to_file(cmd1, cmd2, out_path):
    """Run: cmd1 | cmd2 > out_path, but report bcftools errors first."""
    Path(out_path).parent.mkdir(parents=True, exist_ok=True)
    with open(out_path, "wb") as fout:
        p1 = subprocess.Popen(cmd1, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        p2 = subprocess.Popen(cmd2, stdin=p1.stdout, stdout=fout, stderr=subprocess.PIPE)
        # Let p1 get SIGPIPE if p2 exits early
        if p1.stdout is not None:
            p1.stdout.close()

        # Wait; read stderr streams
        _, err2 = p2.communicate()
        _, err1 = p1.communicate()

        # 1) If bcftools failed, show that error (this is the usual root cause).
        if p2.returncode != 0:
            raise RuntimeError(
                f"{cmd2[0]} failed (exit {p2.returncode}).\n{err2.decode(errors='ignore')}"
            )

        # 2) If bcftools succeeded, ignore SIGPIPE (-13) from samtools; otherwise report samtools error.
        if p1.returncode not in (0, -getattr(signal, "SIGPIPE", 13)):
            raise RuntimeError(
                f"{cmd1[0]} failed (exit {p1.returncode}).\n{err1.decode(errors='ignore')}"
            )


In [None]:
from concurrent.futures import ThreadPoolExecutor, as_completed

def run_adna_plus_modern(
    set_name: str,
    regions_df,
    adna_label: str,            # e.g., "richdalei"
    adna_sample: str,           # sample name in the aDNA BCF to use (e.g., "filtered2" or actual)
    adna_bcf: str,              # e.g., "richdalei_norm.bcf"
    anc_ref_fa: str,            # e.g., "Anc01.fasta"
    modern_ref_fa: str,         # e.g., "../a9_genome_masked.fa"
    modern_vcf: str,            # your modern population VCF/BCF
    samples,          # modern samples list (one per line)
    max_workers: int = 8,
    region_start: int = 0,
):
    out_dir = Path(f"out/{set_name}")
    out_dir.mkdir(parents=True, exist_ok=True)

    # 1) Region lists (1‑based inclusive)
    anc_list, modern_list, coords, strands = write_regions_files(
        regions_df,
        anc_path=out_dir / "anc_regions.txt",
        modern_path=out_dir / "modern_regions.txt",
        one_based=True,
    )
    assert anc_list and modern_list, "richdalei set should have both anc and modern coords"

    # Persist metadata for collation
    json.dump(coords, open(out_dir / "coords.json", "w"))
    json.dump(strands, open(out_dir / "strand_flags.json", "w"))

    # 2) aDNA consensus (single run over all anc regions)
    adna_multi = out_dir / f"{adna_label}.multi.fa"
    cmd1 = ["samtools", "faidx", "-r", str(anc_list), anc_ref_fa]
    cmd2 = ["bcftools", "consensus", "-s", adna_sample, "-H", "I", "-M", "N", adna_bcf]
    run_pipe_to_file(cmd1, cmd2, adna_multi)

    def do_one(sample):
        outp = out_dir / f"{sample}.multi.fa"
        c1 = ["samtools", "faidx", "-r", str(modern_list), modern_ref_fa]
        c2 = ["bcftools", "consensus", "-s", sample, "-H", "I", "-M", "N", modern_vcf]
        run_pipe_to_file(c1, c2, outp)
        return str(outp)

    with ThreadPoolExecutor(max_workers=max_workers) as ex:
        futs = {ex.submit(do_one, s): s for s in samples}
        for fut in as_completed(futs):
            fut.result()  # raise on error

    # 4) Collate per‑region folders; RC only the aDNA when '+-'
    waitaha_like_collate(
        set_name=set_name,
        adna_label=adna_label,
        out_dir=str(out_dir),
        rc_strands=True,  # enable RC per '+-' flag
        region_start=region_start,  # start from 0 or 1 as needed
    )


def waitaha_like_collate(set_name: str, adna_label: str, out_dir: str, rc_strands: bool, region_start: int = 0):
    out_dir = Path(out_dir); 
    coords = json.load(open(out_dir / "coords.json"))
    strands = json.load(open(out_dir / "strand_flags.json")) if rc_strands else None

    # Load aDNA records
    adna_records = list(parse_fastx_file(out_dir / f"{adna_label}.multi.fa"))

    # Load modern records per sample
    modern_records = {}
    for p in glob.glob(str(out_dir / "*.multi.fa")):
        name = Path(p).stem.replace(".multi", "")
        if name == adna_label:
            continue
        modern_records[name] = list(parse_fastx_file(p))

    n = len(adna_records)
    assert n == len(coords), "Count mismatch (coords vs aDNA records)"
    for s, recs in modern_records.items():
        assert len(recs) == n, f"Modern count mismatch for {s}"

    for i, ((anc_str, modern_str)) in enumerate(coords):
        d = f"regions/region_{i + region_start}"
        d = Path(d)
        d.mkdir(parents=True, exist_ok=True)

        # coordinates.txt
        with open(d / "coordinates.txt", "w") as f:
            if anc_str:   f.write(anc_str + "\n")
            if modern_str:f.write(modern_str + "\n")

        # aDNA file (header >{adna_label}); RC when '+-'
        arec = adna_records[i]
        seq = arec.seq
        if strands:
            flag = strands[i]
            if flag == "+-":
                seq = reverse_complement(seq)
            elif flag != "++":
                print(f"[{set_name}] Unexpected strand {flag} at region_{i}; skipping")
                continue

        with open(d / f"{adna_label}.fasta", "w") as f:
            adna_label = sample_id_map.get(adna_label, adna_label)
            f.write(f">{adna_label}\n{seq}\n")

        # Modern per sample (never RC)
        for s, recs in modern_records.items():
            mseq = recs[i].seq
            # Get sample from sample_id_map
            sample = sample_id_map.get(s, s)  # Fallback to original if not
            with open(d / f"{s}.fasta", "w") as f:
                f.write(f">{s}\n{mseq}\n")


In [None]:
run_adna_plus_modern(
    set_name="richdalei",
    regions_df=richdalei_regions,      # your DataFrame
    adna_label="richdalei",
    adna_sample="with_rg",           # or the exact sample name in the BCF
    adna_bcf="richdalei_norm.bcf",
    anc_ref_fa="Anc01.fasta",
    modern_ref_fa="../a9_genome_masked.fa",
    modern_vcf=vcf_file,               # your modern pop VCF/BCF
    samples=samples,
    max_workers=8,
    region_start=len(waitaha_regions)
)


In [None]:
!pixi run python realign_regions.py

In [None]:
# This should make them all the same length as modern chromosome

In [None]:
def run_modern_only(
    set_name: str,
    regions_df,
    modern_ref_fa: str,
    modern_vcf: str,
    samples,
    max_workers: int = 8,
    region_start: int = 0,
    bin_size: int = 5000,  # 5kb bins
):
    out_dir = Path(f"out/{set_name}")
    out_dir.mkdir(parents=True, exist_ok=True)
    reg_dir = Path(f"regions/")
    reg_dir.mkdir(parents=True, exist_ok=True)

    # In columns are "CHROM" and "BIN_START". We can get the end with "BIN_START + BIN_SIZE"
    # Need to convert to the format for the other tools

    # We need regions_df to be: "ModernContig","ModernStart","ModernEnd", but ours is the CHROM, BIN_START, BIN_SIZE
    regions_df = regions_df.select([
        pl.col("CHROM").alias("ModernContig"),
        (pl.col("BIN_START")).alias("ModernStart"),  # convert to 1-based inclusive
        (pl.col("BIN_START") + bin_size).alias("ModernEnd")
    ])

    # Only modern regions are required; anc columns may be absent
    anc_list, modern_list, coords, _ = write_regions_files(
        regions_df,
        anc_path=out_dir / "anc_regions.txt",       # may not be used
        modern_path=out_dir / "modern_regions.txt",
        one_based=True,
    )
    assert modern_list, "modern_regions set must have modern coordinates"
    json.dump(coords, open(out_dir / "coords.json", "w"))

    def do_one(sample):
        outp = out_dir / f"{sample}.multi.fa"
        c1 = ["samtools", "faidx", "-r", str(modern_list), modern_ref_fa]
        c2 = ["bcftools", "consensus", "-s", sample, "-H", "I", "-M", "N", modern_vcf]

        run_pipe_to_file(c1, c2, outp)
        return str(outp)

    from concurrent.futures import ThreadPoolExecutor, as_completed
    with ThreadPoolExecutor(max_workers=max_workers) as ex:
        futs = {ex.submit(do_one, s): s for s in samples}
        for fut in as_completed(futs):
            fut.result()

    # Collate: per‑region folders with coordinates.txt and {sample}.fasta files
    coords = json.load(open(out_dir / "coords.json"))
    # Load modern records per sample
    modern_records = {Path(p).stem.replace(".multi",""): list(parse_fastx_file(p))
                      for p in glob.glob(str(out_dir / "*.multi.fa"))}

    n = len(next(iter(modern_records.values())))
    assert n == len(coords), "Count mismatch (coords vs first sample records)"
    for s, recs in modern_records.items():
        assert len(recs) == n, f"Region count mismatch for sample {s}"

    for i, (anc_str, modern_str) in enumerate(coords):
        d = reg_dir / f"region_{i + region_start}"
        d.mkdir(parents=True, exist_ok=True)
        with open(d / "coordinates.txt", "w") as f:
            if anc_str:   f.write(anc_str + "\n")    # present if df had ancient cols
            if modern_str:f.write(modern_str + "\n")

        for s, recs in modern_records.items():
            mseq = recs[i].seq
            with open(d / f"{s}.fasta", "w") as f:
                sample = sample_id_map.get(s, s)
                f.write(f">{s}\n{mseq}\n")


In [None]:
# n_other
n_other = 200
modern_only_regions = snpden_filtered.sample(n_other, with_replacement=False)
run_modern_only(
    set_name="modern_only",
    regions_df=modern_only_regions,
    modern_ref_fa="../a9_genome_masked.fa",
    modern_vcf=vcf_file,               # your modern pop VCF/BCF
    samples=samples,
    max_workers=8,
    region_start=500,
    bin_size=2000 # 2kbp
)

In [None]:
vcf = VCF(vcf_file)
a9_samples = vcf.samples

HALSTATS_PATH = "../seabird_alignment_halstats"

# aDNA Samples
adna_samples = ["waitaha", "richdalei"]

# Seabird Samples from halstats and Cactus
seabirds_df = pl.read_csv(HALSTATS_PATH, skip_lines=4)['GenomeName'].to_list()
seabirds_df = [s for s in seabirds_df if s is not None]
seabirds_df = [s for s in seabirds_df if not s.startswith("Anc")]
# Filter out c90 (it's subantarctic islands and we have them from the regular pop)
seabirds_df = [s for s in seabirds_df if not s.startswith("c90")]
seabirds_df = [s for s in seabirds_df if not s.startswith("a9")] # We have a9 in the SNPs as well
# Remove Megadyptes_antipodes
seabirds_df = [s for s in seabirds_df if not s.startswith("Megadyptesantipodes")]
# Filter out samples as in the notebook
# all_seabirds = [s for s in all_seabirds if not s.startswith(("c90", "a9", "Megadyptesantipodes"))]
all_seabirds = seabirds_df

# Penguin Samples (subset of seabirds)
penguin_prefixes = ("Aptenodytes", "Spheniscus", "Pygoscelis", "Eudyptula", "Eudyptes")
penguins = [sp for sp in all_seabirds if sp.startswith(penguin_prefixes)]

# --- Define the sets of samples to analyze ---
cactus_penguins = penguins
cactus_seabirds = all_seabirds

default_outgroups = ["Eudyptesmoseleyi_genomic", "Spheniscushumboldti_genomic", "Eudyptesfilholi_genomic"]

TARGET_SAMPLE_SETS = {
    "penguins": a9_samples + adna_samples + cactus_penguins,
    "all_seabirds": a9_samples + adna_samples + cactus_seabirds,
    "vcf_and_adna_only": a9_samples + adna_samples,
    "just_a_few_outgroups": a9_samples + adna_samples + default_outgroups
}

## Add in the other species

In [None]:
# Fix:         RuntimeError: invalid range start=25620000 end=25622000 exceeds ptg000013l sequence length of 25621374
# ../a9_genome_masked.fa.fai
# Read in the fai file

fai_file = "../a9_genome_masked.fa.fai"
fai = pl.read_csv(fai_file, separator="\t", has_header=False, new_columns=["Contig", "Length", "Offset", "LineBases", "LineWidth"])
fai = fai.with_columns(pl.col("Length").cast(pl.Int64))
# Create a dictionary for quick access
fai_dict = {row["Contig"]: row["Length"] for row in fai.iter_rows(named=True)}

In [None]:
#!/usr/bin/env python3
"""
Run cactus-hal2maf and then oxid_maf remove-ref-indels for each region in parallel.

Requirements from user:
- Jobstore: unique per run & region, MUST NOT exist beforehand, and NOT under /tmp.
- WorkDir: unique per run & region, MUST EXIST beforehand, on the same drive as before (/mnt/data).

Other features:
- Up to MAX_WORKERS regions in flight.
- Per-region MAF output (region/hal2maf.maf).
- Renames resulting FASTA headers using sample_id_map.
- Adjusts extraction length if it would exceed contig length (fai_dict).
"""

from glob import glob
import subprocess
import os
import sys
import uuid
from datetime import datetime
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed

# ------------ config ------------
MAX_WORKERS = 6

HAL_PATH = "/mnt/data/seabirds.hal"
CACTUS_BIN = "/mnt/data/development/hoiho/wga/cactus/cactus-bin-v2.9.7"
CACTUS_VENV = "/mnt/data/development/hoiho/wga/cactus/cactus-bin-v2.9.7/venv-cactus-v2.9.7"
CACTUS_HAL2MAF = f"{CACTUS_VENV}/bin/cactus-hal2maf"

# Bases (same drive: /mnt/data)
BASE_JOBSTORE = Path("/mnt/data/hal2maf_jobstores")    # jobstores live here; per-run subdirs
BASE_WORKDIR  = Path("/mnt/data/workdir_hal2maf")      # work dirs live here; per-run subdirs

# oxid_maf path & args
OXID_MAF = "/home/joseph/development/OxidMAF/target/release/oxid_maf"
OXID_SUBCOMMAND = "remove-ref-indels"
OXID_EXCLUDE = "c90,a9,Megadyptesantipodesantipodes_genomic"  # from modern pop

# Environment (ensure cactus bin/venv are in PATH)
subprocess_env = os.environ.copy()
subprocess_env["PATH"] = f"{CACTUS_BIN}/bin:{CACTUS_VENV}/bin:" + subprocess_env.get("PATH", "")

# Create a unique RUN_ID for this invocation
RUN_ID = datetime.now().strftime("%Y%m%d-%H%M%S") + "-" + uuid.uuid4().hex[:8]

# ------------ helpers ------------

def parse_ptg_coords(coords_path: Path):
    """
    Parse coordinates.txt and return (contig, start0, length).
    Expects a line like: ptg000001l:42594242-42599241  (1-based, inclusive end).
    We convert to 0-based start and compute length.
    """
    with coords_path.open() as f:
        lines = [ln.strip() for ln in f if ln.strip()]
    ptg_lines = [ln for ln in lines if ln.startswith("ptg")]
    if not ptg_lines:
        return None

    contig, span = ptg_lines[0].split(":")
    start_s, end_s = span.split("-")
    start0 = int(start_s) - 1
    end1  = int(end_s)      # inclusive end in 1-based input; here we use end (0-based exclusive) logic via length
    length = end1 - start0
    if length <= 0:
        raise ValueError(f"Non-positive length parsed from {coords_path}: {ptg_lines[0]}")
    return contig, start0, length


def process_one_region(region: Path, fai_dict, sample_id_map):
    """
    Process a single region directory with a per-run unique jobstore and workdir.

    Jobstore rules:
      - jobstore path must NOT exist before starting.
      - we do not create it; cactus-hal2maf will.

    WorkDir rules:
      - workdir path must EXIST before starting.
      - we create it; fail if it already exists (violates uniqueness).
    """
    try:
        coords_path = region / "coordinates.txt"
        if not coords_path.exists():
            return (region, False, "coordinates.txt not found")

        parsed = parse_ptg_coords(coords_path)
        if not parsed:
            return (region, False, "No ptg* coordinates found")

        contig, start0, length = parsed

        # Bounds adjust using fai_dict
        contig_len = fai_dict.get(contig)
        if contig_len is None:
            return (region, False, f"Contig {contig} not found in fai_dict")
        if start0 >= contig_len:
            return (region, False, f"Start {start0} beyond contig length {contig_len}")
        if start0 + length > contig_len:
            length = contig_len - start0

        # Per-run, per-region paths
        region_name = region.name
        jobstore = BASE_JOBSTORE / RUN_ID / region_name      # MUST NOT exist
        workdir  = BASE_WORKDIR  / RUN_ID / region_name      # MUST exist

        # Ensure parents exist
        jobstore.parent.mkdir(parents=True, exist_ok=True)
        workdir.parent.mkdir(parents=True, exist_ok=True)

        # Enforce jobstore non-existence
        if jobstore.exists():
            return (region, False, f"Jobstore already exists (must not): {jobstore}")

        # Enforce workdir existence (unique per run, so it should not exist yet)
        if workdir.exists():
            return (region, False, f"WorkDir already exists (violates uniqueness): {workdir}")
        workdir.mkdir(parents=False, exist_ok=False)  # create; must exist before run

        maf_output = region / "hal2maf.maf"  # per-region MAF; overwrite each run
        if maf_output.exists():
            maf_output.unlink()

        print(f"▶ {region_name}: {contig}:{start0}-{start0+length} ({length}bp)")
        print(f"   jobstore={jobstore}")
        print(f"   workdir={workdir}")

        # cactus-hal2maf
        cmd_hal = [
            CACTUS_HAL2MAF,
            str(jobstore),
            HAL_PATH,
            str(maf_output),
            "--noAncestors",
            "--clean", "always",
            "--workDir", str(workdir),
            "--chunkSize", "20000",
            "--binariesMode", "local",
            "--dupeMode", "single",
            "--refSequence", contig,
            "--refGenome", "a9",
            "--start", str(start0),
            "--length", str(length),
        ]
        res = subprocess.run(cmd_hal, capture_output=True, text=True, env=subprocess_env)
        if res.returncode != 0:
            return (region, False, f"cactus-hal2maf failed:\nSTDERR:\n{res.stderr}\nSTDOUT:\n{res.stdout}")

        # oxid_maf: remove ref indels
        cmd_oxid = [
            OXID_MAF,
            OXID_SUBCOMMAND,
            str(maf_output),
            str(region) + "/",   # output dir
            "--exclude",
            OXID_EXCLUDE,
        ]
        res = subprocess.run(cmd_oxid, capture_output=True, text=True, env=subprocess_env)
        if res.returncode != 0:
            return (region, False, f"oxid_maf failed:\nSTDERR:\n{res.stderr}\nSTDOUT:\n{res.stdout}")

        # Rename headers in emitted FASTAs using sample_id_map
        for fasta_path in region.glob("*.fasta"):
            with fasta_path.open() as fh:
                lines = fh.readlines()
            if not lines:
                continue
            header = lines[0].strip()
            seq = "".join(lines[1:]).strip()
            sample_name = header[1:] if header.startswith(">") else header
            new_name = sample_id_map.get(sample_name, sample_name)
            if new_name != sample_name:
                with fasta_path.open("w") as out:
                    out.write(f">{new_name}\n{seq}\n")

        # Success; do NOT delete jobstore/workdir automatically (kept for provenance/resume)
        return (region, True, "ok")

    except Exception as e:
        return (region, False, f"Exception: {e}")


# ------------ main ------------

if __name__ == "__main__":
    # Expect fai_dict and sample_id_map to be defined in the environment where this script is run
    try:
        fai_dict
        sample_id_map
    except NameError:
        print("ERROR: 'fai_dict' and 'sample_id_map' must be defined in scope before running this script.")
        sys.exit(1)

    regions = [Path(p) for p in sorted(glob("regions/region_*")) if Path(p).is_dir()]
    if not regions:
        print("No regions found.")
        sys.exit(0)

    print(f"RUN_ID={RUN_ID}")
    print(f"Submitting {len(regions)} regions with up to {MAX_WORKERS} concurrent jobs…")
    failures = []

    with ThreadPoolExecutor(max_workers=MAX_WORKERS) as ex:
        futs = {ex.submit(process_one_region, r, fai_dict, sample_id_map): r for r in regions}
        for fut in as_completed(futs):
            region = futs[fut]
            ok_region, ok, msg = fut.result()
            name = ok_region.name
            if ok:
                print(f"✓ {name}: {msg}")
            else:
                print(f"✗ {name}: {msg}")
                failures.append((name, msg))

    if failures:
        print("\nSummary of failures:")
        for name, msg in failures:
            print(f" - {name}: {msg}")
        sys.exit(1)

    print("\nAll regions completed successfully.")


In [None]:
# Let's save the waitaha_regions, richdalei_regions, and modern_only_regions to a file
#waitaha_regions.write_parquet("waitaha_regions.parquet")
#richdalei_regions.write_parquet("richdalei_regions.parquet")
#modern_only_regions.write_parquet("modern_only_regions.parquet")

In [None]:
vcf = VCF(vcf_file)
a9_samples = vcf.samples

HALSTATS_PATH = "../seabird_alignment_halstats"

# aDNA Samples
adna_samples = ["waitaha", "richdalei"]

# Seabird Samples from halstats and Cactus
seabirds_df = pl.read_csv(HALSTATS_PATH, skip_lines=4)['GenomeName'].to_list()
seabirds_df = [s for s in seabirds_df if s is not None]
seabirds_df = [s for s in seabirds_df if not s.startswith("Anc")]
# Filter out c90 (it's subantarctic islands and we have them from the regular pop)
seabirds_df = [s for s in seabirds_df if not s.startswith("c90")]
seabirds_df = [s for s in seabirds_df if not s.startswith("a9")] # We have a9 in the SNPs as well
# Remove Megadyptes_antipodes
seabirds_df = [s for s in seabirds_df if not s.startswith("Megadyptesantipodes")]
# Filter out samples as in the notebook
# all_seabirds = [s for s in all_seabirds if not s.startswith(("c90", "a9", "Megadyptesantipodes"))]
all_seabirds = seabirds_df

# Penguin Samples (subset of seabirds)
penguin_prefixes = ("Aptenodytes", "Spheniscus", "Pygoscelis", "Eudyptula", "Eudyptes")
penguins = [sp for sp in all_seabirds if sp.startswith(penguin_prefixes)]

# --- Define the sets of samples to analyze ---
cactus_penguins = penguins
cactus_seabirds = all_seabirds

default_outgroups = ["Eudyptesmoseleyi_genomic", "Spheniscushumboldti_genomic", "Eudyptesfilholi_genomic"]

TARGET_SAMPLE_SETS = {
    "penguins": a9_samples + adna_samples + cactus_penguins,
    "all_seabirds": a9_samples + adna_samples + cactus_seabirds,
    "vcf_and_adna_only": a9_samples + adna_samples,
    "just_a_few_outgroups": a9_samples + adna_samples + default_outgroups
}

In [None]:
# for each region
region_dirs = glob("regions/region_*")

for sample_set, samples in TARGET_SAMPLE_SETS.items():
    # Create the output directory this this sample set "processed/{sample_set}"
    out_dir = Path(f"processed/{sample_set}")
    out_dir.mkdir(parents=True, exist_ok=True)

    for region in region_dirs:
        # Create a blank {region_name}.fasta file in the out_dir
        region_name = Path(region).name
        with open(out_dir / f"{region_name}.fasta", "w") as fh:
            print(f"Processing {region_name} for sample set {sample_set}")
            
            for sample in samples:
                # Check for sample.fasta or mapped sample id.fasta
                sample_fasta = Path(region) / f"{sample}.fasta"
                if not sample_fasta.exists():
                    # Try mapped sample id
                    mapped_sample = sample_id_map.get(sample, sample)
                    sample_fasta = Path(region) / f"{mapped_sample}.fasta"

                if not sample_fasta.exists():
                    continue

                # Read the sample FASTA
                with open(sample_fasta, "r") as sfh:
                    lines = sfh.readlines()
                    if not lines:
                        continue

                    # Write the header and sequence to the output file
                    header = lines[0].strip()
                    seq = "".join(lines[1:]).strip()
                    fh.write(f"{header}\n{seq}\n")
        # Finally, run pixi run mafft on the output file (then rename to the original region fasta file, delete intermediate)
        out_fasta = out_dir / f"{region_name}.fasta"
        # Define the path for the aligned output FASTA file
        aligned_fasta = out_dir / f"{region_name}.aligned.fasta"

        # Correctly run mafft by redirecting standard output to the aligned file
        with open(aligned_fasta, "w") as output_file:
            cmd_mafft = ["pixi", "run", "mafft", "--quiet", str(out_fasta)]
            subprocess.run(cmd_mafft, check=True, stdout=output_file, stderr=subprocess.DEVNULL)

        # Rename the aligned file to replace the original unaligned file
        # The original 'aligned_fasta' file is renamed, so no separate deletion is needed.
        aligned_fasta.rename(out_fasta)

In [None]:
# Fix any indels that snuck in

sample_id_map_values = [sample_id_map[s] for s in a9_samples + adna_samples]

processed_files = glob("processed/*/*.fasta")
for processed_file in processed_files:
    lengths = {}
    with open(processed_file, "r") as f:
        lines = f.readlines()
    if not lines:
        continue
    
    # Check that all sequence lengths are the same (for modern pop and aDNA samples)
    seqs = []
    # Take 2 lines (id and seq), make sure id is in sample_id_map (values), then record the length
    for i in range(0, len(lines), 2):
        if i + 1 >= len(lines):
            continue  # skip if no sequence line
        header = lines[i].strip()
        seq = lines[i + 1].strip()
        if header[1:] not in sample_id_map_values:  # skip if not in sample_id_map
            continue
        lengths[header[1:]] = len(seq)

    if len(set(lengths.values())) > 1:
        print(f"Error in {processed_file}: sequences have different lengths: {set(lengths.values())}")
        break


    
