In [1]:
!pip install -Uq malariagen_data

In [2]:
import logging
logging.getLogger("distributed.scheduler").setLevel(logging.ERROR)
logging.getLogger("distributed.core").setLevel(logging.ERROR)
logging.getLogger("distributed.deploy.adaptive").setLevel(logging.ERROR)
logging.getLogger("distributed.utils_perf").setLevel(logging.ERROR)
logging.getLogger("distributed.batched").setLevel(logging.ERROR)
# from dask_kubernetes import KubeCluster
# from dask.distributed import Client
# cluster = KubeCluster(n_workers=20, 
#                       env={'EXTRA_PIP_PACKAGES': 'malariagen_data'})
# client = Client(cluster)
# client


In [3]:
from dask.distributed import Client

client = Client("tcp://10.35.97.8:41445")
client

0,1
Client  Scheduler: tcp://10.35.97.8:41445  Dashboard: /user/alimanfoo@googlemail.com/proxy/8787/status,Cluster  Workers: 0  Cores: 0  Memory: 0 B


In [4]:
import functools
import bisect
import allel
import numpy as np
import dask.array as da
from dask.diagnostics import ProgressBar
import malariagen_data
import bokeh.io
import bokeh.plotting
import bokeh.models
import bokeh
import seaborn as sns
from bokeh.core.enums import MarkerType
from matplotlib.colors import to_hex
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

In [5]:
bokeh.io.output_notebook()

In [6]:
# ProgressBar().register()

In [7]:
ag3 = malariagen_data.Ag3("gs://vo_agam_release/")

In [8]:
@functools.lru_cache(maxsize=None)
def run_pca(
    contig="3L", 
    region_start=15_000_000, 
    region_stop=41_000_000,
    sample_sets="v3_wild",
    sample_query=None,
    site_mask="gamb_colu_arab",
    min_ac=3,
    max_an_missing=0,
    n_snps=100_000,
    snp_offset=0):
    
    # load sample metadata
    df_samples = ag3.sample_metadata(sample_sets=sample_sets)
    
    # load SNP positions
    pos = ag3.snp_sites(contig=contig, field='POS', site_mask=site_mask).compute()
    
    # access SNP genotypes
    gt = ag3.snp_genotypes(contig=contig, sample_sets=sample_sets, site_mask=site_mask)

    # locate genome region 
    if region_start or region_stop:
        loc_region = slice(
            bisect.bisect_left(pos, region_start) if region_start else None,
            bisect.bisect_right(pos, region_stop) if region_stop else None,
        )
        gt = gt[loc_region]
    
    # locate selected samples
    if sample_query:
        loc_samples = df_samples.eval(sample_query).values
        df_samples = df_samples.loc[loc_samples, :]
        gt = da.compress(loc_samples, gt, axis=1)
        
    # perform allele count
    ac = allel.GenotypeDaskArray(gt).count_alleles(max_allele=3).compute()
    
    # calculate some convenience variables
    n_chroms = gt.shape[1] * 2
    an_called = ac.sum(axis=1)
    an_missing = n_chroms - an_called
    
    # locate segregating sites above threshold frequency
    max_ac = n_chroms - min_ac
    # here we choose biallelic sites involving the reference allele
    loc_seg = np.nonzero(ac.is_biallelic() & 
                         (ac[:, 0] >= min_ac) & 
                         (ac[:, 0] <= max_ac) & 
                         (an_missing <= max_an_missing))[0]
    
    # thin SNPs to desired number
    step = loc_seg.shape[0] // n_snps
    loc_seg_ds = loc_seg[snp_offset::step]

    # subset genotypes to selected sites
    gt_seg = da.take(gt, loc_seg_ds, axis=0)
    
    # convert to genotype alt counts
    gn_seg = allel.GenotypeDaskArray(gt_seg).to_n_alt().compute()
    
    # remove any edge-cases where all genotypes are identical
    loc_var = np.any(gn_seg != gn_seg[:, 0, np.newaxis], axis=1)
    gn_var = np.compress(loc_var, gn_seg, axis=0)

    # run PCA
    coords, model = allel.pca(gn_var)
    
    return df_samples, coords, model
    

In [139]:
species_markers = {
    'arabiensis': 'triangle',
    'gambiae': 'circle',
    'coluzzii': 'square',
    'intermediate_arabiensis_gambiae': 'plus', 
    'intermediate_gambiae_coluzzii': 'star',
}


def plot_pca(
    df_samples,
    coords,
    pcx=1, 
    pcy=2, 
    color_field='country',
    colors='colorblind',
    marker_field='species',
    markers=species_markers,
    width=500,
    height=300,
    title='PCA',
    marker_size=8,
    ):

    # copy and shuffle data so we don't get overplotting
    data = df_samples.copy()
    data['x'] = coords[:, pcx - 1]
    data['y'] = coords[:, pcy - 1]
    data = data.sample(frac=1)
    
    # markers
    if isinstance(markers, str):
        data['marker'] = markers
    elif isinstance(markers, dict):
        markers_col = [markers[v] for v in df_samples[marker_field]]
        data['marker'] = markers_col

    # color by whatever you ask for
    if isinstance(colors, str):
        color_keys = df_samples[color_field].unique().tolist()
        color_palette = sns.color_palette(colors, n_colors=len(color_keys))
        colors = dict(zip(color_keys, map(to_hex, color_palette)))
    colors_col = [colors[v] for v in df_samples[color_field]]
    data['color'] = colors_col
    
    source = bokeh.plotting.ColumnDataSource(data)
    tools = "pan,wheel_zoom,box_zoom,reset,hover,save"
    fig = bokeh.plotting.figure(title=title, tools=tools, active_scroll="wheel_zoom", 
                                width=width, height=height)
    fig.scatter('x', 'y', marker="marker", size=marker_size, source=source, 
                line_color="black", fill_color='color', alpha=.9, legend_field='country')
    
    # setup hover tooltips
    hover = fig.select(dict(type=bokeh.models.HoverTool))
    hover.tooltips = {
        "sample_id": "@sample_id",
        "species": "@species",
        "country": "@country",
        "location": "@location",
        "year": "@year",
        "sample_set": "@sample_set",
    }
    
    # axis labels
    fig.xaxis.axis_label = "PC{}".format(pcx)
    fig.yaxis.axis_label = "PC{}".format(pcy)
    
    # color legend
    fig.legend.visible = False
    legend = bokeh.models.Legend(items=fig.legend.items, location='center')
    legend.title = color_field.capitalize()
    legend.label_text_font_size = '0.8em'
    legend.spacing = 0
    fig.add_layout(legend, 'right')
    
    bokeh.plotting.show(fig)

In [140]:
df_samples, coords, model = run_pca(
    sample_query='species == "arabiensis"',
    site_mask='arab',
)

plot_pca(df_samples, coords, 
         pcx=1,
         pcy=2,
         title='An. arabiensis',
         colors='colorblind',
         markers='circle')

In [141]:
df_samples, coords, model = run_pca(
    sample_query='species == "coluzzii"',
    site_mask='gamb_colu',
)

plot_pca(df_samples, coords, 
         pcx=3,
         pcy=4,
         title='An. coluzzii',
         markers='circle')

In [132]:
df_samples, coords, model = run_pca(
    sample_query='species == "gambiae"',
    site_mask='gamb_colu',
)

plot_pca(df_samples, coords, 
         pcx=1,
         pcy=2,
         title='An. gambiae',
         colors='husl',
         width=700,
         height=450)

distributed.client - ERROR - Failed to reconnect to scheduler after 10.00 seconds, closing client
distributed.utils - ERROR - 
Traceback (most recent call last):
  File "/opt/conda/lib/python3.7/site-packages/distributed/utils.py", line 663, in log_errors
    yield
  File "/opt/conda/lib/python3.7/site-packages/distributed/client.py", line 1296, in _close
    await gen.with_timeout(timedelta(seconds=2), list(coroutines))
concurrent.futures._base.CancelledError
distributed.utils - ERROR - 
Traceback (most recent call last):
  File "/opt/conda/lib/python3.7/site-packages/distributed/utils.py", line 663, in log_errors
    yield
  File "/opt/conda/lib/python3.7/site-packages/distributed/client.py", line 1025, in _reconnect
    await self._close()
  File "/opt/conda/lib/python3.7/site-packages/distributed/client.py", line 1296, in _close
    await gen.with_timeout(timedelta(seconds=2), list(coroutines))
concurrent.futures._base.CancelledError
