# H12 signal detection

In [None]:
# Notebook parameters. Values here are for development only and
# will be overridden when running via snakemake and papermill.

config_file = "../../../config/agam.yaml"
cohort_id = "BF-09_Houet_colu_2012_Q3"
contig = "2R"

## Setup

In [None]:
import yaml
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from selection_atlas.setup import AtlasSetup
from selection_atlas import peak_utils

# Initialise the atlas setup.
setup = AtlasSetup(config_file)

%matplotlib inline
%config InlineBackend.figure_format = 'retina'

In [None]:
# Load window sizes.
h12_calibration_file = setup.h12_calibration_files.as_posix().format(cohort=cohort_id)
with open(h12_calibration_file) as calibration_file:
    calibration_params = yaml.safe_load(calibration_file)
window_size = calibration_params["h12_window_size"]
window_size

In [None]:
# load cohorts to find sample query
df_cohorts = pd.read_csv(setup.cohorts_file).set_index("cohort_id")
cohort = df_cohorts.loc[cohort_id]
cohort

In [None]:
cohort_query = cohort.sample_query
cohort_query

In [None]:
phasing_analysis = setup.taxon_phasing_analysis[cohort.taxon]
phasing_analysis

## Run signal detection

In [None]:
# set parameters for signal detection
filter_size = 20  # hampel filter parameter
filter_t = 2  # hampel filter parameter
scan_interval = 1  # step in cM
min_baseline = 0
max_baseline_percentile = 95
min_amplitude = 0.03
init_amplitude = 0.5
max_amplitude = 1.5
min_decay = 0.1
init_decay = 0.5
max_abs_skew = 0.5
scan_start = None
scan_stop = None
debug = True

In [None]:
# load gwss data
ppos, h12, _ = setup.malariagen_api.h12_gwss(
    contig=contig,
    window_size=window_size,
    analysis=phasing_analysis,
    sample_sets=setup.sample_sets,
    sample_query=cohort_query,
    min_cohort_size=setup.min_cohort_size,
    max_cohort_size=setup.max_cohort_size,
)

In [None]:
# convert to int
ppos = ppos.astype(int)
ppos

In [None]:
# map physical to genetic position
gpos = peak_utils.p2g(setup, contig=contig, ppos=ppos)
gpos

In [None]:
if debug:
    # quick check of genetic map
    fig, ax = plt.subplots(figsize=(4, 4))
    ax.plot(ppos, gpos)
    ax.set_xlabel("Physical position (bp)")
    ax.set_ylabel("Genetic position (cM)")
    ax.set_title(contig)
    fig.tight_layout()

In [None]:
# filter outliers
h12_filtered = peak_utils.hampel_filter(h12, size=filter_size, t=filter_t)

In [None]:
if debug:
    # before filtering
    fig, ax = plt.subplots(figsize=(8, 2))
    ax.plot(
        gpos,
        h12,
        marker="o",
        linestyle=" ",
        mfc="none",
        markersize=1,
        mew=0.5,
        color="k",
    )
    ax.set_title("Unfiltered")
    ax.set_ylim(0, 1)
    fig.tight_layout()

    # after filtering
    fig, ax = plt.subplots(figsize=(8, 2))
    ax.plot(
        gpos,
        h12_filtered,
        marker="o",
        linestyle=" ",
        mfc="none",
        markersize=1,
        mew=0.5,
        color="k",
    )
    ax.set_title("Filtered")
    ax.set_ylim(0, 1)
    fig.tight_layout()

In [None]:
# set parameters
init_baseline = np.median(h12_filtered)
max_baseline = np.percentile(h12_filtered, max_baseline_percentile)
min_skew, init_skew, max_skew = -max_abs_skew, 0, max_abs_skew
if not scan_start:
    scan_start = 2
if not scan_stop:
    scan_stop = gpos[-1] - 2

In [None]:
# set up results
results = []

# main loop, iterate along the contig
for gcenter in np.arange(scan_start, scan_stop, scan_interval):
    for gflank in setup.h12_signal_detection_gflanks:
        # print('center', gcenter, 'flank size', gflank)

        result = peak_utils.fit_exponential_peak(
            setup=setup,
            contig=contig,
            cohort_id=cohort_id,
            gpos=gpos,
            stat_filtered=h12_filtered,
            gcenter=gcenter,
            gflank=gflank,
            scan_interval=scan_interval,
            init_amplitude=init_amplitude,
            min_amplitude=min_amplitude,
            max_amplitude=max_amplitude,
            init_decay=init_decay,
            min_decay=min_decay,
            init_skew=init_skew,
            min_skew=min_skew,
            max_skew=max_skew,
            init_baseline=init_baseline,
            min_baseline=min_baseline,
            max_baseline=max_baseline,
            min_delta_aic=setup.h12_signal_detection_min_delta_aic,
            min_stat_max=setup.h12_signal_detection_min_stat_max,
            debug=debug,
        )

        if result is not None:
            results.append(result)

df_signals = pd.DataFrame.from_records(results)
df_signals

In [None]:
def dedup_signals(df_signals):
    keep = list(range(len(df_signals)))
    for i, this in df_signals.iterrows():
        for j, that in df_signals.iterrows():
            if i != j:
                # thank you Ned Batchelder
                # https://nedbatchelder.com/blog/201310/range_overlap_in_two_compares.html
                disjoint = (
                    that.span1_gstart > this.span1_gstop
                    or that.span1_gstop < this.span1_gstart
                )
                if not disjoint and that.delta_i > this.delta_i:
                    keep.remove(i)
                    break
    return df_signals.iloc[keep].copy()

In [None]:
df_signals_dedup = dedup_signals(df_signals)
df_signals_dedup

## Write outputs

In [None]:
h12_signal_file = setup.h12_signal_files.as_posix().format(
    cohort=cohort_id, contig=contig
)
with open(h12_signal_file, mode="w") as output_file:
    df_signals_dedup.to_csv(output_file, index=False)