In [431]:
import tqdm
import time
import editdistance 
import matplotlib.pyplot as plt
import random

import numpy as np
import numba as nb
from numba import njit, jit, prange, cuda, objmode
from numba.typed import List
from numba.experimental import jitclass
import plotly.graph_objects as go
import random
from collections import defaultdict
from tqdm import tqdm
np.random.seed(0)
random.seed(0)

In [501]:
@njit(fastmath=True)
def l2_dist(a, b): 
    d = a-b
    return np.sum(d*d);


@njit(fastmath=True)
def hamming_distance(a, b):
    return np.sum(a != b);


# only works for k<=32, would overflow otherwise 
@njit(fastmath=True)
def extract_kmers(seq,k):
    kmer = 0
    kmers = np.zeros(len(seq)-k+1,dtype=np.int64)
    for i,c in enumerate(seq):
        kmer = kmer * 4 + c
        kmer = kmer % (4**k)
        if i>=k:
            kmers[i-k] = kmer
    return kmers

@jitclass([
    ('A', nb.int32),
    ('t', nb.int32),
    ('D', nb.int32),
    ('normalize', nb.bool_),
    ('hashes', nb.int32[:, :]),
    ('signs', nb.float32[:, :])])
class TS():
    def __init__(self, t, D, A, normalize = True):
        self.A = A
        self.t = t
        self.D = D
        self.normalize = normalize

        # An A*t array of random integers in [0, D)
        self.hashes = np.empty((self.A, self.t), dtype=np.int32)
        # An A*t array of random +-1
        self.signs = np.empty((self.A, self.t), dtype=np.float32)
        for c in range(self.A):
            for k in range(self.t):
                self.hashes[c][k] = random.randrange(0, self.D)
                self.signs[c][k] = random.randrange(-1, 2, 2)

    def _full_sketch(self, seq: nb.int32[:]):
        # NOTE: The sketch is stored as float64 here so counting won't overflow.
        T = np.zeros((self.t + 1, self.D), dtype=np.float64)
        T[0][0] = 1
        
        for c in seq:
            for k in range(self.t - 1, -1, -1):
                h = self.hashes[c][k]
                s = self.signs[c][k]
                for l in range(self.D):
                    r = l + h if l + h < self.D else l + h - self.D
                    T[k + 1][l] += s * T[k][r]

        return T

    def _normalize(self, seq, T):
        if self.normalize:
            # Normalization factor.
            n = len(seq)
            nct = nb.float64(1)
            for i in range(self.t):
                nct = nct * (n - i) / (i + 1)
            T /= nct
        return T

    def sketch_one(self, seq: np.ndarray) -> nb.float32[:]:
        full_sketch = self._full_sketch(seq)

        self._normalize(seq, full_sketch[self.t])

        sketch = np.array([x for x in full_sketch[self.t]], dtype=np.float32)
        return sketch

    def sketch(self, seqs):
        return [self.sketch_one(seq) for seq in seqs]

    def dist(self, s1, s2):
        return l2_dist(s1,s2)
    
class TSS():
    def __init__(self, seq_len, t, W, S, D, A, normalize = True, sketch_class=TS):
        # reduce sketch dim to ensure that the flattened sketch size is D 
        D2 = int(D/int((seq_len-W)/S)+1)   
        self.sketcher = sketch_class(t=t,D=D2,A=A,normalize=normalize)
        self.t = t
        self.W = W
        self.S = S
        self.D2 = D2
        self.D = D
    
    def sketch_one(self, seq: nb.int32[:]) -> nb.float32[:,:]:
        L = int(np.ceil((len(seq)-self.W+1)/self.S))
        sketch = np.zeros((self.D2,L), dtype=np.float32)
        for si,i in enumerate(np.arange(0,len(seq)-self.W+1,self.S)):
            sketch[:,si] = self.sketcher.sketch_one(seq[i:i+self.W])
        
        return sketch
    
    def sketch(self, seqs):
        all_seqs = List()
        starts = np.arange(0,len(seqs[0])-self.W+1,self.S)
        n, m = len(seqs),len(starts)
        for seq in seqs:
            for i in starts:
                all_seqs.append(seq[i:i+self.W])
        all_sketches = self.sketcher.sketch(all_seqs)
        sketches = List()
        for i in range(n):
            sk = np.array([x for sk in all_sketches[i*m:(i+1) * m] for x in sk])
            sketches.append(sk)
        return sketches
    
    def dist(self, s1, s2):
        return l2_dist(s1,s2)

In [504]:
def print_binary(x, dim):
    s = ""
    for i in range(dim - 1, -1, -1):
        s += str((x >> i) & 1)
    print(s)
    
def mutate(seq, alphabet, rate):
    new_seq = []
    n = len(seq)
    i = 0
    while i < n:
        chance = np.random.uniform()
        if chance > rate:
            # do nothing
            new_seq.append(seq[i])
            i += 1
        else:
            # choose at random an operation between substitution, insertion, deletion
            op = np.random.choice([0,1,2])            
            if op == 0:
                # substitution
                c = np.random.choice(list(set(alphabet) - set([seq[i]])))
                new_seq.append(c)
                i += 1
            elif op == 1:
                # insertion
                c = np.random.choice(alphabet)
                new_seq.append(c)
            elif op == 2:
                # deletion
                i += 1
    return ''.join(new_seq)
    
def convert(seq, alphabet):
    inv_map = {alphabet[i]:i for i in range(len(alphabet))}
    new_seq = []
    for c in seq:
        new_seq.append(inv_map[c])
    return np.asarray(new_seq, dtype=np.uint8)

def generate_paths(seq, path_len, num_paths, k, mutation_rate, alphabet):
    paths = []
    for i in range(num_paths):
        # pick random index
        start = np.random.choice(len(seq) - path_len)
        path = seq[start:start+path_len]
        nodes = list(range(start, start + path_len - k + 1))
        
        assert len(nodes) == len(path) - k + 1
        
        paths.append((path, mutate(path, alphabet, mutation_rate), nodes))
    return paths

def get_discrete_mmer_sketches(mmers, N, m, alphabet, G, sketcher):
    sketches = []
    for mmer in mmers:
        sketch = sketcher.sketch_one(convert(mmer, alphabet))
        discrete_sketch = discretize(G, sketch)
        sketches.append(discrete_sketch)

    assert len(sketches) == N - m + 1, f"{len(sketches)} != {N - m + 1}"
    return sketches

def get_kmer_sketches(mmer_sketches, N, k, m, stride, embed_dim):
    # Form kmer sketches
    num_windows = int(np.ceil((k - m + 1) / stride))

    kmer_sketches = []
    for kmer in range(0, N - k + 1, 1):
        kmer_sketch = 0
        for mmer in range(kmer, kmer + num_windows * stride, stride):
#             print_binary(mmer_sketches[mmer], embed_dim)
            for bit in range(embed_dim - 1, -1, -1):
                kmer_sketch <<= 1
                kmer_sketch |= ((mmer_sketches[mmer] >> bit) & 1)
#         print_binary(kmer_sketch, embed_dim * num_windows)
        assert kmer_sketch <= 2**(embed_dim * num_windows)
        kmer_sketches.append(kmer_sketch)
    assert len(kmer_sketches) == N - k + 1, f"{len(kmer_sketches)} != {N - k + 1}"
    return kmer_sketches

def build_lookup(seq, sketches, N, k):
    lut = defaultdict(list)
    for i in range(N - k + 1):
        lut[sketches[i]].append(i)
    return lut

def discretize(G, mmer_sketch):
    result = 0
#     product = np.asarray(np.sign(G@mmer_sketch) > 0, dtype=np.uint8)
#     for x in product:
#         result <<= 1
#         result |= x

    for x in mmer_sketch:
        result <<= 1
        result |= (x >= 0)
    return result

# Generate main sequence
N = 1000000
# k = 20

embed_dim = 25
mutation_rate = 20/100
G = np.random.normal(size=(embed_dim, embed_dim))
tuple_length = 3
num_paths = 1000
alphabet = ['A', 'C', 'T', 'G']

seqv = np.random.choice(alphabet, N)
seq = ''.join(seqv)
sketcher = TS(t=tuple_length, D=embed_dim, A=len(alphabet))    

In [None]:
precisions = {}
recalls = {}
brute_force = {}
for k in range(30, 90, 10):
    print(k)
    path_len = 200
    m = k//2
    stride = k//2
#     if k == 30:
#         m = 15
#         stride = 15
        
#     if k == 40:
#         m = 20
#         stride = 11        
        
#     if k == 50:
#         m = 30
#         stride = 11
        
#     if k == 60:
#         m = 40
#         stride = 11
        
#     if k == 70:
#         m = 50
#         stride = 20
        
#     if k == 80:
#         m = 60
#         stride = 15
        
    
    # Change this to use TSS!
    mmers = [seqv[i:i+m] for i in range(N - m + 1)]
    mmer_sketches = get_discrete_mmer_sketches(mmers, N, m, alphabet, G, sketcher)
    sketches = get_kmer_sketches(mmer_sketches, N, k, m, stride, embed_dim)
    lut = build_lookup(seq, sketches, N, k)
    
    paths = generate_paths(seq, path_len, num_paths, k, mutation_rate, alphabet)
    recall = np.array([False for _ in range(num_paths)])
    hits = np.array([0 for _ in range(num_paths)])
    tp = np.array([0 for _ in range(num_paths)])
    num_found_hits = 0
    recall_distance = 5
    hits_out_of_total = ([0 for _ in range(num_paths)])

    for i, path in enumerate(tqdm(paths)):
        reference, query, node_path = path
        hit = False

        # Mmers
        query_mmers = [query[j:j+m] for j in range(len(query) - m + 1)]

        # Build sketches
        query_mmer_sketches = get_discrete_mmer_sketches(query_mmers, len(query), m, alphabet, G, sketcher)
        query_sketches = get_kmer_sketches(query_mmer_sketches, len(query), k, m, stride, embed_dim)
        
        # Lookup
        for query_sketch in query_sketches:
            if query_sketch in lut:
                hits[i] += len(lut[query_sketch])
                
                node_path_start = node_path[0] # this is the path of the unmutated sequence
                node_path_end = node_path[-1]

                for node in lut[query_sketch]: # for each node j in the matches
                    if node >= node_path_start - recall_distance and node <= node_path_end + recall_distance:
                        recall[i] = True #here
                        tp[i] += 1
        hits_out_of_total[i] = hits[i] / (N - k + 1)
    seqs_that_hit = np.argwhere(hits > 0)
    precisions[k] = np.sum((tp[seqs_that_hit] / hits[seqs_that_hit])) / len(seqs_that_hit)
    recalls[k] = np.sum(recall) / num_paths
    brute_force[k] = np.sum(hits_out_of_total) / len(seqs_that_hit)
    print(precisions)
    print(recalls)
    print(brute_force)


30


 52%|████████████████████▋                   | 516/1000 [15:41<13:10,  1.63s/it]

In [500]:
fig = go.Figure()
keys = list(precisions.keys())
fig.add_trace(
    go.Scatter(x=list(precisions.values()), y=list(recalls.values()), text=keys, mode="lines+markers+text")
)
fig.update_traces(textposition="top center")
fig.update_layout(
    title=f"PR @ {mutation_rate} mutation rate, {recall_distance} recall threshold, {N} reference length",
    xaxis_title="Precision",
    yaxis_title="Recall",
)
fig.show()

fig = go.Figure()
keys = list(precisions.keys())
fig.add_trace(
    go.Scatter(x=list(brute_force.values()), y=list(recalls.values()), text=keys, mode="lines+markers+text")
)
fig.update_traces(textposition="top center")
fig.update_layout(
    title=f"PR @ {mutation_rate} mutation rate, {recall_distance} recall threshold, {N} reference length",
    xaxis_title="Brute Force",
    yaxis_title="Recall",
)
fig.show()

In [372]:
(20 - 5 + 1) / 3

5.333333333333333