# DTW Alignment to Virus Reference

Manage all imports

In [None]:
from sklearn import metrics
import numpy as np
from glob import glob
import random
import h5py
import matplotlib.pyplot as plt
%matplotlib inline
import multiprocessing as mp
from numba import njit
import seaborn as sns

Define globals for selecting input data

In [None]:
ref_fn = "../data/covid/reference.fasta"
kmer_model_fn = "../data/dna_kmer_model.txt"
k = 6 # 6-mer model
ref_read_fns = ["../data/covid/fast5/SP1-mapped109.fast5"]
#ref_read_fns = glob("../data/covid/fast5/*.fast5")
max_ref_reads = 50
other_read_fns = ["../data/human/fast5/1000/reads.fast5", "../data/zymo/fast5/907/reads.fast5"]
#other_read_fns = glob("../data/human/chr20/*.fast5")
max_other_reads = 50

## Signal-Based Reference Setup
Define helper functions

In [None]:
def get_fasta(fasta_fn):
    ''' Get sequence from FASTA filename. '''
    with open(fasta_fn, 'r') as fasta:
        return ''.join(fasta.read().split('\n')[1:])

def rev_comp(bases):
    ''' Get reverse complement of sequence. '''
    return bases.replace('A','t').replace('T','a').replace('G','c').replace('C','g').upper()[::-1]

def load_model(kmer_model_fn):
    ''' Load k-mer model file into Python dict. '''
    kmer_model = {}
    with open(kmer_model_fn, 'r') as model_file:
        for line in model_file:
            kmer, current = line.split()
            kmer_model[kmer] = float(current)
    return kmer_model

def discrete_normalize(seq, bits=8, minval=-4, maxval=4):
    ''' Approximate normalization which converts signal to integer of desired precision. '''
    mean = int(np.mean(seq))
    mean_avg_dev = int(np.mean(np.abs(seq - mean)))
    norm_seq = (seq - mean) / mean_avg_dev
    
    norm_seq[norm_seq < minval] = minval
    norm_seq[norm_seq > maxval] = maxval
    norm_seq = ( (norm_seq - minval) * (2**(bits)/(maxval-minval)) ).astype(int)
    return norm_seq

def ref_signal(fasta, kmer_model):
    ''' Convert reference FASTA to expected reference signal (z-scores). '''
    signal = np.zeros(len(fasta))
    for kmer_start in range(len(fasta)-k):
        signal[kmer_start] = kmer_model[fasta[kmer_start:kmer_start+k]]
    return discrete_normalize(signal*100) # increase dist between floats before rounding

Create COVID reference using (z-score normalized) expected k-mer currents for forward/backward reference FASTA

In [None]:
ref_fasta = get_fasta(ref_fn)
kmer_model = load_model(kmer_model_fn)
fwd_ref_sig = ref_signal(ref_fasta, kmer_model)
rev_ref_sig = ref_signal(rev_comp(ref_fasta), kmer_model)
ref_sig = np.concatenate((fwd_ref_sig, rev_ref_sig))

## Read Preparation

Define preprocessing functions for converting raw FAST5 data to normalized alignable signals

In [None]:
stall_threshold = 200
stall_events = 2
stall_event_len = 3

In [None]:
def get_stall_end(signal):
    ''' Determine the end of the DNA stall region. '''
    
    # take average of a few samples to reduce variation
    events = []
    for event in range(0, len(signal), stall_event_len):
        events.append(np.mean(signal[event:event+stall_event_len]))
    
    # find where we exceed threshold for a few consecutive events
    above_threshold_count = 0
    event_pos = 0
    for event in events:
        event_pos += 1
        if event > stall_threshold:
            above_threshold_count += 1
        else:
            above_threshold_count = 0
        if above_threshold_count == stall_events:
            break
            
    # find where we go back below threshold
    below_threshold_count = 0
    for event in events[event_pos:]:
        event_pos += 1
        if event < stall_threshold:
            below_threshold_count += 1
        else:
            below_threshold_count = 0
        if below_threshold_count == stall_events:
            break
            
    return event_pos * stall_event_len


def trim_signal(signal):
    ''' Trims signal by detecting stall and adapter. '''
    stall_end = get_stall_end(signal)
    return signal[stall_end:]
    

def preprocess_reads(fast5_fn):
    ''' Returns all preprocessed reads from specified FAST5 file. '''
    fast5_file = h5py.File(fast5_fn, 'r')
    reads = []
    for read_name in fast5_file:
        signal = np.array(fast5_file[read_name]['Raw']['Signal'][:], dtype=np.int16)
        signal = discrete_normalize(signal)
        trimmed_signal = trim_signal(signal)
        reads.append(trimmed_signal)
    return reads

Preprocess reads into a list of NumPy arrays for each data source (ref/other), selecting a random subset

In [None]:
ref_reads = []
for fast5_fn in ref_read_fns:
    ref_reads.extend(preprocess_reads(fast5_fn))
random.shuffle(ref_reads)
ref_reads = ref_reads[:max_ref_reads]
    
other_reads = []
for fast5_fn in other_read_fns:
    other_reads.extend(preprocess_reads(fast5_fn))
random.shuffle(other_reads)
other_reads = other_reads[:max_other_reads]

## DTW Alignment

Define globals for use in DTW (reference signals and alignment lengths)

In [None]:
ref = ref_sig
aln_lens = np.array(range(1000,5001,1000))

In [None]:
@njit()
def sdtw(seq):
    ''' Returns minimum alignment score for subsequence DTW. '''
    
    # initialize cost matrix
    cost_mat = np.zeros((len(seq), len(ref)))
    cost_mat[0, 0] = abs(seq[0]-ref[0])
    for i in range(1, len(seq)):
        cost_mat[i, 0] = cost_mat[i-1, 0] + abs(seq[i]-ref[0])

    # compute entire cost matrix
    for i in range(1, len(seq)):
        for j in range(1, len(ref)):
            cost_mat[i, j] = abs(seq[i]-ref[j]) + \
                min(cost_mat[i-1, j-1], cost_mat[i, j-1], cost_mat[i-1, j])
    
    # return cost of optimal alignment
    cost_mins = np.zeros((len(aln_lens),))
    for i in range(len(aln_lens)):
        cost_mins[i] = min(cost_mat[aln_lens[i]-1,:])
    return cost_mins

In [None]:
with mp.Pool() as pool:
    ref_dists = pool.map(sdtw, ref_reads)
    other_dists = pool.map(sdtw, other_reads)

## Data Analysis

In [None]:
thresholds = [11000, 20000, 29000, 39000, 49000]
for i, l in enumerate(aln_lens):
    fig, ax = plt.subplots()
    ax.hist([x[i] for x in ref_dists], bins=50, facecolor='r', alpha=0.5)
    ax.hist([x[i] for x in other_dists], bins=50, facecolor='g', alpha=0.5)
    ax.legend(['COVID', 'Human'])
    ax.set_xlabel('DTW Alignment Cost')
    ax.set_ylabel('Read Count')
    ax.axvline(thresholds[i], color='k', linestyle='--')
    ax.set_title('{} Signals'.format(l))
    plt.show()

In [None]:
for i, l in enumerate(aln_lens):
    pred = [x[i] < thresholds[i] for x in ref_dists] + \
            [x[i] < thresholds[i] for x in other_dists]
    truth = [True]*len(ref_dists) + [False]*len(other_dists)
    cm = metrics.confusion_matrix(pred, truth)
    print("Alignment Length: {}".format(l))
    print("\tCOVID Retention Rate: {}".format(np.sum([x[i] < thresholds[i] for x in ref_dists])/len(ref_dists)))
    print("\tOther Discard Rate: {}\n".format(np.sum([x[i] > thresholds[i] for x in other_dists])/len(other_dists)))