In [7]:
cohort_id = 'ML-2_Kati_gamb_2014_Q3'
cohorts_analysis="20230223"
contigs = ['2L']
sample_sets = ""
max_cohort_size = 50
h12_calibration_contig = ''

In [8]:
from IPython.display import Markdown
import malariagen_data
import pandas as pd
from pyprojroot import here
import yaml

import geopandas as gpd
import bokeh.layouts as bklay
import bokeh.plotting as bkplt
import plotly.express as px

In [9]:
# load cohorts to find sample query 
df_cohorts = gpd.read_file(here() / "build" / "final_cohorts.geojson").set_index("cohort_id")
cohort = df_cohorts.loc[cohort_id]

In [11]:
ag3 = malariagen_data.Ag3(
    # TODO in production build, remove use of simplecache if running inside google cloud
    # url = "gs://vo_agam_release",
    url="simplecache::gs://vo_agam_release",
    # pin the version of the cohorts analysis for reproducibility
    cohorts_analysis=cohorts_analysis,
    # TODO remove simplecache config in production
    simplecache=dict(cache_storage=(here() / "gcs_cache").as_posix()),
    results_cache=(here() / "malariagen_data_cache").as_posix(),
)

df_samples = ag3.sample_metadata()
sample_query = f"cohort_admin2_quarter == '{cohort_id}'"
df_samples = df_samples.query(sample_query)

Load sample metadata:   0%|          | 0/28 [00:00<?, ?it/s]

In [12]:
def add_empty_months(df, cohort_id):
    quarter = cohort_id[-2:]
    q_months = {'Q1':[1,2,3],
            'Q2':[4,5,6],
            'Q3':[7,8,9],
            'Q4':[10,11,12]}

    year = df.year.unique()
    months = df['month'].tolist()
    empty_months = list(set(q_months[quarter]) - set(months))
    
    if empty_months:
        for m in empty_months:
            df = pd.concat([df, pd.DataFrame({'year': year, 'month': m, 'count': 0})])

    return(df, quarter, year[0])

df_collection_dates = df_samples.groupby(['year', 'month']).size().reset_index().rename(columns={0: 'count'})
df_collection_dates, quarter, year = add_empty_months(df_collection_dates, cohort_id)
df_collection_dates['month'] = pd.to_datetime(df_collection_dates['month'], format='%m').dt.month_name().str.slice(stop=3)
start_month = df_collection_dates.query("count > 0").month.min()
end_month = df_collection_dates.query("count > 0").month.max()

n_sites = '???'
study_id = 'who knows????'

In [None]:
Markdown(
    f"""
    This cohort comprises {cohort['cohort_size']} An. {cohort['taxon']} samples collected from {n_sites} sites within {cohort['admin1_name']} province, {cohort['admin2_name']} region, {cohort['country']}, between {start_month} and {end_month} {cohort['year']}. Samples were contributed by {study_id}. 
    """
)

## Sampling information
### Collection locations

In [None]:
from ipyleaflet import Map, Marker, basemaps

center = cohort[['latitude', 'longitude']].to_list()
m = Map(center=center, zoom=9, basemap=basemaps.OpenTopoMap)

for coh_id, row in df_samples.iterrows():
    lat, long = row[['latitude', 'longitude']]
    
    if row['taxon'] == 'gambiae':
        color= 'red'
    elif row['taxon'] == 'coluzzii':
        color='cadetblue'
    elif row['taxon'] == 'arabiensis':
        color='lightgreen'
    else: 
        color='gray'
    
    marker = Marker(location=(lat, long), draggable=True, opacity=0.7, color=color)
    m.add_layer(marker);
    
    # message2 = HTML()
    # message2.value = f'<a href="https://github.com/anopheles-genomic-surveillance/selection-atlas">{coh_id}</a>'
    # marker.popup = message2

display(m)

### Collection dates

In [None]:
px.bar(df_collection_dates, x='month', y='count', title=f"Collection dates {quarter} {year}")

In [None]:
# load window sizes 
calibration_dir = "build/h12-calibration"
with open(here() / calibration_dir / f"{cohort_id}.yaml") as calibration_file:
    calibration_params = yaml.safe_load(calibration_file)
window_size = calibration_params["h12_window_size"]

if cohort.taxon == 'arabiensis':
    phasing_analysis = 'arab'
else:
    phasing_analysis = 'gamb_colu'

if cohort.cohort_size > max_cohort_size:
    # downsampling for computational efficiency
    cohort_size = max_cohort_size
else:
    # no downsampling
    cohort_size = None 


def plot_h12_ihs_tracks(
        contig, 
        window_size, 
        phasing_analysis, 
        sample_sets, 
        sample_query, 
        cohort_size, 
        sizing_mode='stretch_width', 
        show=False, 
        width=800, 
        genes_height=100
    ):

    fig1 = ag3.plot_h12_gwss_track(
        contig=contig, 
        window_size=window_size, 
        analysis=phasing_analysis, 
        sample_sets=sample_sets,
        sample_query=sample_query, 
        cohort_size=cohort_size,
        sizing_mode=sizing_mode,
        show=show,
        width=width,
    )
    fig1.xaxis.visible = False

    fig2 = ag3.plot_h12_gwss_track(
        contig=contig, 
        window_size=window_size, 
        analysis=phasing_analysis, 
        sample_sets=sample_sets,
        sample_query=sample_query, 
        cohort_size=cohort_size,
        sizing_mode=sizing_mode,
        width=width,
        show=show,
        title="",
        x_range=fig1.x_range,
    )
    fig2.xaxis.visible = False

    fig3 = ag3.plot_genes(
        region=contig, 
        show=show,
        sizing_mode=sizing_mode,
        width=width,
        height=genes_height,
        x_range=fig1.x_range
        )
                        
    fig = bklay.gridplot(
        [fig1, fig2, fig3],
        ncols=1,
        toolbar_location="above",
        merge_tools=True,
        sizing_mode=sizing_mode,
    )
    return(fig)

## Selection scans

### Chromosome 2RL

In [None]:
fig = plot_h12_ihs_tracks(
    contig='2L',
    window_size=window_size,
    phasing_analysis=phasing_analysis,
    sample_sets=sample_sets,
    sample_query=sample_query,
    cohort_size=cohort_size)

In [None]:
bkplt.show(fig)

### Chromosome 3RL

In [None]:
fig = plot_h12_ihs_tracks(
    contig='3L',
    window_size=window_size,
    phasing_analysis=phasing_analysis,
    sample_sets=sample_sets,
    sample_query=sample_query,
    cohort_size=cohort_size)

In [None]:
bkplt.show(fig)

### Chromosome X 

In [None]:
fig = plot_h12_ihs_tracks(
    contig='X',
    window_size=window_size,
    phasing_analysis=phasing_analysis,
    sample_sets=sample_sets,
    sample_query=sample_query,
    cohort_size=cohort_size)

In [None]:
bkplt.show(fig)

## Diagnostics
### H12 calibration

In [None]:
window_sizes = (100, 200, 500, 1000, 2000, 5000, 10000, 20000)

ag3.plot_h12_calibration(
    contig=h12_calibration_contig,
    analysis=phasing_analysis,
    sample_sets=sample_sets,
    sample_query=sample_query,
    cohort_size=cohort_size,
    window_sizes=window_sizes,
)