In [None]:
contig = '2L'
cohorts_analysis = '20230223'
use_gcs_cache = False

In [None]:
from IPython.display import Markdown
import malariagen_data
import pandas as pd
from pyprojroot import here
import geopandas as gpd

import bokeh.layouts as bklay
import bokeh.plotting as bkplt
import bokeh.models as bkmod

In [None]:
Markdown(
    f"""The plot below shows selection signals discovered in the major vector species *An. gambiae*, 
    *An. coluzzii* or *An. arabiensis*, all of which are members of the *Anopheles gambiae* species complex. 
    The reference genome used for these analyses is AgamP4, from the {contig} chromosome. Hover over a 
    signal for more information about the species, location, date and selection statistic in which the signal 
    was found. Click on a signal to see the underlying selection scan data.""")

In [None]:
extra_params = dict()
if use_gcs_cache:
    extra_params["url"] = "simplecache::gs://vo_agam_release"
    extra_params["simplecache"] = dict(cache_storage=(here() / "gcs_cache").as_posix())

ag3 = malariagen_data.Ag3(
    # pin the version of the cohorts analysis for reproducibility
    cohorts_analysis=cohorts_analysis,
    results_cache=(here() / "malariagen_data_cache").as_posix(),
    **extra_params,
)
ag3

In [None]:
cohorts = gpd.read_file(here() / "build" / "final_cohorts.geojson")

df_signals = [
    pd.read_csv(here() / "build/h12-signal-detection/" / f"{row['cohort_id']}_{contig}.csv").assign(taxon=row['taxon'])
    for _, row in cohorts.iterrows()
]
df_signals = pd.concat(df_signals, axis=0).assign(statistic = "H12").sort_values('taxon')
color_dict = {'gambiae': '#BEC4FF',
             'coluzzii': '#D7B2A6',
             'arabiensis': '#A6D7CA'}

df_signals['color'] = df_signals['taxon'].map(color_dict)

In [None]:
def stack_overlaps(df, start_col, end_col, tolerance=10000):
    import numpy as np
    occupants = [None]
    out = []
    for _, cur in df.iterrows():

        level = 0
        prv = occupants[level]
        # search upwards to find the first vacant level
        while prv is not None and cur[start_col] <= (prv[end_col] + tolerance):
            level += 1
            if level == len(occupants):
                occupants.append(None)
            prv = occupants[level]
        occupants[level] = cur
        out.append(level)
    return np.asarray(out)

df_signals = df_signals.sort_values(by='span2_pstart')
df_signals['level'] = stack_overlaps(df_signals, 'span2_pstart', 'span2_pstop')

In [None]:
df = df_signals.reset_index()
source = bkmod.ColumnDataSource(data={
    'cohort': df.cohort_id,
    'statistic': df.statistic,
    'chromosome': df.contig,
    'score': df.delta_i.astype(int),
    'peak_start': df.span2_pstart,
    'peak_stop': df.span2_pstop,
    'focus_start': df.focus_pstart,
    'focus_stop': df.focus_pstop,    
    'bottom': df.level,
    'top': df.level + .8,
    'color':df.color
})

hover = bkmod.HoverTool(tooltips=[
        ("Cohort", '@cohort'),
        ("Statistic", '@statistic'),
        ("Score", '@score'),
        ("Focus", "@focus_start{,} - @focus_stop{,}"),
    ])

# make figure 
fig1 = bkplt.figure(title='Selection signals',
                  plot_width=900, plot_height=200 + (10 * max(df.level)), 
                  tools="tap,xpan,xzoom_in,xzoom_out,xwheel_zoom,reset".split() + [hover],
                  toolbar_location='above', active_drag='xpan', active_scroll='xwheel_zoom')

fig1.quad(bottom='bottom', top='top', left='peak_start', right='focus_start', 
          source=source, color="color", alpha=.7, line_width=2)

fig1.quad(bottom='bottom', top='top', left='focus_start', right='focus_stop', 
          source=source, color="red", alpha=.7, line_width=2)

fig1.quad(bottom='bottom', top='top', left='focus_stop', right='peak_stop', 
          source=source, color="color", alpha=.7, line_width=2)

fig1.x_range = bkmod.Range1d(0, ag3.genome_sequence(contig).shape[0])
fig1.y_range = bkmod.Range1d(-0.5, max(df.level) + 1.3)
fig1.x_range.max_interval = ag3.genome_sequence(contig).shape[0]
fig1.yaxis.visible = False
fig1.xaxis.visible = False
fig1.ygrid.visible = False

url = '../cohort/@cohort.html'
taptool = fig1.select(type=bkmod.TapTool)
taptool.callback = bkmod.OpenURL(url=url)

fig2 = ag3.plot_genes(
    region=contig, 
    sizing_mode="stretch_width",
    x_range=fig1.x_range,
    show=False)

fig = bklay.gridplot(
    [fig1, fig2],
    ncols=1,
    toolbar_location="above",
    merge_tools=True,
    sizing_mode="stretch_width",
) 

bkplt.show(fig)

In [None]:
df_signals = df_signals.merge(cohorts)[['contig', 'focus_pstart', 'focus_pstop', 'cohort_label', 'statistic', 'delta_i']]
df_signals = df_signals.assign(focal_region=
                               df_signals['contig'] + ' ( ' + 
                               df_signals['focus_pstart'].apply(lambda x: "{:,}".format(x, axis=1)) + ' - ' +
                               df_signals['focus_pstop'].apply(lambda x: "{:,}".format(x, axis=1)) + " )")
df_signals[['focal_region', 'cohort_label', 'statistic', 'delta_i']]