## Setup

In [None]:
# Load required modules
import tskit
import cyvcf2
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display

In [None]:
# Load sample DF
sample_df = pd.read_table(snakemake.config["samples"], sep="\t")
sample_df = sample_df[["Sample", "Habitat"]]

# Load order of BAM files and merge with sample DF
bams = pd.read_table(snakemake.input["bams"][0], names=["bam"])
bams["Sample"] = bams["bam"].str.extract("(s_\\d+_\\d+)")
bams = bams.merge(sample_df, on = "Sample", how = "left")[["Sample", "Habitat"]]

# Get indices or urban, rural, and suburban samples in BAM list
# This corresponds to their indices in the VCFs used to build the ARGs
urban_sample_indices = bams.index[bams["Habitat"] == "Urban"].tolist()
rural_sample_indices = bams.index[bams["Habitat"] == "Rural"].tolist()
suburban_sample_indices = bams.index[bams["Habitat"] == "Suburban"].tolist()

# Map habitat names to integers and then get population of each index
habitat_pop_map = {"Urban": 0, "Suburban": 1, "Rural": 2}
sample_indices = {k: {} for k in range(len(bams) * 2)}
for i, r in bams.iterrows():
    sample_indices[i*2]["population"] = habitat_pop_map[r["Habitat"]]
    sample_indices[(i*2) + 1]["population"] = habitat_pop_map[r["Habitat"]]

# Get haplotype indices for each sample
urban_hap_indices = [k for k,v in sample_indices.items() if v["population"] == 0]
rural_hap_indices = [k for k,v in sample_indices.items() if v["population"] == 2]

## Extract site-based Fst estimates

### Estimates from `SINGER` ARGs

- Nei's site-based estimate of Fst from Bhatia (2013) and branch-based estimator from Slatkin (1991).
    - See [this GitHub issue](https://github.com/tskit-dev/tskit/issues/858)

In [None]:
# Dictionary to store results
results = {"pos_index": [], 
           "pos": [], 
           "arg_branch_fst": [], 
           "arg_site_fst": [], 
           "gt_hudson_num": [],
           "gt_hudson_denom": [],
           "gt_hudson_fst": [], 
           "gt_nei_num": [], 
           "gt_nei_denom": [], 
           "gt_nei_fst": [], 
           "sfs_hudson_num": [],
           "sfs_hudson_denom": [],
           "sfs_hudson_fst": [],
}

# Parse inputs and parameters
region = f"region{snakemake.wildcards['n']}"
n_samples = int(snakemake.params["n_samples"])
arg_path = snakemake.params['arg_path']

# Iterate over posterior ARG samples and get branch and site Fst at every site
# Stored as lists with n_samples elements, where each element is a 1D array of size ts.num_sites
branch_fsts = []
site_fsts = []
for s in range(n_samples):
    ts = tskit.load(f"{arg_path}/trees/{region}/{region}_{s}.trees")
    branch_fst = ts.Fst([urban_hap_indices, rural_hap_indices], mode="branch", windows="sites", span_normalise=False)
    branch_fsts.append(branch_fst)
    site_fst = ts.Fst([urban_hap_indices, rural_hap_indices], mode="site", windows="sites", span_normalise=False)
    site_fsts.append(branch_fst)

In [None]:
ts.num_sites

In [None]:
# Get mean Fst of each site across samples
mean_branch_fsts = np.nanmean(branch_fsts, axis = 0)
mean_site_fsts = np.nanmean(site_fsts, axis = 0)

In [None]:
# Add ARG-based Fst estimates to results dictionary
for i in range(ts.num_sites):
    results["pos_index"].append(i)
    results["arg_branch_fst"].append(mean_branch_fsts[i])
    results["arg_site_fst"].append(mean_site_fsts[i])

In [None]:
len(results["arg_site_fst"])

### Estimates from genotypes

- Will estimate both Nei's and Hudson's Fst from Bhatia (2013)

In [None]:
# Load urban and rural VCFs
urban_vcf = cyvcf2.VCF(snakemake.input["vcf"][0], samples = bams.Sample[bams["Habitat"] == "Urban"].tolist())
rural_vcf = cyvcf2.VCF(snakemake.input["vcf"][0], samples = bams.Sample[bams["Habitat"] == "Rural"].tolist())

In [None]:
# Create dictionaries with SNP positions as keys and alternative allele frequencies as values
# Do this for both urban and rural populations
urban_af_dict = {f"{snp.CHROM}:{snp.POS}": [snp.aaf] for snp in urban_vcf}
rural_af_dict = {f"{snp.CHROM}:{snp.POS}": [snp.aaf] for snp in rural_vcf}

In [None]:
len(urban_af_dict.keys())

In [None]:
n_urban = len(urban_sample_indices)
n_rural = len(rural_sample_indices)

for snp, af in urban_af_dict.items():
    u_af = af[0]  # Allele frequency in urban population
    r_af = rural_af_dict[snp][0]  # Allele frequency at same site in rural population

    # Hudson's Fst
    a = (u_af - r_af) ** 2
    b = (u_af * (1 - u_af)) / (n_urban - 1)
    c = (r_af * (1 - r_af)) / (n_rural - 1)
    d = (u_af * (1 - r_af))
    e = (r_af * (1 - u_af))
    gt_hudson_numerator = a - b - c
    gt_hudson_denominator = d + e  # Will be 0 if u_af = r_af = 0
    if gt_hudson_denominator == 0:
        gt_hudson_fst = np.nan  # To prevent division by zero. Set to NaN instead of 0
    else:
        gt_hudson_fst = gt_hudson_numerator / gt_hudson_denominator
    
    # Nei's Fst
    avg_af = (u_af + r_af) / 2
    gt_nei_numerator = (u_af - r_af) ** 2
    gt_nei_denominator = avg_af * (1 - avg_af)
    if gt_nei_denominator == 0:
        gt_nei_fst = np.nan
    else:
        gt_nei_fst = gt_nei_numerator / gt_nei_denominator
        
    results["pos"].append(snp)
    results["gt_hudson_num"].append(gt_hudson_numerator)
    results["gt_hudson_denom"].append(gt_hudson_denominator)
    results["gt_hudson_fst"].append(gt_hudson_fst)
    results["gt_nei_num"].append(gt_nei_numerator)
    results["gt_nei_denom"].append(gt_nei_denominator)
    results["gt_nei_fst"].append(gt_nei_fst)

## SFS-based Fst estimates

- Estimates Hudson's Fst

In [None]:
# Load ANGSD SFS readable Hudson Fst data as pandas dictionary
chrom = results["pos"][0].split(":")[0]
hudson_sfs_fst_path = [p for p in snakemake.input["sfs_fst"] if chrom in p][0]
hudson_sfs_fst_df = pd.read_table(hudson_sfs_fst_path, sep='\t', usecols=[1,2,3], names = ["pos", "num", "denom"]).set_index('pos')
hudson_sfs_fst_dict = hudson_sfs_fst_df.to_dict(orient='index')

In [None]:
# Get Hudson Fst numerators, denominators, and Fst for SFS positions overlapping Genotypes/ARGs
for pos in results["pos"]:
    position = int(pos.split(":")[1])
    if position in hudson_sfs_fst_dict.keys():
        sfs_hudson_numerator = hudson_sfs_fst_dict[position]["num"]
        sfs_hudson_denominator = hudson_sfs_fst_dict[position]["denom"]
        if sfs_hudson_denominator == 0:
            sfs_hudson_fst = np.nan
        else:
            sfs_hudson_fst = sfs_hudson_numerator / sfs_hudson_denominator
    else:
        sfs_hudson_numerator = np.nan
        sfs_hudson_denominator = np.nan
        sfs_hudson_fst = np.nan
    
    results["sfs_hudson_num"].append(sfs_hudson_numerator)
    results["sfs_hudson_denom"].append(sfs_hudson_denominator)
    results["sfs_hudson_fst"].append(sfs_hudson_fst)

In [None]:
df_out = pd.DataFrame(results)
df_out["regionID"] = snakemake.wildcards["n"]
df_out.to_csv(snakemake.output[0], index = False)