# DTW Alignment to Virus Reference

Manage all imports

In [None]:
from scipy import stats
import numpy as np
import random
import h5py
import matplotlib.pyplot as plt
%matplotlib inline
import multiprocessing as mp
from numba import njit

Define globals for selecting input data

In [None]:
ref_fn = "../data/covid/reference.fasta"
kmer_model_fn = "../data/kmer_model.txt"
k = 6 # 6-mer model
ref_read_fns = ["../data/covid/fast5/1000/fwd/reads.fast5", "../data/covid/fast5/1000/rev/reads.fast5"]
max_ref_reads = 200
other_read_fns = ["../data/human/fast5/505/reads.fast5"]
max_other_reads = 200

## Signal-Based Reference Setup
Define helper functions

In [None]:
def get_fasta(fasta_fn):
    ''' Get sequence from FASTA filename. '''
    with open(fasta_fn, 'r') as fasta:
        return ''.join(fasta.read().split('\n')[1:])

def rev_comp(bases):
    ''' Get reverse complement of sequence. '''
    return bases.replace('A','t').replace('T','a').replace('G','c').replace('C','g').upper()[::-1]

def load_model(kmer_model_fn):
    ''' Load k-mer model file into Python dict. '''
    kmer_model = {}
    with open(kmer_model_fn, 'r') as model_file:
        for line in model_file:
            kmer, current = line.split()
            kmer_model[kmer] = float(current)
    return kmer_model

def ref_signal(fasta, kmer_model):
    ''' Convert reference FASTA to expected reference signal (z-scores). '''
    signal = []
    for kmer_start in range(len(fasta)-k):
        signal.append(kmer_model[fasta[kmer_start:kmer_start+k]])
    return stats.zscore(signal)

Create COVID reference using (z-score normalized) expected k-mer currents for forward/backward reference FASTA

In [None]:
ref_fasta = get_fasta(ref_fn)
kmer_model = load_model(kmer_model_fn)
fwd_ref_sig = ref_signal(ref_fasta, kmer_model)
rev_ref_sig = ref_signal(rev_comp(ref_fasta), kmer_model)
ref_sig = np.concatenate((fwd_ref_sig, rev_ref_sig))

## Read Preparation

Define preprocessing functions for converting raw FAST5 data to normalized alignable signals

In [None]:
def read_start():
    ''' Heuristic for selecting read start. '''
    # Note: this will later change to using stall/adapter/barcode information.
    return 1000

def read_stop(): 
    ''' Heuristic for selecting read stop. '''
    # Note: this may later become variable.
    return 4000

def preprocess_reads(fast5_fn):
    ''' Returns all preprocessed reads from specified FAST5 file. '''
    fast5_file = h5py.File(fast5_fn, 'r')
    reads = []
    for read_name in fast5_file:
        signal = np.array(fast5_file[read_name]['Raw']['Signal'][:], dtype=np.int16)
        start, stop = read_start(), read_stop()
        trimmed_signal = signal[start:stop]
        if not len(trimmed_signal): continue # TODO: later check for != stop-start ?
        reads.append(stats.zscore(trimmed_signal))
    return reads

Preprocess reads into a list of NumPy arrays for each data source (ref/other), selecting a random subset

In [None]:
ref_reads = []
for fast5_fn in ref_read_fns:
    ref_reads.extend(preprocess_reads(fast5_fn))
random.shuffle(ref_reads)
ref_reads = ref_reads[:max_ref_reads]
    
other_reads = []
for fast5_fn in other_read_fns:
    other_reads.extend(preprocess_reads(fast5_fn))
random.shuffle(other_reads)
other_reads = other_reads[:max_other_reads]

## DTW Alignment

In [None]:
def sdtw(seq):
    ''' Returns minimum alignment score for subsequence DTW. '''
    
    # compute all pairwise signal differences
    cost_mat = np.zeros((len(seq), len(ref_sig)))
    delta_mat = np.zeros((len(seq), len(ref_sig)))
    np.subtract.outer(seq, ref_sig, out=delta_mat[:len(seq),:len(ref_sig)])
    np.square(delta_mat[:len(seq),:len(ref_sig)], out=delta_mat[:len(seq),:len(ref_sig)])  
    
    # initialize left side of cost_mat (top stays zero due to subsequence DTW)
    cost_mat[0, 0] = delta_mat[0, 0]
    for i in range(1, len(seq)):
        cost_mat[i, 0] = cost_mat[i-1, 0]+delta_mat[i, 0]

    # computer entire cost matrix
    for i in range(1, len(seq)):
        for j in range(1, len(ref_sig)):
            cost_mat[i, j] = delta_mat[i, j] + \
                min(cost_mat[i-1, j-1], cost_mat[i, j-1], cost_mat[i-1, j])
    
    # return cost of optimal alignment
    return np.sqrt(np.min(cost_mat[len(seq)-1, :]))

In [None]:
with mp.Pool() as pool:
    ref_dists = pool.map(sdtw, ref_reads)
    other_dists = pool.map(sdtw, other_reads)

## Data Analysis

In [None]:
fig, ax = plt.subplots()
ax.hist(ref_dists, bins=30, facecolor='r', alpha=0.5)
ax.hist(other_dists, bins=30, facecolor='g', alpha=0.5)
ax.legend(['COVID', 'human'])
ax.set_xlabel('Alignment Cost')
ax.set_ylabel('Read Count')
plt.show()