In [1]:
%run setup.ipynb

In [2]:
def mbp2cm(chromosome, pos):
    """Convert physical distance in Mbp to genetic distance in cM."""
    return bp2cm(chromosome, pos*1e6)


def bp2cm(chromosome, pos):
    """Convert physical distance in bp to genetic distance in cM."""
    return gmap[chromosome][int(pos) - 1]


In [3]:
import numba

In [85]:
def plot_genes(chromosome, center, flank, label, ax=None, genetic_distance=False):
    if ax is None:
        fig, ax = plt.subplots(figsize=(10, 1), facecolor='w')
        
    # plot center line
    if genetic_distance:
        gcenter = mbp2cm(chromosome, center)
        ax.axvline(gcenter, color='w', linestyle='-', lw=4, zorder=0)
    else:
        ax.axvline(center, color='w', linestyle='-', lw=4, zorder=0)
        
    # figure out x limits
    if isinstance(flank, (int, float)):
        xlim = center - flank, center + flank
    elif isinstance(flank, tuple):
        xlim = center - flank[0], center + flank[1]
    else:
        raise ValueError

    # get genes
    df_genes_plot = df_genes.query(f"chromosome == '{chromosome}' and chromosome_start < {xlim[1]*1e6} and chromosome_end > {xlim[0]*1e6}")
    fwd_genes = df_genes_plot.query("strand == '+'")
    rev_genes = df_genes_plot.query("strand == '-'")
    
    # plot genes
    if genetic_distance:
        fwd_limits = [
            (bp2cm(gene.chromosome, gene.chromosome_start),
             bp2cm(gene.chromosome, gene.chromosome_end))
            for _, gene in fwd_genes.iterrows()
        ]
        rev_limits = [
            (bp2cm(gene.chromosome, gene.chromosome_start),
             bp2cm(gene.chromosome, gene.chromosome_end))
            for _, gene in rev_genes.iterrows()            
        ]
        xlim = mbp2cm(chromosome, xlim[0]), mbp2cm(chromosome, xlim[1])
    else:
        fwd_limits = [
            (gene.chromosome_start/1e6, gene.chromosome_end/1e6)
            for _, gene in fwd_genes.iterrows()
        ]
        rev_limits = [
            (gene.chromosome_start/1e6, gene.chromosome_end/1e6)
            for _, gene in rev_genes.iterrows()
        ]
    fwd_xranges = [(a, b-a) for (a, b) in fwd_limits]
    rev_xranges = [(a, b-a) for (a, b) in rev_limits]
    ax.broken_barh(fwd_xranges, (0.6, .3), edgecolor='k', facecolor='w', linewidth=1)
    ax.broken_barh(rev_xranges, (0.1, .3), edgecolor='k', facecolor='w', linewidth=1)

    # tidy
    ax.set_xlim(*xlim)
    ax.set_ylim(0, 1)
    ax.set_yticks([0.25, 0.75])
    ax.set_yticklabels(['reverse', 'forward'])
    ax.set_xticklabels([])
    ax.text(xlim[0], 1, 'Genes', ha='left', va='bottom')
    if genetic_distance:
        annx = mbp2cm(chromosome, center)
    else:
        annx = center
    ax.annotate(label, xy=(annx, 1), xytext=(0, 20), textcoords='offset points',
                ha='center', va='bottom', fontstyle='italic',
                arrowprops=dict(arrowstyle='simple', connectionstyle='arc3', color='k'))


In [216]:
def fig_locus(chromosome, center, flank, label, tracks, figw=10, savefig=None,
              savefig_dpi=150, genetic_distance=False, track_height=None,
              plot_kwargs=None):

    if track_height:
        figh = 1.5 + len(tracks) * track_height
    else:
        figh = (11.75/8.25) * figw

    fig = plt.figure(figsize=(figw, figh), facecolor='w')
    gs = fig.add_gridspec(ncols=1, nrows=1+len(tracks), 
                          height_ratios=[.5] + [2] * len(tracks))
    
    # genes
    ax = fig.add_subplot(gs[0])
    plot_genes(chromosome, center, flank, label, ax=ax, genetic_distance=genetic_distance)
    
    # tracks
    track_stats = []
    for i, (plot, kwargs) in enumerate(tracks):
        if plot_kwargs:
            kwargs.update(plot_kwargs)
        ax = fig.add_subplot(gs[i+1])
        stats = plot(chromosome=chromosome, center=center, flank=flank, ax=ax, 
                     genetic_distance=genetic_distance, **kwargs)
        stats['locus'] = label
        track_stats.append(stats)
        if i < len(tracks) - 1:
            ax.set_xticklabels([])
    df_stats = pd.DataFrame.from_records(track_stats)        
    
    if genetic_distance:
        units = 'cM'
    else:
        units = 'Mbp'
    ax.set_xlabel(f"Chromosome {chromosome} position ({units})")
   
    fig.tight_layout()
    
    if savefig:
        fig.savefig(savefig, bbox_inches='tight', dpi=savefig_dpi)
    plt.show()
    plt.close()
    
    return df_stats
    

In [None]:
def analyse_peak_stats(all_stats, slug, label):

    # concat stats from all genes
    df_stats = (
        pd.concat([all_stats[gene['Name'].lower()] 
                   for gene in (gste2, cyp6p3, cyp9k1, vgsc, gaba, ace1)])
        [['locus', 'pop', 'peak_value', 'peak_pos']]
    )
    df_stats.loc[:, 'peak_pos'] = df_stats['peak_pos'] * 1e3
    df_stats.to_csv(here() / f'tables/locus_peaks_{slug}.csv', index=False)

    # aggregate by locus
    df_agg = (
        df_stats
        .groupby('locus', sort=False).agg({
            'peak_value': ['count', 'min', 'max'],
            'peak_pos': ['min', 'max', mean_absolute], 
        })
    )

    # output aggregate stats to latex
    (df_agg
        .reset_index()
        .rename({
            'locus': 'Locus',
            'count': 'No. populations',
            'peak_value': f'${label}_{{peak}}$',
            'peak_pos': f'$pos({label}_{{peak}})$ (kbp)',
            'min': 'Min',
            'max': 'Max',
            'mean_absolute': 'MAE',
        }, axis=1)
        .to_latex(
            here() / f'tables/locus_stats_{slug}.tex',
            escape=False,
            formatters=[
                '\\textit{{{}}}'.format,
                None,
                lambda v: '-' if np.isnan(v) else '{:.2f}'.format(v),
                lambda v: '-' if np.isnan(v) else '{:.2f}'.format(v),
                lambda v: '-' if np.isnan(v) else '{:+.1f}'.format(v),
                lambda v: '-' if np.isnan(v) else '{:+.1f}'.format(v),
                lambda v: '-' if np.isnan(v) else '{:.1f}'.format(v),
            ],
            index=False
        )
    )

    # output peak vals to latex
    (
        df_stats
        .pivot_table(index='pop', columns='locus', values='peak_value', fill_value='-', dropna=False)
        .rename(index={k: tex_italicize_species(pop_defs[k]['label']) for k in pop_defs})
        .reset_index()
        [['pop', 'Gste2', 'Cyp6p3', 'Cyp9k1', 'Vgsc', 'Gaba', 'Ace1']]
        .rename(columns={
            'pop': 'Population',
            'Gste2': '\textit{Gste2}',
            'Cyp6p3': '\textit{Cyp6p3}',
            'Cyp9k1': '\textit{Cyp9k1}',
            'Vgsc': '\textit{Vgsc}',
            'Gaba': '\textit{Gaba}',
            'Ace1': '\textit{Ace1}',
        })
        .to_latex(
            here() / f'tables/locus_peaks_{slug}.tex',
            index=False,
            escape=False,
            float_format='{:.2f}'.format
        )
    )
