In [None]:
# Notebook parameters. Values here are for development only and 
# will be overridden when running via snakemake and papermill.
cohort_id = 'BF-09_Houet_colu_2012_Q3'
cohorts_analysis = "20230223"
contigs = ['3L']
sample_sets = "3.0"
min_cohort_size = 20
max_cohort_size = 50
h12_calibration_contig = '3L'
use_gcs_cache = False
dask_scheduler = "threads"

In [None]:
from IPython.display import Markdown, HTML
import malariagen_data
import pandas as pd
from pyprojroot import here
import yaml
import dask
dask.config.set(scheduler=dask_scheduler);

import geopandas as gpd
import bokeh.layouts as bklay
import bokeh.plotting as bkplt
import plotly.express as px

from bokeh.io import output_notebook # enables plot interface in J notebook

# N.B., do not add the "remove-output" tag to this cell!!! If you do,
# the bokeh javascript libraries will not get loaded in the generated
# HTML page. The call to output_notebook() injects javascript in the
# cell output which triggers the bokeh javascript libraries to be loaded
# in the page.
output_notebook(hide_banner=True)

In [None]:
# load cohorts to find sample query 
df_cohorts = gpd.read_file(here() / "build" / "final_cohorts.geojson").set_index("cohort_id")
cohort = df_cohorts.loc[cohort_id]
cohort

In [None]:
extra_params = dict()
if use_gcs_cache:
    extra_params["url"] = "simplecache::gs://vo_agam_release"
    extra_params["simplecache"] = dict(cache_storage=(here() / "gcs_cache").as_posix())

ag3 = malariagen_data.Ag3(
    # pin the version of the cohorts analysis for reproducibility
    cohorts_analysis=cohorts_analysis,
    results_cache=(here() / "malariagen_data_cache").as_posix(),
    **extra_params,
)
ag3

In [None]:
df_samples = ag3.sample_metadata()
df_samples = df_samples.query(cohort['sample_query'])

In [None]:
def add_empty_months(df, cohort_id):
    quarter = cohort_id[-2:]
    q_months = {'Q1':[1,2,3],
            'Q2':[4,5,6],
            'Q3':[7,8,9],
            'Q4':[10,11,12]}
    
    year = df.year.unique()
    months = df['month'].tolist()
    empty_months = list(set(q_months[quarter]) - set(months))
    
    if empty_months:
        for m in empty_months:
            df = pd.concat([df, pd.DataFrame({'year': year, 'month': m, 'count': 0})])

    return(df, quarter, year[0])

df_collection_dates = df_samples.groupby(['year', 'month']).size().reset_index().rename(columns={0: 'count'})
if df_collection_dates['month'].unique()[0] == -1:
    month_collections = False
else:
    month_collections = True
    df_collection_dates, quarter, year = add_empty_months(df_collection_dates, cohort_id)
    df_collection_dates['month'] = pd.to_datetime(df_collection_dates['month'], format='%m').dt.month_name()
    start_month = df_collection_dates.query("count > 0").month.min()
    end_month = df_collection_dates.query("count > 0").month.max()

In [None]:
n_sites = '???'
study_id = 'who knows????'

if month_collections:
    display(Markdown(
        f"""This cohort comprises {cohort['cohort_size']} An. {cohort['taxon']} samples 
        collected from {n_sites} sites within {cohort['admin1_name']} province, 
        {cohort['admin2_name']} region, {cohort['country']}, between {start_month} 
        and {end_month} {cohort['year']}. Samples were contributed by {study_id}."""))
else:
    display(Markdown(
        f"""This cohort comprises {cohort['cohort_size']} An. {cohort['taxon']} samples 
        collected from {n_sites} sites within {cohort['admin1_name']} province, 
        {cohort['admin2_name']} region, {cohort['country']} in {cohort['year']}. 
        Samples were contributed by {study_id}."""))


## Sampling information
### Collection locations

In [None]:
from ipyleaflet import Map, Marker, basemaps

center = cohort[['latitude', 'longitude']].to_list()
m = Map(center=center, zoom=9, basemap=basemaps.OpenTopoMap)

for coh_id, row in df_samples.iterrows():
    lat, long = row[['latitude', 'longitude']]
    
    if row['taxon'] == 'gambiae':
        color= 'red'
    elif row['taxon'] == 'coluzzii':
        color='cadetblue'
    elif row['taxon'] == 'arabiensis':
        color='lightgreen'
    else: 
        color='gray'

    marker = Marker(location=(lat, long), draggable=True, opacity=0.7, color=color)
    m.add_layer(marker);

display(m)

### Collection dates

In [None]:
if month_collections:
    fig = px.bar(df_collection_dates, x='month', y='count', title=f"Collection dates {quarter} {year}")
    display(fig)
else:
    display(Markdown("No per-month collection data exists for this cohort."))

## Selection scans

In [None]:
# load window sizes 
calibration_dir = "build/h12-calibration"
with open(here() / calibration_dir / f"{cohort_id}.yaml") as calibration_file:
    calibration_params = yaml.safe_load(calibration_file)
window_size = calibration_params["h12_window_size"]

if cohort.taxon == 'arabiensis':
    phasing_analysis = 'arab'
else:
    phasing_analysis = 'gamb_colu'

ihs_window_size = 100

def plot_h12_g123_ihs_tracks(
        contig, 
        window_size, 
        ihs_window_size,
        phasing_analysis, 
        sample_sets, 
        sample_query, 
        min_cohort_size,
        max_cohort_size,
        sizing_mode='stretch_width', 
        show=False, 
        width=800, 
        genes_height=100
    ):

    fig1 = ag3.plot_h12_gwss_track(
        contig=contig, 
        window_size=window_size, 
        analysis=phasing_analysis, 
        sample_sets=sample_sets,
        sample_query=sample_query, 
        min_cohort_size=min_cohort_size,
        max_cohort_size=max_cohort_size,
        sizing_mode=sizing_mode,
        show=show,
        width=width,
    )
    fig1.xaxis.visible = False

    fig2 = ag3.plot_g123_gwss_track(
        contig=contig, 
        # TODO Calibrate G123 window size separately?
        window_size=window_size, 
        site_mask=phasing_analysis, 
        sample_sets=sample_sets,
        sample_query=sample_query, 
        min_cohort_size=min_cohort_size,
        max_cohort_size=max_cohort_size,
        sizing_mode=sizing_mode,
        width=width,
        show=show,
        title="",
        x_range=fig1.x_range,
    )
    fig2.xaxis.visible = False

    fig3 = ag3.plot_ihs_gwss_track(
        contig=contig, 
        window_size=ihs_window_size, 
        analysis=phasing_analysis, 
        sample_sets=sample_sets,
        sample_query=sample_query, 
        min_cohort_size=min_cohort_size,
        max_cohort_size=max_cohort_size,
        sizing_mode=sizing_mode,
        width=width,
        show=show,
        title="",
        x_range=fig1.x_range,
    )
    fig3.xaxis.visible = False


    fig4 = ag3.plot_genes(
        region=contig, 
        show=show,
        sizing_mode=sizing_mode,
        width=width,
        height=genes_height,
        x_range=fig1.x_range
        )
                        
    fig = bklay.gridplot(
        [fig1, fig2, fig3, fig4],
        ncols=1,
        toolbar_location="above",
        merge_tools=True,
        sizing_mode=sizing_mode,
    )
    return fig 

In [None]:
for contig in contigs:
    
    display(HTML(f"<h3>Chromosome {contig}</h3>"))
    
    fig = plot_h12_g123_ihs_tracks(
        contig=contig,
        window_size=window_size,
        ihs_window_size=ihs_window_size,
        phasing_analysis=phasing_analysis,
        sample_sets=sample_sets,
        sample_query=cohort['sample_query'],
        min_cohort_size=min_cohort_size,
        max_cohort_size=max_cohort_size,
    );

    bkplt.show(fig)


## Diagnostics
### H12 calibration

In [None]:
window_sizes = (100, 200, 500, 1000, 2000, 5000, 10000, 20000)

ag3.plot_h12_calibration(
    contig=h12_calibration_contig,
    analysis=phasing_analysis,
    sample_sets=sample_sets,
    sample_query=cohort['sample_query'],
    min_cohort_size=min_cohort_size,
    max_cohort_size=max_cohort_size,
    window_sizes=window_sizes,
);

### G123 Calibration

In [None]:
# ag3.plot_g123_calibration(
#     contig=h12_calibration_contig,
#     sites=phasing_analysis,
#     site_mask=phasing_analysis,
#     sample_sets=sample_sets,
#     sample_query=cohort['sample_query'],
#     min_cohort_size=min_cohort_size,
#     max_cohort_size=max_cohort_size,
#     window_sizes=window_sizes,
# );

In [None]:
print("TODO")