# DTW Alignment to Virus Reference

Manage all imports

In [None]:
from sklearn import metrics
from numba import njit
from glob import glob
from scipy import stats

import random, h5py, re, os

import numpy as np
import matplotlib.pyplot as plt
import multiprocessing as mp
import seaborn as sns

%matplotlib inline

Define globals for selecting input data

In [None]:
data_dir = "../data"
kmer_model_fn, k = f"{data_dir}/dna_kmer_model.txt", 6 # 6-mer model
virus, other, dna_type, virus_ds, other_ds = "lambda", "human", "DNA", "0", "0"
ref_fn = f"{data_dir}/{virus}/{dna_type}/{virus_ds}/reference.fasta"
virus_fast5_dir = f"{data_dir}/{virus}/{dna_type}/{virus_ds}/fast5"
other_fast5_dir = f"{data_dir}/{other}/{dna_type}/{other_ds}/fast5"
results_dir = f"./results/all_modifications"
virus_max_reads = 1000
other_max_reads = 1000

## Signal-Based Reference Setup
Define helper functions

In [None]:
def get_fasta(fasta_fn):
    ''' Get sequence from FASTA filename. '''
    with open(fasta_fn, 'r') as fasta:
        return ''.join(fasta.read().split('\n')[1:])

def rev_comp(bases):
    ''' Get reverse complement of sequence. '''
    return bases.replace('A','t').replace('T','a').replace('G','c').replace('C','g').upper()[::-1]

def load_model(kmer_model_fn):
    ''' Load k-mer model file into Python dict. '''
    kmer_model = {}
    with open(kmer_model_fn, 'r') as model_file:
        for line in model_file:
            kmer, current = line.split()
            kmer_model[kmer] = float(current)
    return kmer_model

def discrete_normalize(seq, bits=8, minval=-4, maxval=4):
    ''' Approximate normalization which converts signal to integer of desired precision. '''
    mean = int(np.mean(seq))
    mean_avg_dev = int(np.mean(np.abs(seq - mean)))
    norm_seq = (seq - mean) / mean_avg_dev
    
    norm_seq[norm_seq < minval] = minval
    norm_seq[norm_seq > maxval] = maxval
    norm_seq = ( (norm_seq - minval) * (2**(bits)/(maxval-minval)) ).astype(int)
    return norm_seq

def ref_signal(fasta, kmer_model):
    ''' Convert reference FASTA to expected reference signal (z-scores). '''
    signal = np.zeros(len(fasta))
    for kmer_start in range(len(fasta)-k):
        signal[kmer_start] = kmer_model[fasta[kmer_start:kmer_start+k]]
    return discrete_normalize(signal*100) # increase dist between floats before rounding
    #return stats.zscore(signal*100)

Create COVID reference using (z-score normalized) expected k-mer currents for forward/backward reference FASTA

In [None]:
ref_fasta = get_fasta(ref_fn)
kmer_model = load_model(kmer_model_fn)
fwd_ref_sig = ref_signal(ref_fasta, kmer_model)
rev_ref_sig = ref_signal(rev_comp(ref_fasta), kmer_model)
ref_sig = np.concatenate((fwd_ref_sig, rev_ref_sig))

## Read Preparation

Define preprocessing functions for converting raw FAST5 data to normalized alignable signals

In [None]:
def get_stall_end(signal, stall_threshold=3, 
                  stall_events=2, stall_event_len=3):
    ''' Determine the end of the DNA stall region. '''
    
    # take average of a few samples to reduce variation
    events = []
    for event in range(0, len(signal), stall_event_len):
        events.append(np.mean(signal[event:event+stall_event_len]))
    
    # find where we exceed threshold for a few consecutive events
    above_threshold_count = 0
    event_pos = 0
    for event in events:
        event_pos += 1
        if event > stall_threshold:
            above_threshold_count += 1
        else:
            above_threshold_count = 0
        if above_threshold_count == stall_events:
            break
            
    # find where we go back below threshold
    below_threshold_count = 0
    for event in events[event_pos:]:
        event_pos += 1
        if event < stall_threshold:
            below_threshold_count += 1
        else:
            below_threshold_count = 0
        if below_threshold_count == stall_events:
            break
            
    return event_pos * stall_event_len


def trim(signal):
    ''' Trims signal by detecting stall (and eventually adapter). '''
    stall_end = get_stall_end(stats.zscore(signal))
    return signal[stall_end+1000 : stall_end+6000], stall_end


def filter_outliers(signal, minval=-4, maxval=4):
    
    # return empty signals as-is
    if not len(signal): return signal
    
    # upper threshold
    for idx, x in enumerate(signal):
        if x > maxval:
            # other values above max -> threshold to max
            if (idx+1 < len(signal) and signal[idx+1] > 3) or \
            (idx > 0 and signal[idx-1] > maxval):
                signal[idx] = maxval
            # otherwise, single outlier -> interpolate
            elif idx == 0:
                signal[idx] = signal[1]
            elif idx+1 == len(signal):
                signal[idx] = signal[idx-1]
            else:
                signal[idx] = (signal[idx-1] + signal[idx+1]) / 2
                
    # lower threshold
    for idx, x in enumerate(signal):
        if x < minval:
            # other values below min -> threshold to min
            if (idx+1 < len(signal) and signal[idx+1] < -3) or \
            (idx > 0 and signal[idx-1] < minval):
                signal[idx] = minval
            # otherwise, single outlier -> interpolate
            elif idx == 0:
                signal[idx] = signal[1]
            elif idx+1 == len(signal):
                signal[idx] = signal[idx-1]
            else:
                signal[idx] = (signal[idx-1] + signal[idx+1]) / 2
                
    return signal

def preprocess_read(uuid):
    ''' Return preprocessed read from specified FAST5 file. '''
    readname = f"read_{uuid}"
    fast5_file = h5py.File(full_index[uuid], 'r')
    signal = np.array(fast5_file[readname]['Raw']['Signal'][:], dtype=np.int16)
    length = signal.shape[0]
    signal, trimmed = trim(signal)
    if not len(signal): return np.array([]), trimmed, length
    new_signal = np.array(signal, dtype=float)
    for start in range(0, len(signal), 500):
        new_signal[start:start+500] = \
            discrete_normalize(signal[:start+500])[start:start+500]
            #stats.zscore(signal[:start+500])[start:start+500]
    #signal = filter_outliers(np.array(new_signal))
    #signal = discrete_normalize(signal)
    #signal = segment(signal)
    return new_signal, trimmed, length

def get_index(index_filename):
    ''' Read index data structure from file. '''
    index_file = open(index_filename, 'r')
    index = {}
    for line in index_file:
        uuid, fname = re.split(r'\t+', line)
        index[uuid] = fname.rstrip()
    index_file.close()
    return index


def create_index(fast5_dir, force=False):
    '''
    Create file which stores read FAST5 to UUID mappings. 
    '''

    # return existing index if possible
    index_fn = f'{fast5_dir}/index.db'
    if not force and os.path.exists(index_fn):
        return get_index(index_fn)

    # remove existing index
    if os.path.exists(index_fn):
        os.remove(index_fn)

    # create new index    
    index_file = open(index_fn, 'w')

    # iterate through all FAST5 files in directory
    for subdir, dirs, files in os.walk(fast5_dir):
        for filename in files:
            ext = os.path.splitext(filename)[-1].lower()
            if ext == ".fast5":

                # print read uuid and filename to index
                fast5_file = h5py.File(os.path.join(subdir, filename), 'r')
                if 'Raw' in fast5_file: # single-FAST5
                    for readname in fast5_file['Raw']['Reads']:
                        uuid = fast5_file['Raw']['Reads'][readname].attrs['read_id']
                        print('{}\t{}'.format(uuid.decode('utf-8'), \
                                os.path.join(subdir, filename)), file=index_file)
                else: # multi-FAST5
                    for readname in fast5_file:
                        uuid = readname[5:] # remove 'read_' naming prefix
                        print('{}\t{}'.format(uuid, \
                                os.path.join(subdir, filename)), file=index_file)

    # cleanup and return results
    index_file.close()
    return get_index(index_filename)

In [None]:
def segment(signal):
    width = 5
    min_obs = 1
    npts = int((len(signal)*450)/4000)

    # get difference between all neighboring 'width' regions
    cumsum = np.cumsum(np.concatenate([[0.0], signal]))
    cand_poss = np.argsort(np.abs( (2 * cumsum[width:-width]) -
        cumsum[:-2*width] - cumsum[2*width:])).astype(int)[::-1]
    vals = np.abs( (2 * cumsum[width:-width]) - cumsum[:-2*width] - cumsum[2*width:])

    # keep 'npts' best checkpoints
    chkpts = []
    cand_idx = 0
    ct = 0
    blacklist = set()
    while ct < npts:
        edge_pos = cand_poss[cand_idx]
        if edge_pos not in blacklist:
            chkpts.append(edge_pos+width)
            ct += 1

            # blacklist nearby values (only use peaks)
            right = 0
            while edge_pos+right+1 < len(vals) and vals[edge_pos + right] > vals[edge_pos + right+1]:
                right += 1
                blacklist.add(edge_pos+right)
            left = 0
            while edge_pos+left > 0 and vals[edge_pos + left] > vals[edge_pos + left-1]:
                left -= 1
                blacklist.add(edge_pos+left)
        cand_idx += 1

    chkpts = np.sort(chkpts)
    new_signal = [np.mean(signal[0:chkpts[0]])]
    for i in range(len(chkpts)-1):
        new_signal.append(np.mean(signal[chkpts[i]:chkpts[i+1]]))
    return np.array(new_signal)

In [None]:
# create read UUID -> FAST5 filename mapping
virus_index = create_index(virus_fast5_dir)
other_index = create_index(other_fast5_dir)
full_index = {**virus_index, **other_index}

In [None]:
# select random subset of reads
virus_readnames = random.choices(list(virus_index.keys()), k=virus_max_reads)
other_readnames = random.choices(list(other_index.keys()), k=other_max_reads)

In [None]:
# trim all reads
with mp.Pool() as pool:
    virus_reads, virus_trims, virus_lengths = list(map(list, zip(*pool.map(preprocess_read, virus_readnames))))
    other_reads, other_trims, other_lengths = list(map(list, zip(*pool.map(preprocess_read, other_readnames))))

## DTW Alignment

In [None]:
ref = ref_sig
#aln_lens = np.array([int((x*450)/4000) for x in range(500,5001,500)])
aln_lens = np.array(range(500,5001,500))
nthresh = len(aln_lens)

In [None]:
@njit()
def sdtw(seq):
    ''' Returns minimum alignment score for subsequence DTW. '''
    
    # initialize cost matrix
    cost_mat = np.zeros((len(seq), len(ref)))
    cost_mat[0, 0] = abs(seq[0]-ref[0])#*(seq[0]-ref[0])
    for i in range(1, len(seq)):
        cost_mat[i, 0] = cost_mat[i-1, 0] + abs(seq[i]-ref[0])#*(seq[i]-ref[0])
        
    # compute entire cost matrix
    for i in range(1, len(seq)):
        for j in range(1, len(ref)):
            cost_mat[i, j] = min(cost_mat[i-1, j-1], cost_mat[i, j-1], cost_mat[i-1, j]) + \
                             abs(seq[i]-ref[j])#*(seq[i]-ref[j]) 
    
    # return cost of optimal alignment
    cost_mins = np.zeros((len(aln_lens),))
    for i in range(len(aln_lens)):
        if aln_lens[i] <= len(seq):
            cost_mins[i] = min(cost_mat[aln_lens[i]-1,:])
    return cost_mins

In [None]:
# limit cores since each aligner takes ~4GB of RAM to align
with mp.Pool(16) as pool:
    print(f'Aligning {virus} reads...', flush=True)
    virus_scores_list = pool.map(sdtw, virus_reads)
    print(f'Aligning {other} reads...', flush=True)
    other_scores_list = pool.map(sdtw, other_reads)

## Analyze Errors
Look at low-scoring human reads and high-scoring lambda reads to determine reason for error

In [None]:
plt.figure(figsize=(25,5))
#plt.plot(virus_reads[100])
plt.plot(virus_reads[0])
plt.show()
#608,  14, 863, 225, 442, 297, 966, 604, 767, 620, 956, 179,  11,
#       987, 976, 752, 723, 255, 663, 480, 212, 691, 926, 102, 461, 338,
#       113,  63, 483, 973, 950, 176, 394, 681, 661, 481, 569, 634, 740,

In [None]:
vscores = np.zeros((nthresh,len(virus_scores_list)))
for idx, scores in enumerate(virus_scores_list):
    for i in range(nthresh):
        vscores[i,idx] = scores[i]
oscores = np.zeros((nthresh,len(other_scores_list)))
for idx, scores in enumerate(other_scores_list):
    for i in range(nthresh):
        oscores[i,idx] = scores[i]

In [None]:
np.argsort(vscores[-1])[::-1]
#np.argsort(oscores[-1])

In [None]:
vscores[-1,608]

## Data Analysis

#### Save Results

In [None]:
# move data to numpy array for easy sorting/calculations
virus_scores = np.zeros((nthresh,len(virus_scores_list)))
for idx, scores in enumerate(virus_scores_list):
    for i in range(nthresh):
        virus_scores[i,idx] = scores[i]
other_scores = np.zeros((nthresh,len(other_scores_list)))
for idx, scores in enumerate(other_scores_list):
    for i in range(nthresh):
        other_scores[i,idx] = scores[i]
        
# save results
os.makedirs(results_dir, exist_ok=True)
np.save(f"{results_dir}/prefix_lengths", aln_lens)
np.save(f"{results_dir}/virus_trims", virus_trims)
np.save(f"{results_dir}/virus_lengths", virus_lengths)
np.save(f"{results_dir}/virus_scores", virus_scores)
np.save(f"{results_dir}/other_trims", other_trims)
np.save(f"{results_dir}/other_lengths", other_lengths)
np.save(f"{results_dir}/other_scores", other_scores)

#### Load Results

In [None]:
prefix_lengths = np.load(f"{results_dir}/prefix_lengths")
virus_trims = np.load(f"{results_dir}/virus_trims")
virus_lengths = np.load(f"{results_dir}/virus_lengths")
virus_scores = np.load(f"{results_dir}/virus_scores")
other_trims = np.load(f"{results_dir}/other_trims")
other_lengths = np.load(f"{results_dir}/other_lengths")
other_scores = np.load(f"{results_dir}/other_scores")

In [None]:
def get_stats(virus_scores, other_scores, thresh):
    ''' Return F-scores (assumes sorted input). '''
    fscores = np.zeros(nthresh)
    precs = np.zeros(nthresh)
    recalls = np.zeros(nthresh)
    for i in range(nthresh):
        # short reads don't receive a score, so ignore in accuracy metrics
        long_virus = np.count_nonzero(virus_scores[i])
        short_virus = virus_scores.shape[1]-long_virus
        tp = np.searchsorted(virus_scores[i], thresh) - short_virus
        fn = long_virus - tp
        long_other = np.count_nonzero(other_scores[i])
        short_other = other_scores.shape[1]-long_other
        fp = np.searchsorted(other_scores[i], thresh) - short_other
        precs[i] = 0 if not tp+fp else tp / (tp+fp)
        recalls[i] = 0 if not tp+fn else tp / (tp+fn)  
        fscores[i] = 0 if not tp+fp+fn else tp / (tp + 0.5*(fp + fn))
    return fscores, precs, recalls

In [None]:
# sort arrays (for fast f-score calculation)
virus_scores = np.sort(virus_scores)
other_scores = np.sort(other_scores)
max_score = max(np.max(virus_scores), np.max(other_scores))

# calculate all f-scores, and save the best thresholds
best_threshs = np.zeros(nthresh)
best_fscores = np.zeros(nthresh)
best_precs = np.zeros(nthresh)
best_recalls = np.zeros(nthresh)
for thresh in np.arange(max_score/100, max_score, max_score/100):
    fscores, precs, recalls = get_stats(virus_scores, other_scores, thresh)
    for i in range(nthresh):
        if fscores[i] > best_fscores[i]:
            best_fscores[i] = fscores[i]
            best_precs[i] = precs[i]
            best_recalls[i] = recalls[i]
            best_threshs[i] = thresh + 0.01

Plot score distribution for each signal prefix length

In [None]:
for i, l in enumerate(aln_lens):
    fig, ax = plt.subplots()
    ax.set_xlim(0, best_threshs[i]*2)
    ax.hist(virus_scores[i][virus_scores[i] > 0], bins=np.arange(1,best_threshs[i]*2, best_threshs[i]/30), facecolor='r', alpha=0.5)
    ax.hist(other_scores[i][other_scores[i] > 0], bins=np.arange(1,best_threshs[i]*2, best_threshs[i]/30), facecolor='g', alpha=0.5)
    ax.legend([virus, other])
    ax.set_xlabel('DTW Alignment Cost')
    ax.set_ylabel('Read Count')
    ax.axvline(best_threshs[i], color='k', linestyle='--')
    ax.set_title('{} Samples'.format(l))
    plt.show()

Cache all results in .npy files

In [None]:
os.makedirs(results_dir, exist_ok=True)
np.save(f"{results_dir}/fscores", best_fscores)
np.save(f"{results_dir}/precisions", best_precs)
np.save(f"{results_dir}/recalls", best_recalls)
np.save(f"{results_dir}/thresholds", best_threshs)

Generate accuracy plots for alignment method evaluation

In [None]:
results_dirs = ["baseline", "abs_value", "discrete_norm", "running_norm", "all_modifications"]
fig, axs = plt.subplots(1,3, figsize=(15,5))
for d in results_dirs:
    fscores = np.load(f"results/{d}/fscores.npy")
    axs[0].plot(aln_lens, fscores)
    precs = np.load(f"results/{d}/precisions.npy")
    axs[1].plot(aln_lens, precs)
    recalls = np.load(f"results/{d}/recalls.npy")
    axs[2].plot(aln_lens, recalls)
axs[0].set_title('F-score')
axs[0].set_ylim(0.5,1.03)
axs[1].set_ylim(0.5,1.03)
axs[2].set_ylim(0.5,1.03)
axs[0].set_xlabel('Samples')
axs[1].set_xlabel('Samples')
axs[2].set_xlabel('Samples')
axs[0].set_ylabel('Score')
axs[1].set_title('Precision')
axs[2].set_title('Recall')
axs[2].legend(results_dirs)
plt.show()

## Read Until Runtime and Precision
Analyze runtime as a function of accuracy and multi-stage thresholding

In [None]:
class Run(self, reads, flowcell=Flowcell()):
    self.target_coverage = 30.0
    self.coverage_bias = 1.0
    self.fwd_tr = 400.0           # bases / sec
    self.rev_tr = 100_000.0       # bases / sec
    self.reads = reads
    self.flowcell = flowcell

class Reads(self, results_dir):
    self.virus_conc = 0.01        # proportion virus
    self.threshold_poss = 
    self.virus_scores = np.load(f'{results_dir}/virus_scores.npy')
    self.virus_lengths = np.load(f'{results_dir}/virus_lengths.npy')
    self.virus_trims = np.load(f'{results_dir}/virus_trims.npy')
    
class Flowcell(self, version='minion'):
    self.version = version
    self.chemistry = 'r9.4.1'
    self.sampling_rate = 4000     # samples/sec
    self.reversal_latency = 0.070 # sec
    self.channels = 512
    

In [None]:
np.mean(other_lengths)