# H12 signal detection

In [1]:
# Notebook parameters. Values here are for development only and
# will be overridden when running via snakemake and papermill.
# cohort_id = 'CD-NU_Gbadolite_gamb_2015_Q3'
cohort_id = "ML-2_Kati_gamb_2014_Q3"
cohorts_analysis = "20230223"
contig = "2RL"
sample_sets = "3.0"
min_cohort_size = 20
max_cohort_size = 50
dask_scheduler = "threads"
h12_signal_detection_min_delta_aic = 1000
h12_signal_detection_min_stat_max = 0.1
h12_signal_detection_gflanks = [6]
analysis_version = "dev"

## Setup

In [None]:
import yaml
import pandas as pd
import malariagen_data
from pyprojroot import here
import numpy as np
import os
import dask

dask.config.set(scheduler=dask_scheduler)
import matplotlib.pyplot as plt
from bisect import bisect_left, bisect_right
import lmfit

%matplotlib inline
%config InlineBackend.figure_format = 'retina'
%run {here()}/workflow/notebooks/peak-utils.ipynb

In [None]:
sample_sets

In [None]:
ag3 = malariagen_data.Ag3(
    # pin the version of the cohorts analysis for reproducibility
    cohorts_analysis=cohorts_analysis,
    results_cache=(here() / "build" / "malariagen_data_cache").as_posix(),
)
ag3

In [None]:
# load window sizes from output of h12-calibration
calibration_dir = f"build/{analysis_version}/h12-calibration"
with open(here() / calibration_dir / f"{cohort_id}.yaml") as calibration_file:
    calibration_params = yaml.safe_load(calibration_file)
window_size = calibration_params["h12_window_size"]
window_size

In [None]:
# load cohorts to find sample query
df_cohorts = pd.read_csv(here() / "build" / analysis_version / "cohorts.csv").set_index(
    "cohort_id"
)
cohort = df_cohorts.loc[cohort_id]
cohort

In [None]:
sample_query = cohort.sample_query
sample_query

In [None]:
if cohort.taxon == "arabiensis":
    phasing_analysis = "arab"
else:
    phasing_analysis = "gamb_colu"
phasing_analysis

## Run signal detection

In [None]:
# set parameters for signal detection
filter_size = 20  # hampel filter parameter
filter_t = 2  # hampel filter parameter
scan_interval = 1  # step in cM
min_baseline = 0
max_baseline_percentile = 95
min_amplitude = 0.03
init_amplitude = 0.5
max_amplitude = 1.5
min_decay = 0.1
init_decay = 0.5
max_abs_skew = 0.5
scan_start = None
scan_stop = None
# TODO maybe set false in production to avoid too many plots?
debug = True

In [None]:
# load gwss data
ppos, h12, _ = ag3.h12_gwss(
    contig=contig,
    window_size=window_size,
    analysis=phasing_analysis,
    sample_sets=sample_sets,
    sample_query=sample_query,
    min_cohort_size=min_cohort_size,
    max_cohort_size=max_cohort_size,
)

In [None]:
# convert to int
ppos = ppos.astype(int)

In [None]:
# map physical to genetic position
gpos = ag_p2g(contig=contig, ppos=ppos)

In [None]:
if debug:
    # quick check of genetic map
    fig, ax = plt.subplots(figsize=(4, 4))
    ax.plot(ppos, gpos)
    ax.set_xlabel("Physical position (bp)")
    ax.set_ylabel("Genetic position (cM)")
    ax.set_title(contig)
    fig.tight_layout()


In [None]:
# filter outliers
h12_filtered = hampel_filter(h12, size=filter_size, t=filter_t)

In [None]:
if debug:
    # before filtering
    fig, ax = plt.subplots(figsize=(8, 2))
    ax.plot(
        gpos,
        h12,
        marker="o",
        linestyle=" ",
        mfc="none",
        markersize=1,
        mew=0.5,
        color="k",
    )
    ax.set_title("Unfiltered")
    ax.set_ylim(0, 1)
    fig.tight_layout()

    # after filtering
    fig, ax = plt.subplots(figsize=(8, 2))
    ax.plot(
        gpos,
        h12_filtered,
        marker="o",
        linestyle=" ",
        mfc="none",
        markersize=1,
        mew=0.5,
        color="k",
    )
    ax.set_title("Filtered")
    ax.set_ylim(0, 1)
    fig.tight_layout()


In [None]:
# set parameters
init_baseline = np.median(h12_filtered)
max_baseline = np.percentile(h12_filtered, max_baseline_percentile)
min_skew, init_skew, max_skew = -max_abs_skew, 0, max_abs_skew
if not scan_start:
    scan_start = 2
if not scan_stop:
    scan_stop = gpos[-1] - 2

In [None]:
# set up results
results = []

# main loop, iterate along the genome
for gcenter in np.arange(scan_start, scan_stop, scan_interval):
    for gflank in h12_signal_detection_gflanks:
        # print('center', gcenter, 'flank size', gflank)

        result = fit_exponential_peak(
            ppos=ppos,
            gpos=gpos,
            stat_filtered=h12_filtered,
            gcenter=gcenter,
            gflank=gflank,
            scan_interval=scan_interval,
            init_amplitude=init_amplitude,
            min_amplitude=min_amplitude,
            max_amplitude=max_amplitude,
            init_decay=init_decay,
            min_decay=min_decay,
            init_skew=init_skew,
            min_skew=min_skew,
            max_skew=max_skew,
            init_baseline=init_baseline,
            min_baseline=min_baseline,
            max_baseline=max_baseline,
            min_delta_aic=h12_signal_detection_min_delta_aic,
            min_stat_max=h12_signal_detection_min_stat_max,
            debug=debug,
        )

        if result is not None:
            results.append(result)

df_signals = pd.DataFrame.from_records(results)
df_signals

In [None]:
def dedup_signals(df_signals):
    keep = list(range(len(df_signals)))
    for i, this in df_signals.iterrows():
        for j, that in df_signals.iterrows():
            if i != j:
                # thank you Ned Batchelder
                # https://nedbatchelder.com/blog/201310/range_overlap_in_two_compares.html
                disjoint = (
                    that.span1_gstart > this.span1_gstop
                    or that.span1_gstop < this.span1_gstart
                )
                if not disjoint and that.delta_i > this.delta_i:
                    keep.remove(i)
                    break
    return df_signals.iloc[keep].copy()


In [None]:
df_signals_dedup = dedup_signals(df_signals)
df_signals_dedup

## Write outputs

In [None]:
outdir = f"build/{analysis_version}/h12-signal-detection"

with open(here() / outdir / f"{cohort_id}_{contig}.csv", mode="w") as output_file:
    df_signals_dedup.to_csv(output_file, index=False)