In [1]:
import pandas as pd
import numpy as np
import seaborn as sns
import re
from matplotlib.pyplot import subplots, style, rc, rc_context, close
from tqdm import tqdm
from venn import venn, pseudovenn
from collections import defaultdict
from itertools import count, islice
from functools import lru_cache
from argparse import Namespace

In [2]:
from glob import glob
from scipy.spatial.distance import squareform
from scipy.cluster.hierarchy import dendrogram, fcluster, linkage
from sklearn.metrics import silhouette_score
from matplotlib.gridspec import GridSpec
%matplotlib inline

In [3]:
from scipy.cluster.hierarchy import cophenet
from scipy.stats import pearsonr, wilcoxon

In [4]:
from edgecaselib.formats import load_index, load_kmerscan
from edgecaselib.densityplot import interpret_arguments
from edgecaselib.util import natsorted_chromosomes
from pickle import dump, load
from os import path
from tempfile import NamedTemporaryFile
from subprocess import check_output, CalledProcessError
from pysam import AlignmentFile
from scipy.stats import chi2_contingency
from statsmodels.stats.multitest import multipletests
from scipy.ndimage import uniform_filter1d
from matplotlib.patches import FancyArrowPatch, Rectangle
from matplotlib.lines import Line2D

In [5]:
SAMFILTERS = ["is_q|tract_anchor", 3844, 0]
ecx = load_index("assets/hg38ext.fa.ecx")

In [6]:
SUBJECT_TO_TRIO = {
    "HG001": "NA12878",
    "HG002": "AshkenazimTrio", "HG003": "AshkenazimTrio", "HG004": "AshkenazimTrio",
    "HG005": "ChineseTrio", "HG006": "ChineseTrio", "HG007": "ChineseTrio",
}

In [7]:
RAW_GLOBAL_LDS = {
    re.sub(r'-matrix\.tsv$', "", filename.split("/")[-1]): pd.read_csv(filename, sep="\t", index_col=0)
    for filename in glob("PacBio/haplotypes/levenshtein-q_arm/*-matrix.tsv")
}

In [8]:
KMERSCANNER_PKL = "PacBio/kmerscanner-q_arm.pkl"
KMERSCANNER_DAT = "PacBio/kmerscanner-q_arm.dat.gz"

if path.isfile(KMERSCANNER_PKL):
    with open(KMERSCANNER_PKL, mode="rb") as pkl:
        DENSITIES = load(pkl)
else:
    DENSITIES = load_kmerscan(KMERSCANNER_DAT, True, SAMFILTERS, 10)
    with open(KMERSCANNER_PKL, mode="wb") as pkl:
        dump(DENSITIES, pkl)

In [9]:
MAXLEN = 1700

In [10]:
class GridFig():

    def __init__(self, width_ratios, height_ratios, scale=1):
        self.figure, _ = subplots(
            figsize=(sum(width_ratios)*scale, sum(height_ratios)*scale),
            ncols=0, nrows=0,
        )
        self.gs = GridSpec(
            ncols=len(width_ratios), wspace=0, width_ratios=width_ratios, 
            nrows=len(height_ratios), hspace=0, height_ratios=height_ratios,
            figure=self.figure,
        )
 
    def subplot(self, gridspec_slice, aspect="auto", frame=False):
        ax = self.figure.add_subplot(gridspec_slice, aspect=aspect)
        if frame is False:
            ax.set(frame_on=False)
        else:
            for spine in {"top", "right", "bottom", "left"} - set(frame):
                ax.spines[spine].set_visible(False)
        return ax

In [11]:
def plot_dendrogram(Z, gf):
    ax = gf.subplot(gf.gs[0,0])
    with rc_context({"lines.linewidth": .5}):
        dendrogram(
            Z, orientation="left",
            link_color_func=lambda x: "black", ax=ax,
        )
    ax.set(
        xticks=[], xlabel=None,
        yticks=[], ylabel=None,
        ylim=ax.get_ylim()[::-1],
    )

In [12]:
def plot_heatmap(data2d, gf, cmap="gray_r", vmax=.15):
    ax = gf.subplot(gf.gs[0,1])
    sns.heatmap(data2d, cmap=cmap, cbar=False, vmin=0, vmax=vmax, ax=ax)
    ax.set(xticks=[], yticks=[], xlabel=None, ylabel=None)

In [13]:
def get_plottable_density_section(densities, chrom, motif, data2d, ecx):
    chrom_densities = densities[chrom]
    if motif is None:
        by_motif = chrom_densities[chrom_densities["motif"]=="TTAGGG"]
    else:
        by_motif = chrom_densities[chrom_densities["motif"]==motif]
    section = by_motif.set_index("name").reindex(data2d.index).iloc[:,8:].copy()
    if motif is None:
        section = (~section.isnull()).astype(int) / 3
    section.columns = section.columns.astype(int)
    anchor = ecx.loc[
        (ecx["rname"]==chrom) & (ecx["flag"]==0x4000) & (ecx["prime"]==3),
        "pos",
    ].iloc[0]
    return section[[c for c in section.columns if c>=anchor]]

In [14]:
def get_absentees(lds, densities, chrom, ecx):
    raw_section = get_plottable_density_section(densities, chrom, "TTAGGG", lds, ecx)
    nulls = raw_section.isnull().all(axis=1)
    return nulls[nulls].index

In [15]:
def section_to_RGB(ps, color, alpha_factor=1.2):
    return np.transpose(
        np.array([
            np.full_like(ps, color[0]),
            np.full_like(ps, color[1]),
            np.full_like(ps, color[2]),
            np.clip(ps*alpha_factor, a_min=None, a_max=1),
        ]),
        axes=(1, 2, 0),
    )

In [16]:
def draw_fancy_arrow(
    y, start, end, ax, lw=.25,
    csty="angle3,angleA=45,angleB=-45",
    asty="Simple, tail_width=.25, head_width=2, head_length=3"
):
    ax.add_patch(FancyArrowPatch(
        (start, y), (end, y),
        connectionstyle=csty,
        arrowstyle=asty,
        lw=lw, color="#888", clip_on=False,
    ))

In [17]:
POPULATION_COLORS = {
    "HG001": "black",
    "HG002": "green", "HG003": "green", "HG004": "green",
    "HG005": "steelblue", "HG006": "steelblue", "HG007": "steelblue",
}

def plot_subjects(dispatcher, gf, s=10, arrows=True):
    for i, subject in enumerate(sorted(SUBJECT_TO_TRIO)):
        sax = gf.subplot(gf.gs[0,i+3])
        sax.plot([0, 0], [0, len(dispatcher)], lw=.5, color="#888")
        if subject in dispatcher:
            truthiness = dispatcher[subject].reset_index(drop=True)
            positions = truthiness[truthiness].index
            for x in [-.1, 0, .1]:
                sax.scatter(x=[x]*len(positions), y=positions, marker="_", s=s, color=POPULATION_COLORS[subject])
        sax.set(
            xticks=[0], xticklabels=[subject+" "],
            yticks=[], xlabel=None, ylabel=None,
            xlim=(-.5, .5),
            ylim=(len(dispatcher), -1),
        )
        for tick in sax.get_xticklabels():
            tick.set_rotation(90)
        if arrows and (subject in {"HG002", "HG005"}):
            draw_fancy_arrow(len(dispatcher), 1, 0, sax)
            draw_fancy_arrow(len(dispatcher), 2, 0, sax)
        sax.tick_params(axis="both", which="both", length=0)

In [18]:
IMSHOW_PALETTE = {
    None: [.7, .7, .7],
    "TTAGGG": [.1, .5, .2],
    "TGAGGG": [1, 1, 0],
    "TTGGGG": [.6, .27, .5],
    "TTAGGGG": [.5, .9, 1],
}

def plot_densities(densities, chrom, data2d, ecx, gf, extent, bin_size=100):
    ax = gf.subplot(gf.gs[0,-1])
    for motif, color in IMSHOW_PALETTE.items():
        ps = get_plottable_density_section(densities, chrom, motif, data2d, ecx).values
        breakat = ps.shape[1]//100
        if ps.shape[1] < MAXLEN:
            ps = np.pad(ps, ((0, 0), (0, MAXLEN-ps.shape[1])))
        pa = section_to_RGB(np.clip(uniform_filter1d(ps, 5, 1), a_min=0.0, a_max=1.0), color, 1.5)
        ax.imshow(pa, extent=extent, interpolation="nearest")
    ticklabels=np.linspace(0, MAXLEN//100, MAXLEN//100+1).astype(int).astype(str)
    ticklabels[breakat+1:] = ""
    xmin, xmax = extent[:2]
    ax.set(
        xticks=np.linspace(xmin, xmax, MAXLEN//100+1),
        xticklabels=ticklabels,
        xlabel="Kbp of telomeric tract",
        yticks=[], ylabel=None,
    )
    ax.tick_params(axis="both", which="both", length=0)
    ax.axhline(0, 0, (breakat+1)/len(ticklabels), lw=.5, c="black")

In [19]:
@lru_cache(maxsize=None)
def convname(cn):
    match = re.search(r'^\d+', cn)
    if match:
        return "chr" + match.group()
    else:
        return cn

In [20]:
def fixup_labels(gf, chrom, subject):
    if (chrom == "5qtel_1-500K_1_12_12_rc") and (subject == "HG001"):
        gf.figure.get_axes()[1].set_title("Pairwise relative Levenshtein distances   ", loc="right", fontsize=15)
        gf.figure.get_axes()[5].set_title("Subjects", fontsize=15)
        gf.figure.get_axes()[-1].set_title("   Motif densities", loc="left", fontsize=15)
    if (chrom != "chrX") or (subject != "HG007"):
        for ax in gf.figure.get_axes()[:-1]:
            ax.set(xticklabels=[], xlabel=None)
        gf.figure.get_axes()[-1].set(xlabel=None)

In [21]:
for chrom in tqdm(RAW_GLOBAL_LDS):
    lds = RAW_GLOBAL_LDS[chrom].copy()
    absentees = get_absentees(lds, DENSITIES, chrom, ecx)
    lds.drop(index=absentees, columns=absentees, inplace=True)
    constrainer = pd.DataFrame(
        index=lds.index, columns=["subject"],
        data=lds.index.map(lambda s: s.split(":")[1])
    ).sort_values(by="subject")
    for subject in constrainer["subject"].drop_duplicates():
        try:
            subject_index = constrainer[constrainer["subject"]==subject].index
            subject_lds = lds.loc[subject_index, subject_index]
            subject_Z = linkage(squareform(subject_lds), metric="euclidean", method="ward", optimal_ordering=True)
            subject_leaves = dendrogram(subject_Z, no_plot=True)["leaves"]
            subject_data2d = subject_lds.iloc[subject_leaves, subject_leaves]
            dispatcher = pd.DataFrame(index=subject_index)
            dispatcher.index.name = "read"
            to_subject = dispatcher.index.map(lambda s: s.split(":")[1])
            for subject in sorted(to_subject.drop_duplicates()):
                dispatcher[subject] = (to_subject==subject)
            h = 6*len(subject_lds)/200
            w = 30
            gf = GridFig([h/3,h,.3]+[.5]*7+[w], [h], scale=.4)
            plot_dendrogram(subject_Z, gf=gf)
            plot_heatmap(subject_data2d, gf=gf)
            plot_subjects(dispatcher, gf=gf, s=7, arrows=(subject=="HG007"))
            plot_densities(DENSITIES, chrom, subject_data2d, ecx, gf=gf, extent=[0,w,0,h])
            fixup_labels(gf, chrom, subject)
            if subject != "HG007":
                gf.figure.get_axes()[-1].set(xticks=[])
        except ValueError: # too few observations
            continue
        gf.figure.savefig(
            "PacBio/haplotypes/clusters-q_arm/constrained/"+chrom+"-"+subject+".pdf", bbox_inches="tight", pad_inches=0,
        )
        close(gf.figure)

100%|██████████| 18/18 [00:19<00:00,  1.06s/it]
