In [2]:
%run setup.ipynb
%run peak-utils.ipynb

In [3]:
@functools.lru_cache(maxsize=None)
def peak_scan(pop, chromosome, 
              filter_size=20, 
              filter_t=2, 
              gflanks=(4, 8), 
              scan_interval=1,
              min_aic=500,
              min_baseline=0,
              max_baseline_percentile=95,
              min_amplitude=0.03,
              init_amplitude=0.5,
              max_amplitude=1.5,
              min_decay=0.1,
              init_decay=0.5,
              max_abs_skew=0.5,
              diagnostics=False,
              scan_start=None,
              scan_stop=None,
              ):

    # load gwss data
    pwindows, gwindows, _, h12, _, _ = load_h12_gwss(pop, chromosome)
    signal = h12
    ppos = pwindows.mean(axis=1)
    gpos = gwindows.mean(axis=1)

    # filter outliers
    signal_filtered = hampel_filter(signal, size=filter_size, t=filter_t)

    # diagnostics
    if diagnostics:
        fig, axs = plt.subplots(nrows=2, ncols=1, figsize=(10, 4), facecolor='w')
        axs[0].plot(gpos, signal, marker='o', linestyle=' ', markersize=1)
        axs[0].set_title('Original signal')
        axs[1].plot(gpos, signal_filtered, marker='o', linestyle=' ', markersize=1)
        axs[1].set_title('Filtered signal')
        fig.tight_layout()
        plt.show()
        plt.close()

    # set parameters
    init_baseline = np.median(signal_filtered)
    max_baseline = np.percentile(signal_filtered, max_baseline_percentile)
    min_skew, init_skew, max_skew = -max_abs_skew, 0, max_abs_skew
    
    # setup output
    records = []
    
    # iterate through genome
    if not scan_start:
        scan_start = 2
    if not scan_stop:
        scan_stop = gmap[chromosome][-1] - 2
    for gcenter in np.arange(scan_start, scan_stop, scan_interval):
    
        for gflank in gflanks:

            # locate region to fit
            loc_region = slice(bisect_left(gpos, gcenter - gflank), 
                               bisect_right(gpos, gcenter + gflank))

            # setup data to fit
            x = gpos[loc_region]
            y = signal_filtered[loc_region]

            # fit peak model
            peak_model = lmfit.Model(skewed_exponential_peak)
            peak_params = lmfit.Parameters()
            peak_params['center'] = lmfit.Parameter('center', vary=True, 
                                                    value=gcenter, 
                                                    min=gcenter - gflank, 
                                                    max=gcenter + gflank)
            peak_params['amplitude'] = lmfit.Parameter('amplitude', vary=True, 
                                                       value=init_amplitude, 
                                                       min=min_amplitude, 
                                                       max=max_amplitude)
            peak_params['decay'] = lmfit.Parameter('decay', 
                                                   vary=True, 
                                                   value=init_decay, 
                                                   min=min_decay, 
                                                   max=gflank/3)
            peak_params['skew'] = lmfit.Parameter('skew', 
                                                  vary=True, 
                                                  value=init_skew, 
                                                  min=min_skew, 
                                                  max=max_skew)
            peak_params['baseline'] = lmfit.Parameter('baseline', vary=True, 
                                                      value=init_baseline, 
                                                      min=min_baseline, 
                                                      max=max_baseline)
            peak_params['ceiling'] = lmfit.Parameter('ceiling', vary=False, value=1)
            peak_params['floor'] = lmfit.Parameter('floor', vary=False, value=0)
            peak_result = peak_model.fit(y, x=x, params=peak_params)

            # fit null model
            null_model = lmfit.models.ConstantModel()
            null_params = lmfit.Parameters()
            null_params['c'] = lmfit.Parameter('c', vary=True, value=init_baseline, min=0, max=1)
            null_result = null_model.fit(y, x=x, params=null_params)

            # compute fit
            peak_delta_i = int(null_result.aic - peak_result.aic)

            # determine if we want to emit a result - we will do this if delta_i is above threshold
            # and also only if the fitted peak center is within the scan interval - if it is beyond, then 
            # we will get a better fit in a different scan interval
            fit_gcenter = peak_result.params['center'].value
            peak_in_scan_interval = ((gcenter - scan_interval) < fit_gcenter < (gcenter + scan_interval))
            
            if peak_delta_i > min_aic and peak_in_scan_interval:
                
                fit_params = peak_result.params
                fit_skew = fit_params['skew'].value
                fit_decay = fit_params['decay'].value
                decay_right = 2**(-fit_skew) * fit_decay
                decay_left = 2**fit_skew * fit_decay
                span1_gstart = fit_gcenter - 1*decay_left
                span1_gstop = fit_gcenter + 1*decay_right
                span2_gstart = fit_gcenter - 2*decay_left
                span2_gstop = fit_gcenter + 2*decay_right
                
                # chromosome physical position
                fit_pcenter = bisect_left(gmap[chromosome], fit_gcenter)
                span1_pstart = bisect_left(gmap[chromosome], span1_gstart)
                span1_pstop = bisect_right(gmap[chromosome], span1_gstop)
                span2_pstart = bisect_left(gmap[chromosome], span2_gstart)
                span2_pstop = bisect_right(gmap[chromosome], span2_gstop)
                
                # chromosome arm physical position
                chromosome_arm, fit_pcenter_arm = chrom2arm(chromosome, fit_pcenter)
                
                # locus
                locus = None
                for gene in ir_genes + [tep1]:
                    if gene.chromosome == chromosome:
                        disjoint = (gene.chromosome_end < span1_pstart or gene.chromosome_start > span1_pstop)
                        if not disjoint:
                            locus = gene['Name']
                for k, (c, p) in novel_loci.items():
                    if c == chromosome and span1_pstart < p < span1_pstop:
                        locus = k
                
                # max value, pos max value (genetic, physical)
                loc_peak1 = slice(bisect_left(x, span1_gstart), 
                                  bisect_right(x, span2_gstop))
                x_peak1 = x[loc_peak1]
                y_peak1 = y[loc_peak1]
                loc_max = np.argmax(y_peak1)
                gpos_max = x_peak1[loc_max]
                ppos_max = bisect_left(gmap[chromosome], gpos_max)
                _, ppos_max_arm = chrom2arm(chromosome, ppos_max)
                signal_max = y_peak1[loc_max]
                
                record = dict(
                    pop=pop,
                    chromosome=chromosome,
                    gcenter=fit_gcenter,
                    pcenter=fit_pcenter,
                    delta_i=peak_delta_i,
                    signal_max=signal_max,
                    locus=locus,
                    chromosome_arm=chromosome_arm,
                    pcenter_arm=fit_pcenter_arm,
                    gpos_max=gpos_max,
                    ppos_max=ppos_max,
                    ppos_max_arm=ppos_max_arm,
                    span1_gstart=span1_gstart,
                    span1_gstop=span1_gstop,
                    span2_gstart=span2_gstart,
                    span2_gstop=span2_gstop,
                    span1_pstart=span1_pstart,
                    span1_pstop=span1_pstop,
                    span2_pstart=span2_pstart,
                    span2_pstop=span2_pstop,
                    amplitude=fit_params['amplitude'].value,
                    decay=fit_decay,
                    skew=fit_skew,
                    decay_left=decay_left,
                    decay_right=decay_right,
                    baseline=fit_params['baseline'].value,
                    aic=peak_result.aic,
                    bic=peak_result.bic,
                    rss=peak_result.chisqr,
                    constant_aic=null_result.aic,
#                     params=fit_params,
#                     result=peak_result,
                )
                records.append(record)
                
                if diagnostics:
                    fig, ax = plt.subplots(facecolor='w', figsize=(8, 4))
                    peak_result.plot_fit(
                        ax=ax, 
                        xlabel=f'Chromosome {chromosome} position (cM)', 
                        ylabel='$H12$',
                        data_kws=dict(markersize=2), 
                        fit_kws=dict(color='k', linestyle='--')
                    )
                    ax.axvline(gcenter, color='w', lw=2, zorder=0)
                    ax.axvspan(fit_gcenter - decay_left, fit_gcenter + decay_right, zorder=0, color='red', alpha=.2)
                    ax.axvspan(fit_gcenter - 2*decay_left, fit_gcenter + 2*decay_right, zorder=0, color='red', alpha=.2)
                    ax.axvline(fit_gcenter, color='red', lw=2, zorder=0)
                    ax.annotate(
                        f'$AIC={peak_result.aic:.0f}$\n' +
                        f'$BIC={peak_result.bic:.0f}$\n' +
                        f'$\\chi^{2}={peak_result.chisqr:.3f}$\n' +
                        f'$\\Delta_{{i}}={null_result.aic - peak_result.aic:.0f}$',
                        xy=(0, 1), xycoords='axes fraction',
                        xytext=(5, -5), textcoords='offset points',
                        va='top', ha='left', fontsize=8,
                    )
                    ax.set_xlim(gcenter - gflank, gcenter + gflank)
                    fig.tight_layout()
                    plt.show()
                    plt.close()
                    print(peak_result.fit_report())
                    
    return pd.DataFrame.from_records(records)


In [4]:
def dedup_peaks(df_peaks):
    
    keep = list(range(len(df_peaks)))
    for i, this in df_peaks.iterrows():
        for j, that in df_peaks.iterrows():
            if i != j:
                # thank you Ned Batchelder
                # https://nedbatchelder.com/blog/201310/range_overlap_in_two_compares.html
                disjoint = that.span1_gstart > this.span1_gstop or that.span1_gstop < this.span1_gstart
                if not disjoint and that.delta_i > this.delta_i:
                    keep.remove(i)
                    break
    return keep
    

In [5]:
def analyse_population(pop, min_aic=2000, diagnostics=False):
    
    dfs = []
    
    for chromosome in '23X':
        
        df_peaks = df_peaks = peak_scan(pop, chromosome, min_aic=min_aic, diagnostics=diagnostics)
        keep = dedup_peaks(df_peaks)
        df_peaks_dedup = df_peaks.iloc[keep]
        dfs.append(df_peaks_dedup)

    df = pd.concat(dfs).reset_index(drop=True)
    df.to_csv(here() / f'data/signals/signals_{pop}.csv')
    
#     with open(here() / f'tables/signals_{pop}.tex', mode='w') as f:
#         (
#             df
#             [['chromosome_arm', 'pcenter_arm', 'ppos_max_arm', 'locus', 'delta_i', 'signal_max']]
#             .fillna('-')
#             .rename(columns={
#                 'chromosome_arm': 'Chrom',
#                 'pcenter_arm': '$pos(peak)$',
#                 'ppos_max_arm': '$pos(H12_{max})$',
#                 'locus': 'Locus',
#                 'delta_i': '$\\Delta_{i}$',
#                 'signal_max': '$H12_{max}$',
#             })
#             .to_latex(
#                 f,
#                 index=False, 
#                 escape=False,
#                 formatters=[
#                     None, 
#                     lambda v: '{:.2f}'.format(v/1e6),
#                     lambda v: '{:.2f}'.format(v/1e6),
#                     '\\textit{{{}}}'.format,
#                     '{:,}'.format,
#                     '{:.2f}'.format
#                 ]
#             )
#         )    
    
    return df


In [9]:
analyse_population('bf_gam');

In [10]:
analyse_population('gn_gam');

In [11]:
analyse_population('cm_sav_gam');

In [12]:
analyse_population('ug_gam');

In [13]:
analyse_population('gh_gam');

In [14]:
analyse_population('ga_gam');

In [15]:
analyse_population('gq_gam');

In [16]:
analyse_population('fr_gam');

In [17]:
analyse_population('bf_col');

In [18]:
analyse_population('ci_col');

In [19]:
analyse_population('gh_col');

In [20]:
analyse_population('ao_col');

In [21]:
analyse_population('gw');

In [22]:
analyse_population('gm');