In [1]:
import pyreadr
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
from Bio import SeqIO
import sys
import itertools
from tqdm import tqdm
import random
import re
import scipy
import pickle
import os
import time
from pandarallel import pandarallel
from scipy.stats import multinomial, chisquare, power_divergence
import matplotlib.ticker as ticker

In [2]:
from statsmodels.stats.multitest import multipletests

In [3]:
sys.path.insert(0, os.path.abspath(".."))
import suffix_array

In [4]:
pandarallel.initialize(progress_bar=False)

INFO: Pandarallel will run on 1 workers.
INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.


In [5]:
comp_trans = str.maketrans("ACGTMRWSYKVHDBN", "TGCAKYWSRMBDHVN")
def reverse_complement(seq):
    return seq.translate(comp_trans)[::-1]

In [6]:
def get_all_kmers(l=3):
    return ["".join(tup) for tup in itertools.product("ATGC", repeat=l)]

In [7]:
def get_kmer_frequencies(seq, k=3):
    seq_ = str(seq)
    all_kmers = get_all_kmers(k)
    kmer_counts = {kmer:0 for kmer in all_kmers}
    for i in range(len(seq)-k):
        kmer_counts[seq_[i:i+k]] += 1
        kmer_counts[reverse_complement(seq_[i:i+k])] += 1
    d = pd.DataFrame(list(kmer_counts.items()), columns=['kmer', 'count']).set_index('kmer')
    d = d / d.sum()
    return d.loc[all_kmers]

In [8]:
def kmers(row, n=7, k=3):
    for start in range(n):
        kmer = row['s'][start:start+k]
        if len(kmer) == k:
            row[kmer] = 1
    return row

In [9]:
def select_positions(fwd, rev, thr, only_CA=False, peak_dist=0, min_cov=1, min_dist=1, dist_rem='all', subsample=1., k=3, k_window=2, kmer_quantile=1., iterative_kmer_selection=False):
    # select based on threshold
    fwd_pos, rev_pos = np.where(fwd <= thr)[0], np.where(rev <= thr)[0]
    #peak selection
    if peak_dist > 0:
        for i,pos in enumerate([fwd_pos, rev_pos]):
            dists = pos[1:] - pos[:-1]
            rem_flank = 'left'
            while np.any(dists <= peak_dist):
                dists = np.pad(dists, 1, constant_values=[min_dist+1, min_dist+1])#[1,0])
                diff = np.diff((dists <= peak_dist).astype(int))
                clus = np.column_stack((np.where(diff == 1)[0], np.where(diff == -1)[0]))
                # iteratively remove left or right flank of clusters in alternating fashion
                if rem_flank == 'left':
                    rem = clus[:,0]
                    rem_flank = 'right'
                else:
                    rem = clus[:,1]
                    rem_flank = 'left'
                pos = np.delete(pos, rem)
                dists = pos[1:] - pos[:-1]
            if i == 0:
                fwd_pos = pos
            else:
                rev_pos = pos
    # require minimum coverage
    if min_cov > 1:
        fwd_pos = fwd_pos[fwd_cov[fwd_pos] >= min_cov]
        rev_pos = rev_pos[rev_cov[rev_pos] >= min_cov]
    # kmer-based filtering
    if kmer_quantile < 1.:
        if iterative_kmer_selection:
            n = 2*(k_window+k-1)+1-(k-1)
            all_kmers = get_all_kmers(k)
            n_most_frequent = int(kmer_quantile * len(all_kmers))
            frequent_kmers = []
            fwd_pos_ = fwd_pos.copy()
            rev_pos_ = rev_pos.copy()
            for i in range(n_most_frequent):
                d_fwd = pd.DataFrame(fwd_pos_, columns=['pos'])
                d_fwd['s'] = d_fwd.apply(lambda row: "".join(seq[row.pos - (k_window+k-1): row.pos + (k_window+k)]), axis=1).to_frame()
                d_fwd = d_fwd.reindex(d_fwd.columns.tolist() + all_kmers, axis=1, fill_value=0)
                d_fwd = d_fwd.parallel_apply(kmers, n=n, k=k, axis=1).set_index('pos').drop('s', axis=1)
                d_rev = pd.DataFrame(rev_pos_, columns=['pos'])
                d_rev['s'] = d_rev.apply(lambda row: "".join(seq[row.pos - (2+k-1): row.pos + (2+k)]), axis=1).to_frame()
                d_rev = d_rev.reindex(d_rev.columns.tolist() + all_kmers, axis=1, fill_value=0)
                d_rev = d_rev.parallel_apply(kmers, n=n, k=k, axis=1).set_index('pos').drop('s', axis=1)
                d = d_fwd.sum() + d_rev.sum()
                d = d - (kmer_frequencies[k].loc[d.index]['count'] * d.sum()).astype(int)
                frequent_kmers.append(d.idxmax())
                fwd_pos_ = d_fwd[d_fwd[[frequent_kmers[-1]]].sum(axis=1) == 0].index.to_numpy()
                rev_pos_ = d_rev[d_rev[[frequent_kmers[-1]]].sum(axis=1) == 0].index.to_numpy()
            d_fwd = pd.DataFrame(fwd_pos, columns=['pos'])
            d_fwd['s'] = d_fwd.apply(lambda row: "".join(seq[row.pos - (k_window+k-1): row.pos + (k_window+k)]), axis=1).to_frame()
            d_fwd = d_fwd.reindex(d_fwd.columns.tolist() + all_kmers, axis=1, fill_value=0)
            d_fwd = d_fwd.parallel_apply(kmers, n=n, k=k, axis=1).set_index('pos').drop('s', axis=1)
            d_rev = pd.DataFrame(rev_pos, columns=['pos'])
            d_rev['s'] = d_rev.apply(lambda row: "".join(seq[row.pos - (2+k-1): row.pos + (2+k)]), axis=1).to_frame()
            d_rev = d_rev.reindex(d_rev.columns.tolist() + all_kmers, axis=1, fill_value=0)
            d_rev = d_rev.parallel_apply(kmers, n=n, k=k, axis=1).set_index('pos').drop('s', axis=1)
            d = d_fwd.sum() + d_rev.sum()
            d = d - (kmer_frequencies[k].loc[d.index]['count'] * d.sum()).astype(int)
            #print(d[frequent_kmers].sort_values())
            fwd_pos = d_fwd[d_fwd[frequent_kmers].sum(axis=1) > 0].index.to_numpy()
            rev_pos = d_rev[d_rev[frequent_kmers].sum(axis=1) > 0].index.to_numpy()
        else:
            n = 2*(k_window+k-1)+1-(k-1)
            all_kmers = get_all_kmers(k)
            #for i,pos in enumerate([fwd_pos, rev_pos]):
            d_fwd = pd.DataFrame(fwd_pos, columns=['pos'])
            d_fwd['s'] = d_fwd.apply(lambda row: "".join(seq[row.pos - (k_window+k-1): row.pos + (k_window+k)]), axis=1).to_frame()
            d_fwd = d_fwd.reindex(d_fwd.columns.tolist() + all_kmers, axis=1, fill_value=0)
            d_fwd = d_fwd.parallel_apply(kmers, n=n, k=k, axis=1).set_index('pos').drop('s', axis=1)
            d_rev = pd.DataFrame(rev_pos, columns=['pos'])
            d_rev['s'] = d_rev.apply(lambda row: "".join(seq[row.pos - (2+k-1): row.pos + (2+k)]), axis=1).to_frame()
            d_rev = d_rev.reindex(d_rev.columns.tolist() + all_kmers, axis=1, fill_value=0)
            d_rev = d_rev.parallel_apply(kmers, n=n, k=k, axis=1).set_index('pos').drop('s', axis=1)
            d = d_fwd.sum() + d_rev.sum()
            d = d - (kmer_frequencies[k].loc[d.index]['count'] * d.sum()).astype(int)
            frequent_kmers = d[d > d.quantile(1 - kmer_quantile)].index
            #print(d[frequent_kmers].sort_values())
            fwd_pos = d_fwd[d_fwd[frequent_kmers].sum(axis=1) > 0].index.to_numpy()
            rev_pos = d_rev[d_rev[frequent_kmers].sum(axis=1) > 0].index.to_numpy()
    # de-select when too close to other hit
    if min_dist > 1:
        for i,pos in enumerate([fwd_pos, rev_pos]):
            dists = pos[1:] - pos[:-1]
            if dist_rem == 'all':
                dists = np.pad(dists, (1,0), constant_values=0)
                pos = pos[~((dists < min_dist) | (np.pad(dists[1:], (0,1), constant_values=0) < min_dist))]
            elif dist_rem == 'middle':
                while np.any(dists < min_dist):
                    dists = np.pad(dists, 1, constant_values=[min_dist+1, min_dist+1])#[1,0])
                    diff = np.diff((dists < min_dist).astype(int))
                    clus = np.column_stack((np.where(diff == 1)[0], np.where(diff == -1)[0]))
                    rem = clus[:,0] + 1 + (clus[:,1] - clus[:,0]) // 2 # remove center elem of a cluster of close ones
                    pos = np.delete(pos, rem)
                    dists = pos[1:] - pos[:-1]
            if i == 0:
                fwd_pos = pos
            else:
                rev_pos = pos
    # de-select non-CA
    if only_CA:
        fwd_pos = fwd_pos[fwd_ca[fwd_pos] == True]
        rev_pos = rev_pos[rev_ca[rev_pos] == True]
    # subsample randomly
    if subsample < 1.:
        fwd_pos = np.random.choice(fwd_pos, int(subsample*len(fwd_pos)), replace=False)
        rev_pos = np.random.choice(rev_pos, int(subsample*len(rev_pos)), replace=False)
    return fwd_pos, rev_pos

In [10]:
sample_dir = "../samples"

In [11]:
samples = [d for d in os.listdir(sample_dir)]
samples

['OC8',
 'LA_Y3C',
 'Bak13',
 'Riv16',
 'Red_Oak2',
 'LM10',
 'Riv19',
 'Filmore',
 'De_Donno',
 'Bak8',
 'Oak_35874',
 'Temecula_1',
 'Riv25']

In [12]:
context = 22
for sample in samples:
    print("processing sample", sample)
    # load data
    fp = f"../samples/{sample}/diff/{sample}_difference.RDS"
    df = pyreadr.read_r(fp)[None]
    # correct positions
    df['position'] = df['position'].astype(int) - 1 + 3 # convert to 0-based indexing
    df.loc[df.dir == 'rev', 'position'] += 1
    df = df.set_index(df.position)
    # read genome sequence
    fp = f"../samples/{sample}/ref_genome/{sample}.fasta"
    seq = ""
    contig = None
    for record in SeqIO.parse(fp, "fasta"):
        if len(record.seq) > len(seq):
            seq = record.seq
            contig = record.id
    sa = suffix_array.get_suffix_array(contig, seq)
    # coverage
    fwd_cov = np.full(len(seq), np.nan)
    sel = ((df.dir == 'fwd') & ~pd.isna(df.u_test_pval))
    fwd_cov[df.loc[(df.dir == 'fwd') & ~pd.isna(df.u_test_pval), 'u_test_pval'].index] = np.min([df.loc[sel, 'N_wga'], df.loc[sel, 'N_nat']], axis=0)
    rev_cov = np.full(len(seq), np.nan)
    sel = ((df.dir == 'rev') & ~pd.isna(df.u_test_pval))
    rev_cov[df.loc[(df.dir == 'rev') & ~pd.isna(df.u_test_pval), 'u_test_pval'].index] = np.min([df.loc[sel, 'N_wga'], df.loc[sel, 'N_nat']], axis=0)
    # p-values
    fwd_p = np.full(len(seq), np.nan)
    fwd_p[df.loc[(df.dir == 'fwd') & ~pd.isna(df.u_test_pval), 'u_test_pval'].index] = df.loc[(df.dir == 'fwd') & ~pd.isna(df.u_test_pval), 'u_test_pval']
    rev_p = np.full(len(seq), np.nan)
    rev_p[df.loc[(df.dir == 'rev') & ~pd.isna(df.u_test_pval), 'u_test_pval'].index] = df.loc[(df.dir == 'rev') & ~pd.isna(df.u_test_pval), 'u_test_pval']
    # bonferoni-hochberg corrected p-values
    _,corr_p_values,_,_ = multipletests(np.concatenate([fwd_p[~np.isnan(fwd_p)], rev_p[~np.isnan(rev_p)]]), alpha=0.05, method='fdr_bh')
    fwd_p_corr = fwd_p.copy()
    fwd_p_corr[~np.isnan(fwd_p)] = corr_p_values[:(~np.isnan(fwd_p)).sum()]
    rev_p_corr = rev_p.copy()
    rev_p_corr[~np.isnan(rev_p)] = corr_p_values[(~np.isnan(fwd_p)).sum():]
    # calculate kmer frequencies in genome
    kmer_frequencies = {}
    for k in [3,4]:
        kmer_frequencies[k] = get_kmer_frequencies(seq, k)
    
    thr = 0.001
    for kwargs in [{},
                   {'peak_dist':2, 'min_cov':20, 'min_dist':20},
                   {'peak_dist':2, 'min_cov':20, 'min_dist':20, 'k':3, 'kmer_quantile':0.25}
                  ]:
        fwd_pos, rev_pos = select_positions(fwd_p_corr, rev_p_corr, thr, **kwargs)
        fn = "".join([f"{sample}_peaks_uBH_{thr}"] + [f"_{kw}_{val}" for kw,val in kwargs.items()]) + ".fasta"
        print(fn, len(fwd_pos), len(rev_pos))
        fp = os.path.join(sample_dir, sample, "peaks")
        os.makedirs(fp, exist_ok=True)
        fp = os.path.join(fp, fn)
        with open(fp, "w") as f:
            for pos in fwd_pos:
                print(f">{contig}_fwd_{pos}\n{seq[pos-context:pos+context+1].upper()}", file=f)
            for pos in rev_pos:
                print(f">{contig}_rev_{pos}\n{reverse_complement(str(seq[pos-context:pos+context+1].upper()))}", file=f)
    suffix_array.delete_suffix_array(sa)

processing sample OC8
OC8_peaks_uBH_0.001.fasta 9063 8905
OC8_peaks_uBH_0.001_peak_dist_2_min_cov_20_min_dist_20.fasta 5621 5400
OC8_peaks_uBH_0.001_peak_dist_2_min_cov_20_min_dist_20_k_3_kmer_quantile_0.25.fasta 4967 4769
processing sample LA_Y3C
LA_Y3C_peaks_uBH_0.001.fasta 7717 7773
LA_Y3C_peaks_uBH_0.001_peak_dist_2_min_cov_20_min_dist_20.fasta 4686 4679
LA_Y3C_peaks_uBH_0.001_peak_dist_2_min_cov_20_min_dist_20_k_3_kmer_quantile_0.25.fasta 3904 4017
processing sample Bak13
Bak13_peaks_uBH_0.001.fasta 8810 9013
Bak13_peaks_uBH_0.001_peak_dist_2_min_cov_20_min_dist_20.fasta 5392 5522
Bak13_peaks_uBH_0.001_peak_dist_2_min_cov_20_min_dist_20_k_3_kmer_quantile_0.25.fasta 4773 4964
processing sample Riv16
Riv16_peaks_uBH_0.001.fasta 2814 2271
Riv16_peaks_uBH_0.001_peak_dist_2_min_cov_20_min_dist_20.fasta 1896 1465
Riv16_peaks_uBH_0.001_peak_dist_2_min_cov_20_min_dist_20_k_3_kmer_quantile_0.25.fasta 1670 1283
processing sample Red_Oak2
Red_Oak2_peaks_uBH_0.001.fasta 8872 8161
Red_Oak2_pea