# DTW Alignment to Virus Reference

Manage all imports

In [None]:
from sklearn import metrics
from itertools import repeat
from numba import njit
from glob import glob
from scipy import stats
from pyguppy_client_lib.pyclient import PyGuppyClient
from pyguppy_client_lib.helper_functions import package_read, basecall_with_pyguppy

import random, h5py, re, os, mappy

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import multiprocessing as mp
import seaborn as sns

%matplotlib inline
mpl.rcParams['figure.dpi'] = 300

Define globals for selecting input data

In [None]:
data_dir = "../data"
kmer_model_fn, k = f"{data_dir}/dna_kmer_model.txt", 6 # 6-mer model
virus, other, dna_type, virus_ds, other_ds = "lambda", "human", "DNA", "0", "2"
ref_fn = f"{data_dir}/{virus}/{dna_type}/{virus_ds}/reference.fasta"
virus_fast5_dir = f"{data_dir}/{virus}/{dna_type}/{virus_ds}/fast5"
other_fast5_dir = f"{data_dir}/{other}/{dna_type}/{other_ds}/fast5"
results_dir = f"./results/move_bonus_8"
virus_max_reads = 1000
other_max_reads = 1000
#prefix_lengths = np.array([int((x*450)/4000) for x in range(500,5001,500)])
prefix_lengths = np.array(range(500,5001,500))
nprefixes = len(prefix_lengths)

## Signal-Based Reference Setup
Define helper functions

In [None]:
def get_fasta(fasta_fn):
    ''' Get sequence from FASTA filename. '''
    with open(fasta_fn, 'r') as fasta:
        return ''.join(fasta.read().split('\n')[1:])

def rev_comp(bases):
    ''' Get reverse complement of sequence. '''
    return bases.replace('A','t').replace('T','a').replace('G','c').replace('C','g').upper()[::-1]

def load_model(kmer_model_fn):
    ''' Load k-mer model file into Python dict. '''
    kmer_model = {}
    with open(kmer_model_fn, 'r') as model_file:
        for line in model_file:
            kmer, current = line.split()
            kmer_model[kmer] = float(current)
    return kmer_model

def discrete_normalize(seq, bits=8, minval=-4, maxval=4):
    ''' Approximate normalization which converts signal to integer of desired precision. '''
    mean = int(np.mean(seq))
    mean_avg_dev = int(np.mean(np.abs(seq - mean)))
    norm_seq = (seq - mean) / mean_avg_dev
    
    norm_seq[norm_seq < minval] = minval
    norm_seq[norm_seq > maxval] = maxval
    norm_seq = ( (norm_seq - minval) * (2**(bits)/(maxval-minval)) ).astype(int)
    return norm_seq

def ref_signal(fasta, kmer_model):
    ''' Convert reference FASTA to expected reference signal (z-scores). '''
    signal = np.zeros(len(fasta))
    for kmer_start in range(len(fasta)-k):
        signal[kmer_start] = kmer_model[fasta[kmer_start:kmer_start+k]]
    return discrete_normalize(signal*100) # increase dist between floats before rounding
    #return stats.zscore(signal*100)

Create COVID reference using (z-score normalized) expected k-mer currents for forward/backward reference FASTA

In [None]:
ref_fasta = get_fasta(ref_fn)
kmer_model = load_model(kmer_model_fn)
fwd_ref_sig = ref_signal(ref_fasta, kmer_model)
rev_ref_sig = ref_signal(rev_comp(ref_fasta), kmer_model)
ref = np.concatenate((fwd_ref_sig, rev_ref_sig))

## Read Preparation

Define preprocessing functions for converting raw FAST5 data to normalized alignable signals

In [None]:
def get_stall_end(signal, stall_threshold=3, 
                  stall_events=2, stall_event_len=3):
    ''' Determine the end of the DNA stall region. '''
    
    # take average of a few samples to reduce variation
    events = []
    for event in range(0, len(signal), stall_event_len):
        events.append(np.mean(signal[event:event+stall_event_len]))
    
    # find where we exceed threshold for a few consecutive events
    above_threshold_count = 0
    event_pos = 0
    for event in events:
        event_pos += 1
        if event > stall_threshold:
            above_threshold_count += 1
        else:
            above_threshold_count = 0
        if above_threshold_count == stall_events:
            break
            
    # find where we go back below threshold
    below_threshold_count = 0
    for event in events[event_pos:]:
        event_pos += 1
        if event < stall_threshold:
            below_threshold_count += 1
        else:
            below_threshold_count = 0
        if below_threshold_count == stall_events:
            break
            
    return event_pos * stall_event_len


def trim(signal):
    ''' Trims signal by detecting stall (and eventually adapter). '''
    stall_end = get_stall_end(stats.zscore(signal))
    return signal[stall_end+1000 : stall_end+6000], stall_end


def filter_outliers(signal, minval=-4, maxval=4):
    
    # return empty signals as-is
    if not len(signal): return signal
    
    # upper threshold
    for idx, x in enumerate(signal):
        if x > maxval:
            # other values above max -> threshold to max
            if (idx+1 < len(signal) and signal[idx+1] > 3) or \
            (idx > 0 and signal[idx-1] > maxval):
                signal[idx] = maxval
            # otherwise, single outlier -> interpolate
            elif idx == 0:
                signal[idx] = signal[1]
            elif idx+1 == len(signal):
                signal[idx] = signal[idx-1]
            else:
                signal[idx] = (signal[idx-1] + signal[idx+1]) / 2
                
    # lower threshold
    for idx, x in enumerate(signal):
        if x < minval:
            # other values below min -> threshold to min
            if (idx+1 < len(signal) and signal[idx+1] < -3) or \
            (idx > 0 and signal[idx-1] < minval):
                signal[idx] = minval
            # otherwise, single outlier -> interpolate
            elif idx == 0:
                signal[idx] = signal[1]
            elif idx+1 == len(signal):
                signal[idx] = signal[idx-1]
            else:
                signal[idx] = (signal[idx-1] + signal[idx+1]) / 2
                
    return signal

class Read():                                                                                                                                                                                                                                 
    def __init__(self, signal, read_id, offset=0, scaling=1.0):                                                                                                                                                                               
        self.signal = signal                                                                                                                                                                                                                  
        self.read_id = read_id                                                                                                                                                                                                                
        self.total_samples = len(signal)                                                                                                                                                                                                      
        self.daq_offset = offset                                                                                                                                                                                                              
        self.daq_scaling = scaling                                                                                                                                                                                                            
        self.read_tag = random.randint(0, int(2**32 - 1))  

def ba_preprocess_read(uuid, length):
    readname = f"read_{uuid}"
    fast5_file = h5py.File(full_index[uuid], 'r')
    signal = np.array(fast5_file[readname]['Raw']['Signal'][:], dtype=np.int16)
    signal, trimmed = trim(signal)
    if len(signal) < max(prefix_lengths): return None
    signal_dig = fast5_file[readname]['channel_id'].attrs['digitisation']
    signal_offset = fast5_file[readname]['channel_id'].attrs['offset']
    signal_range = fast5_file[readname]['channel_id'].attrs['range']
    signal_scaling = signal_range / signal_dig
    return Read(signal, readname, offset=signal_offset, scaling=signal_scaling)

    
def preprocess_read(uuid):
    ''' Return preprocessed read from specified FAST5 file. '''
    readname = f"read_{uuid}"
    fast5_file = h5py.File(full_index[uuid], 'r')
    signal = np.array(fast5_file[readname]['Raw']['Signal'][:], dtype=np.int16)
    length = signal.shape[0]
    signal, trimmed = trim(signal)
    if len(signal) < max(prefix_lengths): return None
    new_signal = np.array(signal, dtype=float)
    for start in range(0, len(signal), 500):
        new_signal[start:start+500] = \
            discrete_normalize(signal[:start+500])[start:start+500]
            #stats.zscore(signal[:start+500])[start:start+500]
    #signal = filter_outliers(np.array(new_signal))
    #signal = discrete_normalize(signal)
    #signal = segment(signal)
    return new_signal, trimmed, length

def get_index(index_filename):
    ''' Read index data structure from file. '''
    index_file = open(index_filename, 'r')
    index = {}
    for line in index_file:
        uuid, fname = re.split(r'\t+', line)
        index[uuid] = fname.rstrip()
    index_file.close()
    return index


def create_index(fast5_dir, force=False):
    '''
    Create file which stores read FAST5 to UUID mappings. 
    '''

    # return existing index if possible
    index_fn = f'{fast5_dir}/index.db'
    if not force and os.path.exists(index_fn):
        return get_index(index_fn)

    # remove existing index
    if os.path.exists(index_fn):
        os.remove(index_fn)

    # create new index    
    index_file = open(index_fn, 'w')

    # iterate through all FAST5 files in directory
    for subdir, dirs, files in os.walk(fast5_dir):
        for filename in files:
            ext = os.path.splitext(filename)[-1].lower()
            if ext == ".fast5":

                # print read uuid and filename to index
                fast5_file = h5py.File(os.path.join(subdir, filename), 'r')
                if 'Raw' in fast5_file: # single-FAST5
                    for readname in fast5_file['Raw']['Reads']:
                        uuid = fast5_file['Raw']['Reads'][readname].attrs['read_id']
                        print('{}\t{}'.format(uuid.decode('utf-8'), \
                                os.path.join(subdir, filename)), file=index_file)
                else: # multi-FAST5
                    for readname in fast5_file:
                        uuid = readname[5:] # remove 'read_' naming prefix
                        print('{}\t{}'.format(uuid, \
                                os.path.join(subdir, filename)), file=index_file)

    # cleanup and return results
    index_file.close()
    return get_index(index_fn)

def segment(signal):
    width = 5
    min_obs = 1
    npts = int((len(signal)*450)/4000)

    # get difference between all neighboring 'width' regions
    cumsum = np.cumsum(np.concatenate([[0.0], signal]))
    cand_poss = np.argsort(np.abs( (2 * cumsum[width:-width]) -
        cumsum[:-2*width] - cumsum[2*width:])).astype(int)[::-1]
    vals = np.abs( (2 * cumsum[width:-width]) - cumsum[:-2*width] - cumsum[2*width:])

    # keep 'npts' best checkpoints
    chkpts = []
    cand_idx = 0
    ct = 0
    blacklist = set()
    while ct < npts:
        edge_pos = cand_poss[cand_idx]
        if edge_pos not in blacklist:
            chkpts.append(edge_pos+width)
            ct += 1

            # blacklist nearby values (only use peaks)
            right = 0
            while edge_pos+right+1 < len(vals) and vals[edge_pos + right] > vals[edge_pos + right+1]:
                right += 1
                blacklist.add(edge_pos+right)
            left = 0
            while edge_pos+left > 0 and vals[edge_pos + left] > vals[edge_pos + left-1]:
                left -= 1
                blacklist.add(edge_pos+left)
        cand_idx += 1

    chkpts = np.sort(chkpts)
    new_signal = [np.mean(signal[0:chkpts[0]])]
    for i in range(len(chkpts)-1):
        new_signal.append(np.mean(signal[chkpts[i]:chkpts[i+1]]))
    return np.array(new_signal)

In [None]:
# create read UUID -> FAST5 filename mapping
virus_index = create_index(virus_fast5_dir)
other_index = create_index(other_fast5_dir)
full_index = {**virus_index, **other_index}

In [None]:
# select random subset of reads
virus_readnames = random.choices(list(virus_index.keys()), k=virus_max_reads*2)
other_readnames = random.choices(list(other_index.keys()), k=other_max_reads*2)

## Basecalling and Alignment
Initialize aligner and basecaller for DNA lambda

In [None]:
aligner = mappy.Aligner(
    fn_idx_in = ref_fn,
    preset = "map-ont", # "splice" for RNA
    best_n = 1,
    k = 15 # 14 for RNA
)

basecaller = PyGuppyClient(
    address = "127.0.0.1:1234", 
    config = "dna_r9.4.1_450bps_fast.cfg",
    server_file_load_timeout=10
)
basecaller.connect()

In [None]:
def basecall(packets):
    calls = []                                                                                                                                                                                                                            
    sent, rcvd = 0, 0                                                                                                                                                                                                                         
    while sent < len(packets):                                                                                                                                                                                                            
        success = basecaller.pass_read(packets[sent])                                                                                                                                                                                     
        if not success:                                                                                                                                                                                                                       
            print('ERROR: Failed to basecall read.')                                                                                                                                                                                          
            break                                                                                                                                                                                                                             
        else:                                                                                                                                                                                                                                 
            sent += 1                                                                                                                                                                                                                         
    while rcvd < len(packets):                                                                                                                                                                                                            
        result = basecaller.get_completed_reads()                                                                                                                                                                                             
        rcvd += len(result)                                                                                                                                                                                                                   
        calls.extend(result)
    return calls

In [None]:
ba_virus_scores = np.zeros((nprefixes, virus_max_reads))
ba_other_scores = np.zeros((nprefixes, other_max_reads))
with mp.Pool() as pool:
    for prefix_idx, length in enumerate(prefix_lengths):
        
        # trim reads
        ba_virus_reads = list(filter(None, pool.starmap(
                    ba_preprocess_read, zip(virus_readnames, repeat(length)))))[:virus_max_reads]
        ba_other_reads = list(filter(None, pool.starmap(
                    ba_preprocess_read, zip(other_readnames, repeat(length)))))[:other_max_reads]
                           
        # package read data
        virus_pkts = [package_read(
            read_tag = read.read_tag, 
            read_id = read.read_id, 
            raw_data = read.signal, 
            daq_offset = float(read.daq_offset), 
            daq_scaling = float(read.daq_scaling)
        ) for read in ba_virus_reads]
        other_pkts = [package_read(
            read_tag = read.read_tag, 
            read_id = read.read_id, 
            raw_data = read.signal, 
            daq_offset = float(read.daq_offset), 
            daq_scaling = float(read.daq_scaling)
        ) for read in ba_other_reads]

        # basecall
        virus_calls = basecall(virus_pkts)
        other_calls = basecall(other_pkts)
        
        # align
        for call_idx, call in enumerate(virus_calls):
            try:
                alignment = next(aligner.map(call['datasets']['sequence']))
                ba_virus_scores[prefix_idx, call_idx] = alignment.mapq
            except(StopIteration):
                pass # no alignment
        for call_idx, call in enumerate(other_calls):
            try:
                alignment = next(aligner.map(call['datasets']['sequence']))
                ba_other_scores[prefix_idx, call_idx] = alignment.mapq
            except(StopIteration):
                pass # no alignment

In [None]:
virus_call_lengths = [len(call['datasets']['sequence']) for call in virus_calls]
other_call_lengths = [len(call['datasets']['sequence']) for call in other_calls]
plt.hist(virus_call_lengths, bins=np.linspace(0, 800, num=100), facecolor='r', alpha=0.5)
plt.hist(other_call_lengths, bins=np.linspace(0, 800, num=100), facecolor='g', alpha=0.5)
plt.legend([virus, other])
plt.xlabel('Call length (5000 samples)')
plt.ylabel('Read Count')
plt.show()

## sDTW Alignment

In [None]:
# trim all reads
with mp.Pool() as pool:
    virus_reads, virus_trims, virus_lengths = \
        list(map(list, zip(*filter(None, pool.map(preprocess_read, virus_readnames)))))
    other_reads, other_trims, other_lengths = \
        list(map(list, zip(*filter(None, pool.map(preprocess_read, other_readnames)))))

In [None]:
if len(virus_reads) < virus_max_reads:
    print(f'ERROR: only {len(virus_reads)} virus reads long enough, requested {virus_max_reads}')
if len(other_reads) < other_max_reads:
    print(f'ERROR: only {len(other_reads)} other reads long enough, requested {other_max_reads}')
virus_reads, virus_trims, virus_lengths = virus_reads[:virus_max_reads], \
    virus_trims[:virus_max_reads], virus_lengths[:virus_max_reads]
other_reads, other_trims, other_lengths = other_reads[:other_max_reads], \
    other_trims[:other_max_reads], other_lengths[:other_max_reads]

In [None]:
@njit()
def sdtw(seq):
    ''' Returns minimum alignment score for subsequence DTW. '''
    
    # initialize cost matrix
    cost_mat = np.zeros((len(seq), len(ref)))
    cost_mat[0, 0] = abs(seq[0]-ref[0])#*(seq[0]-ref[0])
    for i in range(1, len(seq)):
        cost_mat[i, 0] = cost_mat[i-1, 0] + abs(seq[i]-ref[0])#*(seq[i]-ref[0])
        
    # compute entire cost matrix
    for i in range(1, len(seq)):
        for j in range(1, len(ref)):
            #cost_mat[i, j] = min(cost_mat[i-1, j-1], cost_mat[i, j-1], cost_mat[i-1, j]) + \
            #                 abs(seq[i]-ref[j])#*(seq[i]-ref[j]) 
            cost_mat[i, j] = min(cost_mat[i-1, j-1], cost_mat[i-1, j]) + \
                 abs(seq[i]-ref[j])#*(seq[i]-ref[j]) 
    
    # return cost of optimal alignment
    cost_mins = np.zeros((len(prefix_lengths),))
    for i in range(len(prefix_lengths)):
        if prefix_lengths[i] <= len(seq):
            cost_mins[i] = min(cost_mat[prefix_lengths[i]-1,:])
    return cost_mins

In [None]:
@njit()
def sdtw(seq):
    ''' Returns minimum alignment score for subsequence DTW. '''
    
    # initialize cost matrix
    cost_mat = np.zeros((len(seq), len(ref)))
    cost_mat[0, 0] = abs(seq[0]-ref[0])
    for i in range(1, len(seq)):
        cost_mat[i, 0] = cost_mat[i-1, 0] + abs(seq[i]-ref[0])
    
    prev_consec = np.zeros((len(seq)))
    curr_consec = np.zeros((len(seq)))
    
    # compute entire cost matrix
    for j in range(1, len(ref)):
        bonus = 32
        for i in range(1, len(seq)):
            move = cost_mat[i-1, j-1] - prev_consec[i-1]*bonus < cost_mat[i-1, j]
            if move:
                curr_consec[i] = 0
                cost_mat[i, j] = cost_mat[i-1, j-1] - prev_consec[i-1]*bonus + abs(seq[i]-ref[j])
            else:
                curr_consec[i] = min(10, prev_consec[i] + 1)
                cost_mat[i, j] = cost_mat[i-1, j] + abs(seq[i]-ref[j])
        prev_consec = curr_consec[:]
        curr_consec = np.zeros((len(seq)))
    
    # return cost of optimal alignment
    cost_mins = np.zeros((len(prefix_lengths),))
    for i in range(len(prefix_lengths)):
        if prefix_lengths[i] <= len(seq):
            cost_mins[i] = min(cost_mat[prefix_lengths[i]-1,:])
    return cost_mins

In [None]:
# limit cores since each aligner takes ~4GB of RAM to align
with mp.Pool(14) as pool:
    print(f'Aligning {virus} reads...', flush=True)
    virus_scores_list = pool.map(sdtw, virus_reads)
    print(f'Aligning {other} reads...', flush=True)
    other_scores_list = pool.map(sdtw, other_reads)

## Analyze Errors
Look at low-scoring human reads and high-scoring lambda reads to determine reason for error

In [None]:
vscores = np.zeros((nprefixes,len(virus_scores_list)))
for idx, scores in enumerate(virus_scores_list):
    for i in range(nprefixes):
        vscores[i,idx] = scores[i]
oscores = np.zeros((nprefixes,len(other_scores_list)))
for idx, scores in enumerate(other_scores_list):
    for i in range(nprefixes):
        oscores[i,idx] = scores[i]
high_virus = np.argsort(vscores[-1])[::-1]
low_other = np.argsort(oscores[-1])

In [None]:
for idx in high_virus[:20]:
    plt.figure(figsize=(25,5))
    plt.plot(virus_reads[idx])
    plt.title(idx)
plt.show()

## Data Analysis

#### Save Results

In [None]:
# move data to numpy array for easy sorting/calculations
virus_scores = np.zeros((nprefixes, len(virus_scores_list)))
for idx, scores in enumerate(virus_scores_list):
    for i in range(nprefixes):
        virus_scores[i,idx] = scores[i]
other_scores = np.zeros((nprefixes, len(other_scores_list)))
for idx, scores in enumerate(other_scores_list):
    for i in range(nprefixes):
        other_scores[i,idx] = scores[i]
        
# save results
os.makedirs(results_dir, exist_ok=True)
np.save(f"{results_dir}/prefix_lengths", prefix_lengths)
np.save(f"{results_dir}/virus_trims", virus_trims)
np.save(f"{results_dir}/virus_lengths", virus_lengths)
np.save(f"{results_dir}/virus_scores", virus_scores)
np.save(f"{results_dir}/other_trims", other_trims)
np.save(f"{results_dir}/other_lengths", other_lengths)
np.save(f"{results_dir}/other_scores", other_scores)

In [None]:
np.save(f"{results_dir}/ba_virus_scores", ba_virus_scores)
np.save(f"{results_dir}/ba_other_scores", ba_other_scores)

#### Load Results

In [None]:
prefix_lengths = np.load(f"{results_dir}/prefix_lengths.npy")
virus_trims = np.load(f"{results_dir}/virus_trims.npy")
virus_lengths = np.load(f"{results_dir}/virus_lengths.npy")
virus_scores = np.load(f"{results_dir}/virus_scores.npy")
other_trims = np.load(f"{results_dir}/other_trims.npy")
other_lengths = np.load(f"{results_dir}/other_lengths.npy")
other_scores = np.load(f"{results_dir}/other_scores.npy")

In [None]:
ba_virus_scores = np.load(f"{results_dir}/ba_virus_scores.npy")
ba_other_scores = np.load(f"{results_dir}/ba_other_scores.npy")

In [None]:
def get_stats(virus_scores, other_scores, thresh):
    ''' Return F-scores (assumes sorted input). '''
    fscores = np.zeros(nprefixes)
    precs = np.zeros(nprefixes)
    recalls = np.zeros(nprefixes)
    for i in range(nprefixes):
        # short reads don't receive a score, so ignore in accuracy metrics
        long_virus = np.count_nonzero(virus_scores[i])
        short_virus = virus_scores.shape[1]-long_virus
        tp = np.searchsorted(virus_scores[i], thresh) - short_virus
        fn = long_virus - tp
        long_other = np.count_nonzero(other_scores[i])
        short_other = other_scores.shape[1]-long_other
        fp = np.searchsorted(other_scores[i], thresh) - short_other
        precs[i] = 0 if not tp+fp else tp / (tp+fp)
        recalls[i] = 0 if not tp+fn else tp / (tp+fn)  
        fscores[i] = 0 if not tp+fp+fn else tp / (tp + 0.5*(fp + fn))
    return fscores, precs, recalls

In [None]:
# sort arrays (for fast f-score calculation)
virus_scores = np.sort(virus_scores)
other_scores = np.sort(other_scores)
min_score = min(np.min(virus_scores), np.min(other_scores))
max_score = max(np.max(virus_scores), np.max(other_scores))

# calculate all f-scores, and save the best thresholds
best_threshs = np.zeros(nprefixes)
best_fscores = np.zeros(nprefixes)
best_precs = np.zeros(nprefixes)
best_recalls = np.zeros(nprefixes)
for thresh in np.linspace(min_score, max_score, num=100):
    fscores, precs, recalls = get_stats(virus_scores, other_scores, thresh)
    for i in range(nprefixes):
        if fscores[i] > best_fscores[i]:
            best_fscores[i] = fscores[i]
            best_precs[i] = precs[i]
            best_recalls[i] = recalls[i]
            best_threshs[i] = thresh + 0.01
            
np.save(f"{results_dir}/fscores", best_fscores)
np.save(f"{results_dir}/precisions", best_precs)
np.save(f"{results_dir}/recalls", best_recalls)

Plot score distribution for each signal prefix length

In [None]:
for i, l in enumerate(prefix_lengths):
    fig, ax = plt.subplots()
    ax.set_xlim(0, best_threshs[i]*2)
    minval = min(np.min(virus_scores[i]), np.min(other_scores[i]))
    maxval = max(np.max(virus_scores[i]), np.max(other_scores[i]))
    ax.hist(virus_scores[i], bins=np.linspace(minval, maxval, num=100), facecolor='r', alpha=0.5)
    ax.hist(other_scores[i], bins=np.linspace(minval, maxval, num=100), facecolor='g', alpha=0.5)
    ax.legend([virus, other])
    ax.set_xlabel('DTW Alignment Cost')
    ax.set_ylabel('Read Count')
    ax.axvline(best_threshs[i], color='k', linestyle='--')
    ax.set_title('{} Samples'.format(l))
    plt.show()

#### Basecall Alignment Score Distribution

In [None]:
for i, l in enumerate([1000]):
    fig, ax = plt.subplots()
    minval = min(np.min(ba_virus_scores[i]), np.min(ba_other_scores[i]))
    maxval = max(np.max(ba_virus_scores[i]), np.max(ba_other_scores[i]))
    ax.hist(ba_virus_scores[i], bins=np.linspace(minval, maxval, num=100), facecolor='r', alpha=0.5)
    ax.hist(ba_other_scores[i], bins=np.linspace(minval, maxval, num=100), facecolor='g', alpha=0.5)
    ax.legend([virus, other])
    ax.set_xlabel('Alignment Score')
    ax.set_ylabel('Read Count')
    ax.set_title('{} Samples'.format(l))
    plt.show()

Generate accuracy plots for alignment method evaluation

In [None]:
#results_dirs = ["baseline", "abs_value", "discrete_norm", "running_norm", "all_modifications"]
results_dirs = ["move_baseline", "move_bonus_4", "move_bonus_8", "move_bonus_16", "move_bonus_32"]
fig, axs = plt.subplots(1,3, figsize=(15,5))
for d in results_dirs:
    fscores = np.load(f"results/{d}/fscores.npy")
    axs[0].plot(prefix_lengths, fscores)
    precs = np.load(f"results/{d}/precisions.npy")
    axs[1].plot(prefix_lengths, precs)
    recalls = np.load(f"results/{d}/recalls.npy")
    axs[2].plot(prefix_lengths, recalls)
axs[0].set_title('F-score')
axs[0].set_ylim(0.5,1.03)
axs[1].set_ylim(0.5,1.03)
axs[2].set_ylim(0.5,1.03)
axs[0].set_xlabel('Samples')
axs[1].set_xlabel('Samples')
axs[2].set_xlabel('Samples')
axs[0].set_ylabel('Score')
axs[1].set_title('Precision')
axs[2].set_title('Recall')
axs[2].legend(results_dirs)
plt.show()

In [None]:
dtw_lengths = [1000, 3000, 5000]
ba_lengths = [1000]

# initialize plots
mpl.rcParams.update({'font.size': 14})
dtw_indices = [np.where(prefix_lengths == t)[0] for t in dtw_lengths]
ba_indices = [np.where(prefix_lengths == t)[0] for t in ba_lengths]
fig, ax = plt.subplots()
ax.plot(0, 1, color='k', marker='*', markersize=15, linestyle='None')                                                                                                                                                                        
ax.plot([0,1], [0,1], color='k', marker='None', linestyle=':')

# all DTW plots
for i, l in zip(dtw_indices, dtw_lengths):
    i = int(i)
    minval = min(np.min(virus_scores[i]), np.min(other_scores[i]))-1
    maxval = max(np.max(virus_scores[i]), np.max(other_scores[i]))+1
    thresholds = np.linspace(minval, maxval, num=100)

    other_discard_rate, virus_discard_rate = [], []
    for t in thresholds:
        virus_discard_rate.append(np.sum(virus_scores[i] > t) / len(virus_scores[i]))
        other_discard_rate.append(np.sum(other_scores[i] > t) / len(other_scores[i]))
    ax.plot(virus_discard_rate, other_discard_rate, marker='o', alpha=0.5)
    
# all Guppy + Minimap2 plots
for i, l in zip(ba_indices, ba_lengths):
    i = int(i)
    minval = min(np.min(ba_virus_scores[i]), np.min(ba_other_scores[i]))-1
    maxval = max(np.max(ba_virus_scores[i]), np.max(ba_other_scores[i]))+1
    thresholds = np.linspace(minval, maxval, num=100)

    other_discard_rate, virus_discard_rate = [], []
    for t in thresholds:
        virus_discard_rate.append(sum(ba_virus_scores[i] < t) / len(ba_virus_scores[i]))
        other_discard_rate.append(sum(ba_other_scores[i] < t) / len(ba_other_scores[i]))
    ax.plot(virus_discard_rate, other_discard_rate, marker='o', alpha=0.5)
    
ax.set_xlabel('Lambda Phage Discard Rate')
ax.set_ylabel('Human Discard Rate')
#ax.set_title('Guppy-lite ')
ax.legend(["ideal", "random"] + \
        [f"SquiggleFilter {x} samples" for x in dtw_lengths] + \
        [f"Guppy-lite {x} samples" for x in ba_lengths], loc="lower right")      
ax.set_xlim((-0.1, 1.1))
ax.set_ylim((-0.1, 1.1))
plt.show()

## Read Until Runtime
Analyze runtime as a function of accuracy and multi-stage thresholding

In [None]:
class Reads():
    
    def __init__(self, results_dir):
        self.prop_virus = 0.01        # proportion virus
        self.prop_other = 1 - self.prop_virus
        self.prefix_lengths = np.load(f'{results_dir}/prefix_lengths.npy')

        self.ba_virus_scores = np.load(f'{results_dir}/ba_virus_scores.npy')
        self.virus_scores = np.load(f'{results_dir}/virus_scores.npy')
        self.virus_lengths = np.load(f'{results_dir}/virus_lengths.npy')
        self.avg_virus_length = np.mean(self.virus_lengths)
        self.virus_trims = np.load(f'{results_dir}/virus_trims.npy')
        self.avg_virus_trim = np.mean(self.virus_trims)

        self.ba_other_scores = np.load(f'{results_dir}/ba_other_scores.npy')
        self.other_scores = np.load(f"{results_dir}/other_scores.npy")
        self.other_lengths = np.load(f"{results_dir}/other_lengths.npy")
        self.avg_other_length = np.mean(self.other_lengths)
        self.other_trims = np.load(f"{results_dir}/other_trims.npy")
        self.avg_other_trim = np.mean(self.other_trims)
    
    
class Flowcell():
    
    def __init__(self, channels = 512):
        self.chemistry = 'r9.4.1'
        self.sampling_rate = 4000     # samples/sec
        self.minknow_latency = 0.070 # sec
        self.channels = channels
        
        
class Classifier():
    
    def __init__(self, method='sf'):
        self.method = method
        if method == 'sf': # SquiggleFilter
            #self.throughput = 15_318 * 2000
            self.throughput = 5000000000000
            self.latency = 0.0003264
        elif method == 'ba': # BasecallAlign
            self.throughput = 550 * 2000
            self.latency = 0.149
        else:
            raise Exception("Unknown Read Until classifier type.")
    

class Run():
    
    def __init__(self, reads, clf='sf', flowcell=Flowcell()):
        self.flowcell = flowcell
        self.clf = Classifier(clf)
        self.reads = reads
        
        self.target_coverage = 30.0
        self.target_genome_size = 30_000.0 # bases
        self.coverage_bias = 1.0
        self.fwd_tr = 400.0           # bases / sec
        self.rev_tr = 100_000.0       # bases / sec
        self.capture_time = 1.0       # sec

        self.sr = self.flowcell.sampling_rate
    
    def get_simple_runtime(self):
        max_throughput = self.flowcell.channels * self.sr
        virus_time = self.reads.prop_virus * \
            (self.capture_time + self.reads.avg_virus_length/self.sr)
        other_time = self.reads.prop_other * \
            (self.capture_time + self.reads.avg_other_length/self.sr)
        useful_time = self.reads.prop_virus * \
            (self.reads.avg_virus_length-self.reads.avg_virus_trim) / self.sr
        useful_throughput = max_throughput * useful_time / (virus_time + other_time)
        duration = self.target_genome_size * (self.sr/self.fwd_tr) * \
            self.target_coverage * self.coverage_bias / useful_throughput
        return duration
    
    
    def get_read_until_runtime(self, prefix_indices, thresholds):
        
        # get runtime without read until
        max_throughput = self.flowcell.channels * self.sr
        simple_virus_time = self.reads.prop_virus * \
            (self.capture_time + self.reads.avg_virus_length/self.sr)
        simple_other_time = self.reads.prop_other * \
            (self.capture_time + self.reads.avg_other_length/self.sr)
        simple_useful_time = self.reads.prop_virus * \
            (self.reads.avg_virus_length-self.reads.avg_virus_trim) / self.sr
        simple_prop_useful_time = simple_useful_time / \
            (simple_virus_time + simple_other_time)
        
        # What percentage of pores can perform Read Until?
        # - estimate required basecall throughput from simple sequencing
        bc_time = self.reads.prop_virus * self.reads.avg_virus_length/self.sr + \
            self.reads.prop_other * self.reads.avg_other_length/self.sr
        bc_throughput = max_throughput * (bc_time/ (simple_virus_time + simple_other_time))
        prop_ru = min(1.0, self.clf.throughput/bc_throughput)
        prop_simple = 1 - prop_ru
        
        # calculate sequencing runtime for multiple thresholds
        if self.clf.method == 'sf':
            rem_virus_scores = self.reads.virus_scores.copy()
            rem_other_scores = self.reads.other_scores.copy()
        else:
            rem_virus_scores = -self.reads.ba_virus_scores.copy()
            rem_other_scores = -self.reads.ba_other_scores.copy()   
            
        eject_virus_time, eject_other_time = 0, 0
        for i, thresh in zip(prefix_indices, thresholds):
            
            # device continues sequencing as we make a read-until decision
            length = self.reads.prefix_lengths[i]
            samples = length + self.sr * \
                (self.clf.latency + self.flowcell.minknow_latency)
            reversal_latency = samples / (self.rev_tr * (self.sr/self.fwd_tr))
            latency = self.clf.latency + self.flowcell.minknow_latency + reversal_latency
            
            # choose which reads to keep, count those ejected
            keep_virus = rem_virus_scores[i] < thresh
            rem_virus_scores = rem_virus_scores[:,keep_virus]
            n_eject_virus = len(keep_virus) - sum(keep_virus)
            keep_other = rem_other_scores[i] < thresh
            rem_other_scores = rem_other_scores[:,keep_other]
            n_eject_other = len(keep_other) - sum(keep_other)
            
            # update time spent sequencing ejected reads
            eject_virus_time += n_eject_virus * self.reads.prop_virus * \
                (self.capture_time + length/self.sr + latency)
            eject_other_time += n_eject_other * self.reads.prop_other * \
                (self.capture_time + length/self.sr + latency)
            
        # update total time spent sequencing each type of read
        ru_useful_time = len(rem_virus_scores[0]) * self.reads.prop_virus * \
            (self.reads.avg_virus_length - self.reads.avg_virus_trim) / self.sr
        ru_virus_time = eject_virus_time + len(rem_virus_scores[0]) * self.reads.prop_virus * \
            (self.capture_time + self.reads.avg_virus_length / self.sr)
        ru_other_time = eject_other_time + len(rem_other_scores[0]) * self.reads.prop_other * \
            (self.capture_time + self.reads.avg_other_length / self.sr)
        ru_prop_useful_time = ru_useful_time / (ru_virus_time + ru_other_time)
        
        # calculate duration based on simple/read until split
        prop_useful_time = prop_ru * ru_prop_useful_time + \
            prop_simple * simple_prop_useful_time
        useful_throughput = prop_useful_time * max_throughput + 0.0001
        duration = self.target_genome_size * (self.sr/self.fwd_tr) * \
            self.target_coverage * self.coverage_bias / useful_throughput
        return duration

#### SquiggleFilter Read Until Runtime

Find optimal set of thresholds

In [None]:
read_data = Reads(results_dir)
run = Run(read_data)

best_threshold_indices = []
best_threshold_values = []
best_threshold_time = run.get_simple_runtime()
for i1, p1 in enumerate(prefix_lengths):
    min1 = max(1, min(np.min(virus_scores[i1]), np.min(other_scores[i1])))
    max1 = max(np.max(virus_scores[i1]), np.max(other_scores[i1]))
    t1s = np.linspace(min1*1.1, max1*0.6, 10)
    for i2, p2 in enumerate(prefix_lengths):
        times = []
        min2 = max(1, min(np.min(virus_scores[i2]), np.min(other_scores[i2])))
        max2 = max(np.max(virus_scores[i2]), np.max(other_scores[i2]))
        t2s = np.linspace(min2*1.1, max2*0.6, 10)
        for t1 in t1s:
            for t2 in t2s:
                time = run.get_read_until_runtime([i1, i2], [t1, t2])
                times.append(time)
                if time < best_threshold_time:
                    best_threshold_time = time
                    best_threshold_indices = [i1,i2]
                    best_threshold_values = [t1, t2]
        print(f"\r{p1}-{p2}: {min(times)}                ", end='')
print(f'\nBest Time: {best_threshold_time}')
for i, t in zip(best_threshold_indices, best_threshold_values):
    print(f'@ sample {prefix_lengths[i]}: cutoff {t}')

In [None]:
read_data = Reads(results_dir)
run = Run(read_data)

plt.axhline(run.get_simple_runtime(), linestyle=':', color='red')
for i1, p1 in enumerate(prefix_lengths[1::2]):
    max_score = max(np.max(virus_scores[i1]), np.max(other_scores[i1]))
    min_score = max(1, min(np.min(virus_scores[i1]), np.min(other_scores[i1])))
    times = []
    thresholds = np.linspace(min_score, max_score, 100)
    for t in thresholds:
        times.append(run.get_read_until_runtime([i1], [t]))
    plt.plot(thresholds, times)
    print(f"{p1}: {min(times)}")
plt.axhline(best_threshold_time, linestyle=':', color='black')
#plt.ylim(0,600)
#plt.xlim(0,20000)
plt.ylim(0, run.get_simple_runtime()*1.3)
plt.legend(['no Read Until'] + list(prefix_lengths[1::2]) + ['two thresholds'], loc='lower right')
plt.show()

#### Basecall Align Read Until Runtime

In [None]:
read_data = Reads(results_dir)
run = Run(read_data, 'ba')

plt.axhline(run.get_simple_runtime(), linestyle=':', color='red')
for i1, p1 in enumerate(prefix_lengths[:3]):
    max_score = max(np.max(-ba_virus_scores[i1]), np.max(-ba_other_scores[i1]))
    min_score = min(np.min(-ba_virus_scores[i1]), np.min(-ba_other_scores[i1]))
    times = []
    thresholds = np.linspace(min_score, max_score, 100)
    for t in thresholds:
        times.append(run.get_read_until_runtime([i1], [t]))
    plt.plot(thresholds, times)
    print(f"{p1}: {min(times)}")
#plt.axhline(132.3, linestyle=':', color='green')
#plt.ylim(0, run.get_simple_runtime()*1.3)
plt.ylim(0,600)
#plt.xlim(0, 100_000)
plt.legend(['no Read Until'] + list(prefix_lengths[:3]) + ['two thresholds'], loc='lower right')
plt.show()

Plot decrease in Read Until sequencing time as channels increases

In [None]:
read_data = Reads(results_dir)

simple_runtimes = []
sf_runtimes = []
ba_runtimes = []
channels = np.linspace(512, 51200, num=100)
for channel_count in channels:
    flowcell = Flowcell(channel_count)
    sf_run = Run(read_data, 'sf', flowcell)
    ba_run = Run(read_data, 'ba', flowcell)
    simple_runtimes.append(sf_run.get_simple_runtime())
    sf_runtimes.append(sf_run.get_read_until_runtime(best_threshold_indices, best_threshold_values))
    ba_runtimes.append(ba_run.get_read_until_runtime([1], [-10]))
fig = plt.figure()
ax = fig.add_subplot(1,1,1)
ax.set_yscale('log')
ax.plot(channels, simple_runtimes)
ax.plot(channels, sf_runtimes)
ax.plot(channels, ba_runtimes)
ax.set_xlabel('Portable Sequencer Channels')
ax.set_ylabel('Sequencing Time')
ax.legend(['No Read Until', 'SquiggleFilter', 'Guppy-lite'])
plt.show()