In [None]:
import tskit
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display

In [None]:
sample_df = pd.read_table(snakemake.config["samples"], sep="\t")
sample_df = sample_df[["Sample", "Habitat"]]

bams = pd.read_table(snakemake.input["bams"][0], names=["bam"])
bams["Sample"] = bams["bam"].str.extract("(s_\\d+_\\d+)")
bams = bams.merge(sample_df, on = "Sample", how = "left")[["Sample", "Habitat"]]

urban_sample_indices = bams.index[bams["Habitat"] == "Urban"].tolist()
rural_sample_indices = bams.index[bams["Habitat"] == "Rural"].tolist()
suburban_sample_indices = bams.index[bams["Habitat"] == "Suburban"].tolist()

habitat_pop_map = {"Urban": 0, "Suburban": 1, "Rural": 2}
sample_indices = {k: {} for k in range(len(bams) * 2)}
for i, r in bams.iterrows():
    sample_indices[i*2]["population"] = habitat_pop_map[r["Habitat"]]
    sample_indices[(i*2) + 1]["population"] = habitat_pop_map[r["Habitat"]]

urban_hap_indices = [k for k,v in sample_indices.items() if v["population"] == 0]
rural_hap_indices = [k for k,v in sample_indices.items() if v["population"] == 2]

In [None]:
results = {"win_id": [], "gt_fst": [], "arg_fst": [], "n_sites": []}
region = f"region{snakemake.wildcards['n']}"


vcf_site_fst = {}
with open(snakemake.input["win_gt_fst"][0], "r") as win_gt_fst_in:
    lines = [l.split("\t") for l in win_gt_fst_in.readlines() if not l.startswith("CHROM")]
    for line in lines:
        pos = line[1]
        fst = float(line[2].strip())
        vcf_site_fst[pos] = fst

ts = tskit.load(f"{snakemake.params['arg_path']}/trees/{region}/{region}_0.trees")
sl = int(ts.sequence_length)
window_size = int(snakemake.params["win_size"])
ts_snp_ids = {site.position: site.id for site in ts.sites()}
windows = [[val, val + window_size] for val in range(0, sl, window_size)]
for i, win in enumerate(windows):
    if max(win) <= sl:
        filt_ts = ts.keep_intervals([win], simplify=False)
        if filt_ts.num_sites > 0:
            filt_ts_positions = filt_ts.tables.sites.position
            filt_ts_site_indices = [v for k, v in ts_snp_ids.items() if k in filt_ts_positions]
            filt_vcf_site_fst = {k:v for i, (k,v) in enumerate(vcf_site_fst.items()) if i in filt_ts_site_indices}
            gt_fst = np.nanmean(list(filt_vcf_site_fst.values()))
            results["win_id"].append(i)
            results["gt_fst"].append(gt_fst)
            results["n_sites"].append(filt_ts.num_sites)
        else:
            pass
    
        arg_fst_vals = []
        n_samples = int(snakemake.params["n_samples"])
        for s in range(n_samples):
            ts = tskit.load(f"{snakemake.params['arg_path']}/trees/{region}/{region}_{s}.trees")
            filt_ts = ts.keep_intervals([win], simplify=False)
            arg_fst = filt_ts.Fst([urban_hap_indices, rural_hap_indices], mode = "branch")
            arg_fst_vals.append(arg_fst)
        mean_arg_fst = np.nanmean(arg_fst_vals)
        if filt_ts.num_sites > 0:
            results["arg_fst"].append(mean_arg_fst)

In [None]:
df_out = pd.DataFrame(results)
df_out["regionID"] = snakemake.wildcards["n"]
df_out.to_csv(snakemake.output[0], index = False)