## Setup

In [None]:
# Load required modules
import tskit
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display

In [None]:
# Load sample DF
sample_df = pd.read_table(snakemake.config["samples"], sep="\t")
sample_df = sample_df[["Sample", "Habitat"]]

# Load order of BAM files and merge with sample DF
bams = pd.read_table(snakemake.input["bams"][0], names=["bam"])
bams["Sample"] = bams["bam"].str.extract("(s_\\d+_\\d+)")
bams = bams.merge(sample_df, on = "Sample", how = "left")[["Sample", "Habitat"]]

# Get indices or urban, rural, and suburban samples in BAM list
# This corresponds to their indices in the VCFs used to build the ARGs
urban_sample_indices = bams.index[bams["Habitat"] == "Urban"].tolist()
rural_sample_indices = bams.index[bams["Habitat"] == "Rural"].tolist()
suburban_sample_indices = bams.index[bams["Habitat"] == "Suburban"].tolist()

# Map habitat names to integers and then get population of each index
habitat_pop_map = {"Urban": 0, "Suburban": 1, "Rural": 2}
sample_indices = {k: {} for k in range(len(bams) * 2)}
for i, r in bams.iterrows():
    sample_indices[i*2]["population"] = habitat_pop_map[r["Habitat"]]
    sample_indices[(i*2) + 1]["population"] = habitat_pop_map[r["Habitat"]]

# Get haplotype indices for each sample
urban_hap_indices = [k for k,v in sample_indices.items() if v["population"] == 0]
rural_hap_indices = [k for k,v in sample_indices.items() if v["population"] == 2]

## Extract site-based Fst estimates

### Estimates from `SINGER` ARGs

- Nei's site-based estimate of Fst from Bhatia (2013) and branch-based estimator from Slatkin (1991).
    - See [this GitHub issue](https://github.com/tskit-dev/tskit/issues/858)

In [None]:
# Parse inputs and parameters
chrom = snakemake.wildcards["chrom"]
region = f"region{snakemake.wildcards['region_id']}"
n_samples = int(snakemake.params["n_samples"])
arg_path = snakemake.params['arg_path']
window_size = snakemake.params['window_size']

## Iterate over posterior ARG samples and get branch and site summary stats at every site
## Stored as lists with `n_samples` elements, where each element is a 1D array of size `num_windows`
# One-way
branch_div_urban = []
branch_div_rural = []
site_div_urban = []
site_div_rural = []

branch_td_urban = []
branch_td_rural = []
site_td_urban = []
site_td_rural = []

branch_ss_urban = []
branch_ss_rural = []
site_ss_urban = []
site_ss_rural = []

# Two-way
branch_fsts = []
site_fsts = []

branch_divergences = []
site_divergences = []

for s in range(n_samples):
    ts = tskit.load(f"{arg_path}/trees/{chrom}/{region}/{region}_{s}.trees")
    num_windows = int(ts.sequence_length / window_size)
    windows = np.linspace(0, ts.sequence_length, num_windows + 1)

    # One-way stats
    branch_diversity_urban = ts.diversity([urban_hap_indices], mode="branch", windows=windows)
    branch_div_urban.append(branch_diversity_urban)
    branch_diversity_rural = ts.diversity([rural_hap_indices], mode="branch", windows=windows)
    branch_div_rural.append(branch_diversity_rural)
    site_diversity_urban = ts.diversity([urban_hap_indices], mode="site", windows=windows)
    site_div_urban.append(branch_diversity_urban)
    site_diversity_rural = ts.diversity([rural_hap_indices], mode="site", windows=windows)
    site_div_rural.append(site_diversity_rural)

    branch_taj_urban = ts.Tajimas_D([urban_hap_indices], mode="branch", windows=windows)
    branch_td_urban.append(branch_taj_urban)
    branch_taj_rural = ts.Tajimas_D([rural_hap_indices], mode="branch", windows=windows)
    branch_td_rural.append(branch_taj_rural)
    site_taj_urban = ts.Tajimas_D([urban_hap_indices], mode="site", windows=windows)
    site_td_urban.append(site_taj_urban)
    site_taj_rural = ts.Tajimas_D([rural_hap_indices], mode="site", windows=windows)
    site_td_rural.append(site_taj_rural)

    branch_seg_urban = ts.segregating_sites([urban_hap_indices], mode="branch", windows=windows)
    branch_ss_urban.append(branch_seg_urban)
    branch_seg_rural = ts.segregating_sites([rural_hap_indices], mode="branch", windows=windows)
    branch_ss_rural.append(branch_seg_rural)
    site_seg_urban = ts.segregating_sites([urban_hap_indices], mode="site", windows=windows)
    site_ss_urban.append(site_seg_urban)
    site_seg_rural = ts.segregating_sites([rural_hap_indices], mode="site", windows=windows)
    site_ss_rural.append(site_seg_rural)
    
    # Two-way stats
    branch_fst = ts.Fst([urban_hap_indices, rural_hap_indices], mode="branch", windows=windows)
    branch_fsts.append(branch_fst)
    site_fst = ts.Fst([urban_hap_indices, rural_hap_indices], mode="site", windows=windows)
    site_fsts.append(site_fst)
    
    branch_divergence = ts.divergence([urban_hap_indices, rural_hap_indices], mode="branch", windows=windows)
    branch_divergences.append(branch_divergence)
    site_divergence = ts.divergence([urban_hap_indices, rural_hap_indices], mode="site", windows=windows)
    site_divergences.append(site_divergence)

In [None]:
# Get mean summary stats in each window across posterior ARG samples
mean_branch_div_urban = np.nanmean(branch_div_urban, axis = 0)
mean_site_div_urban = np.nanmean(site_div_urban, axis = 0)
mean_branch_div_rural = np.nanmean(branch_div_rural, axis = 0)
mean_site_div_rural = np.nanmean(site_div_rural, axis = 0)

mean_branch_td_urban = np.nanmean(branch_td_urban, axis = 0)
mean_site_td_urban = np.nanmean(site_td_urban, axis = 0)
mean_branch_td_rural = np.nanmean(branch_td_rural, axis = 0)
mean_site_td_rural = np.nanmean(site_td_rural, axis = 0)

mean_branch_ss_urban = np.nanmean(branch_ss_urban, axis = 0)
mean_site_ss_urban = np.nanmean(site_ss_urban, axis = 0)
mean_branch_ss_rural = np.nanmean(branch_ss_rural, axis = 0)
mean_site_ss_rural = np.nanmean(site_ss_rural, axis = 0)

mean_branch_fsts = np.nanmean(branch_fsts, axis = 0)
mean_site_fsts = np.nanmean(site_fsts, axis = 0)
mean_branch_divergences = np.nanmean(branch_divergences, axis = 0)
mean_site_divergences = np.nanmean(site_divergences, axis = 0)

In [None]:
# Dictionary to store results
results = {}

# Add ARG-based summary stat estimates to results dictionary
results["window_index"] = [i for i in range(num_windows)]
results["arg_branch_fst"] = [mean_branch_fsts[i] for i in range(num_windows)]
results["arg_site_fst"] = [mean_site_fsts[i] for i in range(num_windows)]
results["arg_branch_divergence"] = [mean_branch_divergences[i] for i in range(num_windows)]
results["arg_site_divergence"] = [mean_site_divergences[i] for i in range(num_windows)]
results["arg_branch_pi_urban"] = [mean_branch_div_urban[i][0] for i in range(num_windows)]
results["arg_site_pi_urban"] = [mean_site_div_urban[i][0] for i in range(num_windows)]
results["arg_branch_pi_rural"] = [mean_branch_div_rural[i][0] for i in range(num_windows)]
results["arg_site_pi_rural"] = [mean_site_div_rural[i][0] for i in range(num_windows)]
results["arg_branch_td_urban"] = [mean_branch_td_urban[i][0] for i in range(num_windows)]
results["arg_site_td_urban"] = [mean_site_td_urban[i][0] for i in range(num_windows)]
results["arg_branch_td_rural"] = [mean_branch_td_rural[i][0] for i in range(num_windows)]
results["arg_site_td_rural"] = [mean_site_td_rural[i][0] for i in range(num_windows)]
results["arg_branch_ss_urban"] = [mean_branch_ss_urban[i][0] for i in range(num_windows)]
results["arg_site_ss_urban"] = [mean_site_ss_urban[i][0] for i in range(num_windows)]
results["arg_branch_ss_rural"] = [mean_branch_ss_rural[i][0] for i in range(num_windows)]
results["arg_site_ss_rural"] = [mean_site_ss_rural[i][0] for i in range(num_windows)]

In [None]:
df_out = pd.DataFrame(results)
df_out["regionID"] = snakemake.wildcards["region_id"]
df_out.to_csv(snakemake.output[0], index = False)