In [1]:
import itertools
import pandas as pd
import numpy as np
import toytree
import toyplot
import arviz as az
import pymc3 as pm

### Load tree

In [3]:
EXCLUDE = [
    'Quercus|Quercus|Leucomexicana|Q.species', 
    'Quercus|Quercus|Roburoids|Q.vulcanica',
    'Quercus|Quercus|Roburoids|Q.imeretina', 
    'Quercus|Lobatae|Erythromexicana|Q.lowilliamsii', 
    'Quercus|Lobatae|Agrifoliae|Q.oxyadenia',
    'Quercus|Quercus|Dumosae|Q.pacifica',
    'Quercus|Quercus|Roburoids|Q.kotschyana',
    'Quercus|Quercus|Roburoids|Q.cedrorum', 
    'Quercus|Lobatae|Agrifoliae|Q.tamalpaiensis',
    'Cerris|Cyclobalanopsis|Semiserrata|Q.litoralis', 
    'Cerris|Cyclobalanopsis|Semiserrata|Q.patelliformis',
    'Cerris|Cyclobalanopsis|Glauca|Q.multinervis', 
    'Cerris|Ilex|Himalayansubalpine|Q.sp.nov.',
]

TREE = toytree.tree("/home/henry/oaks-thesis/full_crown2.tre").drop_tips(EXCLUDE)

# find duplicates (label ends in 1 or 2, we'll drop the 2)
DUPS = [i for i in TREE.get_tip_labels() if i.endswith('2')]
TREE = TREE.drop_tips(DUPS)

# relabel kept duplicate tips by stripping 1 from end
TREE = TREE.set_node_values(
    feature="name", 
    values={nidx: TREE.idx_dict[nidx].name.strip("1") for nidx in TREE.idx_dict}
)

# Scale tree to 1.0 length
TREE = TREE.mod.node_scale_root_height(1.0)

In [8]:
# Get crown nodes for eight clades.
clades = [
    "Quercus|Quercus", # teal
    "Quercus|Virentes", # orange
    "Quercus|Ponticae", # blue
    "Quercus|Protobalanus", # pink
    "Quercus|Lobatae", # green
    "Cerris|Ilex", # yellow
    "Cerris|Cerris", # tan
    "Cerris|Cyclobalanopsis", # gray
]
crowns = {
    TREE.get_mrca_idx_from_tip_labels(wildcard=i): i for i in clades
}
crowns

{419: 'Quercus|Quercus',
 420: 'Quercus|Virentes',
 434: 'Quercus|Ponticae',
 447: 'Quercus|Protobalanus',
 455: 'Quercus|Lobatae',
 450: 'Cerris|Ilex',
 451: 'Cerris|Cerris',
 457: 'Cerris|Cyclobalanopsis'}

In [52]:
# draw the tree
TREE.draw(
    layout='d', 
    edge_colors=TREE.get_edge_values_mapped({
            j: toytree.colors[i] for i,j in enumerate(crowns)
        }),
    tip_labels=False,
    height=350,
    node_labels=[
        str(i) if int(i) in crowns else "" 
        for i in TREE.get_node_values("idx", 1, 1)
    ],
    node_labels_style={
        "font-size": "16px",
        "-toyplot-anchor-shift": "15px",
    },
);

In [10]:
# make group index (gidx)
crown_dict = {i: TREE.get_tip_labels(i) for i in crowns}
gidx = np.zeros(TREE.ntips, dtype=int)
for tidx, tip in enumerate(TREE.get_tip_labels()):
    for cidx, clade in enumerate(crown_dict):
        if tip in crown_dict[clade]:
            gidx[tidx] = cidx
gidx

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1,
       1, 1, 1, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
       4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
       4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
       4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5,
       5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
       6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
       7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7])

In [131]:
# True param values
𝛼_mean = 0.05
𝛼_std = 0.02
𝛽_mean = 3.0
𝛽_std = 0.2
𝜓_mean = 0.0
𝜓_std = 0.33

# 8 different clade effects on rate of RI (used for partial-pooling data)
𝜓_Quercus_mean = -0.8
𝜓_Quercus_std = 0.2
𝜓_Virentes_mean = -3.5
𝜓_Virentes_std = 0.05
𝜓_Ponticae_mean = -5.0
𝜓_Ponticae_std = 0.05
𝜓_Protobalanus_mean = 5.0
𝜓_Protobalanus_std = 0.1
𝜓_Lobatae_mean = 0.5
𝜓_Lobatae_std = 0.2
𝜓_Ilex_mean = 1.0
𝜓_Ilex_std = 0.15
𝜓_Cerris_mean = 5.0
𝜓_Cerris_std = 0.15
𝜓_Cyclobalanopsis_mean = 3.5
𝜓_Cyclobalanopsis_std = 0.15

In [132]:
SPECIES_DATA = pd.DataFrame({
    "Species": ["Quercus " + "{}".format(
        TREE.idx_dict[idx].name.split("|")[-1].split(".")[-1]) for idx in range(len(TREE.get_tip_labels()))],
    "b": np.random.normal(𝛽_mean, 𝛽_std, TREE.ntips),
    "psi": np.random.normal(𝜓_mean, 𝜓_std, TREE.ntips),
    "psi_x": np.concatenate([
        np.random.normal(𝜓_Quercus_mean, 𝜓_Quercus_std, len(gidx[gidx == 0])),
        np.random.normal(𝜓_Virentes_mean, 𝜓_Virentes_std, len(gidx[gidx == 1])),
        np.random.normal(𝜓_Ponticae_mean, 𝜓_Ponticae_std, len(gidx[gidx == 2])),
        np.random.normal(𝜓_Protobalanus_mean, 𝜓_Protobalanus_std, len(gidx[gidx == 3])),
        np.random.normal(𝜓_Lobatae_mean, 𝜓_Lobatae_std, len(gidx[gidx == 4])),
        np.random.normal(𝜓_Ilex_mean, 𝜓_Ilex_std, len(gidx[gidx == 5])),
        np.random.normal(𝜓_Cerris_mean, 𝜓_Cerris_std, len(gidx[gidx == 6])),
        np.random.normal(𝜓_Cyclobalanopsis_mean, 𝜓_Cyclobalanopsis_std, len(gidx[gidx == 7])),
    ]),
    "gidx": gidx
})

In [133]:
SPECIES_DATA

Unnamed: 0,Species,b,psi,psi_x,gidx
0,Quercus arizonica,2.649839,0.124749,-0.829386,0
1,Quercus oblongifolia,3.375341,0.027754,-0.856423,0
2,Quercus laeta,2.967354,-0.262966,-0.469912,0
3,Quercus ajoensis,2.909137,0.177614,-0.864610,0
4,Quercus turbinella,2.728273,-0.439628,-0.549844,0
...,...,...,...,...,...
226,Quercus rex,3.265054,-0.123982,3.702592,7
227,Quercus chungii,3.254171,0.036968,3.299497,7
228,Quercus delavayi,3.210858,-0.440764,3.485134,7
229,Quercus championii,3.269819,-0.300722,3.704606,7


### Generate crossing data

In [134]:
def get_dist(tree, idx0, idx1):
    "returns the genetic distance between two nodes on a tree"
    dist = tree.treenode.get_distance(
        tree.idx_dict[idx0], 
        tree.idx_dict[idx1],
    )
    return dist

# get all combinations of two sampled taxa
a, b = zip(*itertools.combinations(range(TREE.ntips), 2))

# organize into DF and get genetic distance between pairs
DATA = pd.DataFrame({
    "sidx0": a,
    "sidx1": b,
    "dist": [get_dist(TREE, i, j) / 2 for (i, j) in zip(a, b)],
})

# get pairwise velocity
DATA['b'] = (
    np.random.normal(𝛽_mean, 𝛽_std, DATA.shape[0])
)

DATA['velo'] = (
    DATA['b']
    + SPECIES_DATA['psi'][DATA.sidx0].values
    + SPECIES_DATA['psi'][DATA.sidx1].values
)

DATA['velo_x'] = (
    DATA['b']
    + SPECIES_DATA['psi_x'][DATA.sidx0].values
    + SPECIES_DATA['psi_x'][DATA.sidx1].values
)

DATA['intercept'] = np.random.normal(𝛼_mean, 𝛼_std, DATA.shape[0])

# get logits
DATA['logit_b'] = (
    1 / (1 + np.exp(-(DATA.intercept + DATA.dist * DATA.b)))
)
DATA['logit'] = (
    1 / (1 + np.exp(-(DATA.intercept + DATA.dist * DATA.velo)))
)
DATA['logit_x'] = (
    1 / (1 + np.exp(-(DATA.intercept + DATA.dist * DATA.velo_x)))
)

# get RI estimates
DATA['RI_pooled'] = np.random.binomial(n=1, p=DATA.logit_b / DATA.logit_b.max())
DATA['RI_unpooled'] = np.random.binomial(n=1, p=DATA.logit / DATA.logit.max())
DATA['RI_partpooled'] = np.random.binomial(n=1, p=DATA.logit_x / DATA.logit_x.max())

DATA.head()

Unnamed: 0,sidx0,sidx1,dist,b,velo,velo_x,intercept,logit_b,logit,logit_x,RI_pooled,RI_unpooled,RI_partpooled
0,0,1,0.001046,2.746947,2.89945,1.061138,0.027633,0.507626,0.507666,0.507185,1,0,0
1,0,2,0.002093,2.901469,2.763252,1.602171,0.01185,0.50448,0.504408,0.503801,0,1,1
2,0,3,0.004185,3.270655,3.573019,1.576659,0.075311,0.522235,0.522551,0.520466,0,0,0
3,0,4,0.008371,2.981583,2.666704,1.602353,0.07336,0.52456,0.523902,0.52168,0,1,0
4,0,5,0.016741,3.137269,3.501146,1.595426,0.002633,0.513785,0.515307,0.507335,0,0,1


In [135]:
NSAMPLES = 1000
SAMPLE = DATA.sample(NSAMPLES).copy().reset_index(drop=True)
SAMPLE.head()

Unnamed: 0,sidx0,sidx1,dist,b,velo,velo_x,intercept,logit_b,logit,logit_x,RI_pooled,RI_unpooled,RI_partpooled
0,19,205,1.0,2.862798,2.842064,5.557401,0.03554,0.947764,0.946728,0.99629,1,1,1
1,32,168,1.0,2.841939,2.503688,3.171095,0.056308,0.94776,0.928242,0.961853,1,1,1
2,216,224,0.862857,3.225487,1.869283,9.851656,0.027151,0.943229,0.837545,0.999802,1,1,1
3,66,137,0.927411,2.797488,1.427397,2.107891,0.07417,0.935148,0.801859,0.883814,1,1,1
4,58,172,1.0,3.025176,3.817973,3.23437,0.083669,0.957256,0.980192,0.965043,1,1,1


### Visualize data

In [19]:
def logit_plot(dist, logit, RI):
    canvas = toyplot.Canvas(width=500, height=250)
    ax0 = canvas.cartesian(
        label="pooled data (function)",
        xlabel="Genetic dist.",
        ylabel="Logit function",
        grid=(1, 2, 0),
    )
    ax1 = canvas.cartesian(
        label="pooled data (observation)",
        xlabel="Genetic dist.",
        ylabel="RI",
        grid=(1, 2, 1),
    )

    # points are jittered on x-axis for visibility
    ax0.scatterplot(
        dist,
        logit,
        size=5,
        opacity=0.33,
        color=toyplot.color.Palette()[0],
    );
    ax1.scatterplot(
        dist,
        RI,
        size=10,
        opacity=0.2,
        marker="|",
        mstyle={
            "stroke": toyplot.color.Palette()[1],
            "stroke-width": 3,
        },
    );
    return canvas, (ax0, ax1)

In [136]:
logit_plot(SAMPLE.dist, SAMPLE.logit_b, SAMPLE.RI_pooled);

In [137]:
logit_plot(SAMPLE.dist, SAMPLE.logit, SAMPLE.RI_unpooled);

In [138]:
logit_plot(SAMPLE.dist, SAMPLE.logit_x, SAMPLE.RI_partpooled);

### Define models

In [23]:
def toytrace(trace, var_names, titles):
    """
    Plot posterior trace with toyplot
    """
    nvars = len(var_names)
    
    # setup canvase
    canvas = toyplot.Canvas(width=500, height=200 * nvars)
    
    # store axes
    axes = []
    
    # iter over params
    for pidx, param in enumerate(var_names):
        
        # get param posterior
        posterior = trace.get_values(param)
        
        # setup axes 
        ax = canvas.cartesian(grid=(nvars, 1, pidx))
        ax.y.show = False
        ax.x.spine.style = {"stroke-width": 1.5}
        ax.x.ticks.labels.style = {"font-size": "12px"}
        ax.x.ticks.show = True
        ax.x.label.text = f"param='{titles[pidx]}'"        
        
        # iterate over shape of param
        for idx in range(posterior.shape[1]):
            mags, bins = np.histogram(posterior[:, idx], bins=100)
            ax.plot(bins[1:], mags, stroke_width=2, opacity=0.6)
        axes.append(ax)
    return canvas, axes

In [35]:
def pooled_logistic(x, y, **kwargs):
    
    # define model
    with pm.Model() as model:  

        # parameters and error
        𝛼 = pm.Normal('𝛼', mu=0., sigma=10., shape=1)
        𝛽 = pm.Normal('𝛽', mu=0., sigma=10., shape=1)
        
        # link function
        effect = 𝛼 + 𝛽 * x
        logit = pm.Deterministic("logit", pm.invlogit(effect))
        
        # data likelihood
        y = pm.Bernoulli("y", p=logit, observed=y)
        
        # sample posterior, skip burnin
        trace = pm.sample(**kwargs)[1000:]
    
        # show summary table
        stats = pm.summary(trace)
        
    # organize results
    result_dict = {
        'model': model, 
        'trace': trace,
        'stats': stats,
    }
    return result_dict

In [36]:
def unpooled_logistic(x, y, idx0, idx1, **kwargs):
    
    # define model
    with pm.Model() as model:
        
        # indexers
        sidx0 = pm.Data("spp_idx0", idx0.values)
        sidx1 = pm.Data("spp_idx1", idx1.values)

        # parameters and error
        𝜓_mean = pm.Normal('𝜓_mean', mu=0., sigma=5., shape=1)
        𝜓_std = pm.HalfNormal('𝜓_std', 5., shape=1)
        𝜓_offset = pm.Normal('𝜓_offset', mu=0, sigma=1., shape=TREE.ntips)
        𝜓 = pm.Deterministic('𝜓', 𝜓_mean + 𝜓_std * 𝜓_offset)
        𝛼 = pm.Normal('𝛼', mu=0., sigma=10., shape=1)
        𝛽 = pm.Normal('𝛽', mu=0., sigma=10., shape=1)
        
        # link function
        effect = 𝛼 + (𝛽 + 𝜓[sidx0] + 𝜓[sidx1]) * x
        logit = pm.Deterministic("logit", pm.invlogit(effect))
        
        # data likelihood
        y = pm.Bernoulli("y", p=logit, observed=y)
        
        # sample posterior, skip burnin
        trace = pm.sample(**kwargs)[1000:]
    
        # show summary table
        stats = pm.summary(trace)
        
    # organize results
    result_dict = {
        'model': model, 
        'trace': trace,
        'stats': stats,
    }
    return result_dict

In [37]:
def partpooled_logistic(x, y, idx0, idx1, gidx, **kwargs):
    
    # define model
    with pm.Model() as model:
        
        # indexers
        sidx0 = pm.Data("spp_idx0", idx0)
        sidx1 = pm.Data("spp_idx1", idx1)
        gidx = pm.Data("gidx", gidx)

        # parameters and error
        𝜓_mean = pm.Normal('𝜓_mean', mu=0., sigma=5., shape=8)
        𝜓_std = pm.HalfNormal('𝜓_std', 5., shape=8)
        𝜓_offset = pm.Normal('𝜓_offset', mu=0, sigma=1., shape=TREE.ntips)
        𝜓 = pm.Deterministic('𝜓', 𝜓_mean[gidx] + 𝜓_std[gidx] * 𝜓_offset)
        𝛽 = pm.Normal('𝛽', mu=0., sigma=10., shape=1)
        𝛼 = pm.Normal('𝛼', mu=0., sigma=10., shape=1)
        
        # linear model prediction
        effect = 𝛼 + (𝛽 + 𝜓[sidx0] + 𝜓[sidx1]) * x
        logit = pm.Deterministic("logit", pm.invlogit(effect))
        
        # data likelihood (normal distributed errors)
        y = pm.Bernoulli("y", p=logit, observed=y)

        # sample posterior, skip burnin
        trace = pm.sample(**kwargs)[1000:]

        # show summary table
        stats = pm.summary(trace)
        
    # organize results
    result_dict = {
        'model': model, 
        'trace': trace,
        'stats': stats,
    }
    return result_dict

In [98]:
# MCMC sampler kwargs
sample_kwargs = dict(
    tune=10000,
    draws=10000,
    target_accept=0.99,
    return_inferencedata=False,
    progressbar=True,
)

In [142]:
# model input
model_args = [
    SAMPLE.dist,
    SAMPLE.RI_pooled,
    SAMPLE.sidx0,
    SAMPLE.sidx1,
    gidx
]

# pooled model
pooled_sub = pooled_logistic(*model_args[:2], **sample_kwargs)

Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [𝛽, 𝛼]


Sampling 4 chains for 10_000 tune and 10_000 draw iterations (40_000 + 40_000 draws total) took 222 seconds.
The number of effective samples is smaller than 25% for some parameters.


In [143]:
pooled_sub['stats']

Unnamed: 0,mean,sd,hdi_3%,hdi_97%,mcse_mean,mcse_sd,ess_mean,ess_sd,ess_bulk,ess_tail,r_hat
𝛼[0],-0.083,0.258,-0.561,0.407,0.003,0.002,7791.0,7447.0,7779.0,9700.0,1.0
𝛽[0],3.709,0.377,2.987,4.402,0.004,0.003,7813.0,7802.0,7789.0,9707.0,1.0
logit[0],0.974,0.005,0.964,0.983,0.000,0.000,15364.0,15361.0,15184.0,17144.0,1.0
logit[1],0.974,0.005,0.964,0.983,0.000,0.000,15364.0,15361.0,15184.0,17144.0,1.0
logit[2],0.957,0.007,0.944,0.970,0.000,0.000,20274.0,20267.0,20055.0,18433.0,1.0
...,...,...,...,...,...,...,...,...,...,...,...
logit[995],0.974,0.005,0.964,0.983,0.000,0.000,15364.0,15361.0,15184.0,17144.0,1.0
logit[996],0.966,0.006,0.954,0.977,0.000,0.000,17572.0,17566.0,17350.0,17887.0,1.0
logit[997],0.966,0.006,0.955,0.977,0.000,0.000,17428.0,17422.0,17207.0,17949.0,1.0
logit[998],0.952,0.007,0.938,0.965,0.000,0.000,21878.0,21869.0,21630.0,18755.0,1.0


In [144]:
# model input
model_args = [
    SAMPLE.dist,
    SAMPLE.RI_unpooled,
    SAMPLE.sidx0,
    SAMPLE.sidx1,
    gidx
]

# unpooled model
unpooled_sub = unpooled_logistic(*model_args[:4], **sample_kwargs)

Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [𝛽, 𝛼, 𝜓_offset, 𝜓_std, 𝜓_mean]


Sampling 4 chains for 10_000 tune and 10_000 draw iterations (40_000 + 40_000 draws total) took 8543 seconds.
The number of effective samples is smaller than 25% for some parameters.


In [145]:
toytrace(unpooled_sub['trace'], ['𝜓_mean', '𝜓_offset', '𝜓'], ['psi-mean', 'psi-offset', 'psi-spp']);

In [146]:
# show plot of TRUE vs. ESTIMATED rates
c, a, m = toyplot.scatterplot(
    unpooled_sub['trace']['𝜓'].mean(axis=0),         # estimated
    SPECIES_DATA['psi'],                             # true
    width=400,
    height=250,
    xlabel="ESTIMATED species velocity",
    ylabel="TRUE species velocity",
    # color=[toyplot.color.Palette()[i] for i in SPECIES_DATA.gidx],
);

In [139]:
# model input
model_args = [
    SAMPLE.dist,
    SAMPLE.RI_partpooled,
    SAMPLE.sidx0,
    SAMPLE.sidx1,
    gidx
]

# partpooled model
partpooled_sub = partpooled_logistic(*model_args, **sample_kwargs)

Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [𝛼, 𝛽, 𝜓_offset, 𝜓_std, 𝜓_mean]


Sampling 4 chains for 10_000 tune and 10_000 draw iterations (40_000 + 40_000 draws total) took 5161 seconds.
The number of effective samples is smaller than 25% for some parameters.


In [140]:
toytrace(partpooled_sub['trace'], ['𝜓_mean', '𝜓_offset', '𝜓'], ['psi-mean', 'psi-offset', 'psi-spp']);

In [141]:
# show plot of TRUE vs. ESTIMATED rates
c, a, m = toyplot.scatterplot(
    partpooled_sub['trace']['𝜓'].mean(axis=0),         # estimated
    SPECIES_DATA['psi_x'],                             # true
    width=400,
    height=250,
    xlabel="ESTIMATED species velocity",
    ylabel="TRUE species velocity",
    color=[toytree.colors[i] for i in SPECIES_DATA.gidx],
);