In [None]:
import tskit
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display

In [None]:
sample_df = pd.read_table(snakemake.config["samples"], sep="\t")
sample_df = sample_df[["Sample", "Habitat"]]

bams = pd.read_table(snakemake.input["bams"][0], names=["bam"])
bams["Sample"] = bams["bam"].str.extract("(s_\\d+_\\d+)")
bams = bams.merge(sample_df, on = "Sample", how = "left")[["Sample", "Habitat"]]

urban_sample_indices = bams.index[bams["Habitat"] == "Urban"].tolist()
rural_sample_indices = bams.index[bams["Habitat"] == "Rural"].tolist()
suburban_sample_indices = bams.index[bams["Habitat"] == "Suburban"].tolist()

habitat_pop_map = {"Urban": 0, "Suburban": 1, "Rural": 2}
sample_indices = {k: {} for k in range(len(bams) * 2)}
for i, r in bams.iterrows():
    sample_indices[i*2]["population"] = habitat_pop_map[r["Habitat"]]
    sample_indices[(i*2) + 1]["population"] = habitat_pop_map[r["Habitat"]]

In [None]:
regions_df = pd.read_table(snakemake.input["regions"][0])
region_ids = regions_df["regionID"].tolist()

In [None]:
rural_region_ids = regions_df[regions_df["direction"] == "Rural sel"]["regionID"].tolist()
urban_region_ids = regions_df[regions_df["direction"] == "Urban sel"]["regionID"].tolist()
unsel_region_ids = regions_df[regions_df["direction"] == "Not outlier"]["regionID"].tolist()

In [None]:
arg_path = "/scratch/projects/trifolium/gwsd/results/args"
plt.rcParams['figure.figsize'] = [15, 5]
for id in region_ids:
    ts = tskit.load(f"{arg_path}/trees/region{id}/region{id}_199.trees")
    
    window_size = ts.sequence_length / 40
    step = ts.sequence_length / 200
    win_starts = np.linspace(0, ts.sequence_length, int((ts.sequence_length + step) / step))
    windows = []
    for i in range(len(win_starts)):
        start = win_starts[i]
        end = start + window_size
        win = [start, end]
        if max(win) <= ts.sequence_length:
            windows.append(win)
            
    fsts = []
    win_centers = []
    for win in windows:
        new_ts = ts.keep_intervals([win], simplify=False)
        fst = new_ts.Fst([urban_sample_indices, rural_sample_indices],/mode = "branch")
        fsts.append(fst)
        win_center = win[0] + ((win[1] - win[0]) / 2)
        win_centers.append(win_center)
    if id in rural_region_ids:
        col = "blue"
    elif id in urban_region_ids:
        col = "blue"
    else:
        col = "black"
    plt.scatter(win_centers, fsts, color = col, s = 4)
    plt.xlabel("Genome position (kb)")
    plt.ylabel("Fst")
# plt.vlines(x = 475000, ymin = 0, ymax = 0.07, linewidth = 2, color = 'red', linestyles = "dashed")
# plt.vlines(x = 525000, ymin = 0, ymax = 0.07, linewidth = 2, color = 'red', linestyles = "dashed")
plt.show()