## DNA Data Storage
Recovering the data stored on DNA by clustering the noisy reads.

In [1]:
import numpy as np
from random import shuffle
from Bio.Align.Applications import MuscleCommandline
from Bio import AlignIO
from skbio.alignment import local_pairwise_align_ssw
from skbio import DNA
import itertools
import operator
import time
import zipfile
import subprocess as sp
import hashlib
from include import *
from random import random

## Data preparation
You may choose to work with either a simulated dataset or a real synthesized dataset. The overall steps you need to take for each data are as follows:<br> <br>
Simulated data: <br>
In this example, the following stages are tested: <br>
1- Loading File_1.zip and encoding it using a Reed-Solomon encoder. <br>
2- Generating noisy reads from the encoded file. <br>
3- Clustering the reads using either the Trivial clustering or the Locality-Sensitive Hashing (LSH) method. <br>
4- Generating candidates from the clusters by performing multiple sequence alignment and majority voting. <br>
5- Putting the candidates into a Reed-Solomon decoder and recovering the oiginal data. <br>
*Encoding note: <br>
Note that n, k, N, K, nus, and numblocks (this one is printed out when the following function is called) will be needed for decoding. <br> <br>
Real synthesized data: <br>
1- Load files File1_ODNA.txt and I16_S2_R1_001.fastq which contain the original sequences and the noisy reads, respectively. <br>
2- Clustering the reads using either the Trivial clustering or the Locality-Sensitive Hashing (LSH) method. <br>
3- Generating candidates from the clusters by performing multiple sequence alignment and majority voting. <br>
4- Putting the candidates into a Reed-Solomon decoder and recovering the oiginal data. <br>

In [2]:
def encode(infile):
    stat = sp.Popen(["./simulate/texttodna", "--n=16383", "--k=10977", "--N=20", "--K=18", "--nuss=6", "--l=4", \
                     "--encode", infile,"--output=./data/data_encoded.txt"],stdout = sp.PIPE)
    for line in stat.stdout.read().splitlines():
        print(line)
    return True

In [3]:
def generate_noisy_reads(infile,p,n):
    # infile: encoded file
    # p: deletion probabilty per character (this increases per character index to mimick the synthesis procedure)
    # n: number of noisy reads per original sequence
    
    orig_seqs = []
    f = open(infile,"r")
    for seq in f.read().splitlines():
        orig_seqs += [seq]
    
    
    rand_replace = lambda c,d: c if random() > p*d else '' # this function induces deletions
    
    population = []
    for seq in orig_seqs:
        for i in range(n):
            s = ''.join([rand_replace(c,j) for j,c in enumerate(seq)])
            population += [s]
    shuffle(population)
    return orig_seqs,population

In [4]:
SIMULATE = False

In [5]:
if SIMULATE:
    ##### Encode
    s = time.time()
    encode("--input=./data/File_1.zip")
    print("Encoding time:",round(time.time()-s,1),"s")
    ##### Generate noisy reads by inducing deletions (perturbation step)
    orig_seqs,reads = generate_noisy_reads("./data/data_encoded.txt",0.0005,6)
else:
    ##### Load synthesized data
    orig_seqs = file_to_list("data/File1_ODNA.txt")
    filename = "data/I16_S2_R1_001.fastq"
    seqs = fastq_to_list(filename)
    print("all sequences: ", len(seqs))
    print("all orig sequences: ", len(orig_seqs))
    reads = [seq for seq in seqs if len(seq) >= 55 and len(seq)<=70]
    print("all trimmed sequences: ",  len(reads))

all sequences:  29303855
all orig sequences:  16383
all trimmed sequences:  29303855


## Clustering
You may choose to work with one of the following clustering methods: <br>
1. Trivial clustering: looking at the first "nbeg" characters of each sequence to cluster the reads. This option is way faster than the other one, especially if you choose to work with the syntesized dataset. This is important since choosing this option reducing runtime of the following sections. <br>
2. Locality-Sensititve Hashing (LSH): clustering the reads based on the computed LSH signatures. If this option is chosen when working on the synthesized dataset, both clusering and alignment stages will be slower than the trivial clustering.

In [6]:
#===== assign numbers to shingles of each sequence=====#
def kmerDNA(seq,k=3):
    kmer = []
    for ell in range(len(seq)-k+1):
        nstr = seq[ell:ell+k]
        index = 0
        for j,c in enumerate(nstr):
            if c == 'A':
                i = 0
            elif c == 'C':
                i = 1
            elif c == 'G':
                i = 2
            elif c == 'T':
                i = 3
            else:
                index = -1
                break
            index += i*(4**j)
        kmer += [index]
    return kmer
#=====min-hash object=====#
class minhashsig():
    # min-hash of k-mers
    def __init__(self,m,k):
        # m is the number of signatures
        self.tables = [np.random.permutation(4**k) for i in range(m)]
        self.k = k
    def generate_signature(self,seq):
        kmer = kmerDNA(seq,self.k)
        sig = [ min([table[i] for i in kmer]) for table in self.tables]
        return sig
#=====pair detection=====#
def extract_similar_pairs(sigs,m,k_lsh,ell_lsh,maxsig):
    # sigs: minhash signatures
    # ell_lsh: number of LSH signatures
    # k_lsh: number of MH signatures to be concatenated
    # we use generatrs to yield a number of pairs at a time for the sake of memory efficiency
    
    pairs = set([])
    
    # generate ell_lsh random indices
    for ell in range(ell_lsh):
        pair_count = 0
        s = time.time()
        lshinds = np.random.permutation(m)[:k_lsh]
        # generate LSh signatures
        lshsigs = []
        for sig in sigs:
            lshsig = 0
            for i,lshind in enumerate(lshinds):
                lshsig += sig[lshind]*(maxsig**i)
            lshsigs += [lshsig]
        d = {}
        for ind,sig in enumerate(lshsigs):
            if sig in d:
                d[sig] += [ind]
            else:
                d[sig] = [ind]
        for candidates in d.values():
            cent = set([])
            if len(candidates) > 1:
                for pair in itertools.combinations(candidates,2):
                    cent.add(pair[0])
                    if len(cent)==1:
                        pairs.add(pair)
                    else:
                        break
                        
        yield pairs,ell
        pair_count += len(pairs)
        pairs = set([])
#=====form clusters based on pairs=====#
def center_cluster(pairs):
    clusters = {}
    hold = 0
    t_counter = 0
    ell_copy = 0
    pairsize = 0
    while not hold:
        
        try:
            out = next(pairs)
            pairs_sort = list(out[0])
            ell = out[1]
            pairsize += len(pairs_sort)
            pairs_sort.sort()
            s = time.time()
            for (u,v) in pairs_sort:
                if u in clusters:
                    clusters[u] += [v]
        
                if v in clusters:
                    clusters[v] += [u]
        
                if v not in clusters and u not in clusters:
                    clusters[u] = [v]
    
        except StopIteration:
            hold = 1
            print("clustering completed","---",pairsize,"pairs clustered")
        if ell==ell_copy:
            t_counter += time.time()-s
        else:
            print("Clustering time for LSH",ell_copy,":",t_counter,'\n')
            t_counter = time.time()-s
            ell_copy = ell
 
    return clusters

#=====LSH clustering (main function)=====#
def lsh_cluster(seqs,m,k,k_lsh=2,ell_lsh=4):
    # This is the main function
    maxsig = 4**k
    minhash = minhashsig(m,k)
    sigs = [minhash.generate_signature(seq[:40]) for seq in seqs]
    pairs = extract_similar_pairs(sigs,m,k_lsh,ell_lsh,maxsig)
    clusters = center_cluster(pairs)
    return clusters

In [7]:
def filter_nonunique(seqs):
    # filter out sequences that appear many times
    d = {}
    ctr = 0
    for i,seq in enumerate(seqs):
        if seq in d:
            d[seq] += [i]
        else:
            d[seq] = [i]
    import operator
    sorted_d = sorted(d.items(), key=operator.itemgetter(1))
    sorted_d.reverse()
    return d,ctr

In [8]:
TRIVIAL = True # when real synthesized data is selected, set this flag to "True" (this dramatically reduces the runtime)

In [9]:
if TRIVIAL:
    start = time.time()
    nbeg=14
    d,ctr = filter_nonunique([seq[:nbeg] for seq in reads])
    clusters = [d[a] for a in d if len(d[a]) > 3]
    end = time.time()
    print("Runtime:",round(end-start,1),"s")
    print(len(clusters),"number of clusters created.")
    fclusts = clusters.copy()
else:
    # set up the parameters and call the lsh_cluster function
    k_lsh = 4
    sim = 0.5
    ell_lsh = int(1/(sim**k_lsh))
    m,k=50,5
    start = time.time()
    clusters = lsh_cluster(reads,m,k,k_lsh,ell_lsh)
    end = time.time()

    print("Runtime:",round(end-start,1),"s")
    print(len(clusters),"number of clusters created")

Runtime: 43.6 s
369425 number of clusters created.


## Filtering the clusters (do this step only if you have chosen LSH clustering)
In this stage, we use the max_match function to filter the clusters. This function checks the simlarity of each cluster member and its cluster center based on local alignment.

In [10]:
# adding the center of each cluter to the cluster (also removing duplicates from each cluster)
clusts = [ [c] + list(set(clusters[c])) for c in clusters if len(clusters[c]) > 3 ]

In [11]:
#=====max matching=====# 
def max_match(seq1,seq2):
    # This function checks whether seq1 and seq2 are similar or not
    # Checking all pairs within a cluster dramatically increases the time complexity, 
    # so by default, in the next cell, we call this function to only check the pairs
    # that one of their members is the cluster center
    
    alignment,score,start_end_positions \
        = local_pairwise_align_ssw(DNA(seq1) , DNA(seq2) , match_score=2,mismatch_score=-3)
    a = str(alignment[0])
    b = str(alignment[1])
    ctr = 0
    for i,j in zip(a,b):
        if i==j:
            ctr += 1
    return ctr

In [12]:
th = 35 # filtering threshold

k = len(clusts)
s = time.time()
fclusts = []
for i,c in enumerate(clusts):
    cent = reads[c[0]]
    cc = [c[0]]
    for e in c[1:]:
        score = max_match(cent,reads[e])
        if score >= th:
            cc += [e]
    fclusts += [cc]
    if i%1000 == 0:
        print("%",round(i*100/len(clusts),2),"of the clusters are filtered.")
print("filtering time for",k,"clusters:",round(time.time()-s,2),"s")

% 0.0 of the clusters are filtered.
% 3.88 of the clusters are filtered.
% 7.75 of the clusters are filtered.
% 11.63 of the clusters are filtered.
% 15.5 of the clusters are filtered.
% 19.38 of the clusters are filtered.
% 23.25 of the clusters are filtered.
% 27.13 of the clusters are filtered.
% 31.0 of the clusters are filtered.
% 34.88 of the clusters are filtered.
% 38.75 of the clusters are filtered.
% 42.63 of the clusters are filtered.
% 46.5 of the clusters are filtered.
% 50.38 of the clusters are filtered.
% 54.26 of the clusters are filtered.
% 58.13 of the clusters are filtered.
% 62.01 of the clusters are filtered.
% 65.88 of the clusters are filtered.
% 69.76 of the clusters are filtered.
% 73.63 of the clusters are filtered.
% 77.51 of the clusters are filtered.
% 81.38 of the clusters are filtered.
% 85.26 of the clusters are filtered.
% 89.13 of the clusters are filtered.
% 93.01 of the clusters are filtered.
% 96.88 of the clusters are filtered.
filtering time for 

## Aligning the clusters
As for the alignment stage, there are a number of options differing in accuracy and speed. For the sake of having the highest accuracy, we chose the Muscle function which is the most accurate multiple-alignment function among all python built-in functions.

In [10]:
def multiple_alignment_muscle(cluster,out=False):
    # write cluster to file
    file = open("clm.fasta","w") 
    for i,c in enumerate(cluster):
        file.write(">S%d\n" % i)
        file.write(c)
        file.write("\n")
    file.close()

    muscle_exe = r"~/muscle3.8.31_i86linux64" # assuming you've already put this in the main directory
    output_alignment = "clmout.fasta"
    muscle_cline = MuscleCommandline(muscle_exe, input="clm.fasta", out=output_alignment)
    stdout, stderr = muscle_cline()
    msa = AlignIO.read(output_alignment, "fasta")
    if out:
        print(msa)
    alignedcluster = []
    for i in msa:
        alignedcluster += [i.seq]
    return alignedcluster

In [11]:
fresults = []
def align_clusters(clusters,masize = 15):
    ### align clusters, generate candidates
    for i, clusterinds in enumerate(clusters):
        cluster = [reads[i] for i in clusterinds]
        if len(cluster) < 3:
            continue
        if len(cluster) > masize:
            for j in range(5):
                shuffle(cluster)
                ma = multiple_alignment_muscle(cluster[:masize])
                fresults.append(ma)
        else:
            ma = multiple_alignment_muscle(cluster[:masize])
            fresults.append(ma)
            
        if i % 1000 == 0:
            print("%",round(i*100/len(clusters),2),"of the clusters are aligned.")

In [12]:
s = time.time()
align_clusters(fclusts,15)
print("Alignment Runtime:",round(time.time()-s,2),"s")

% 0.0 of the clusters are aligned.
% 0.27 of the clusters are aligned.
% 0.54 of the clusters are aligned.
% 0.81 of the clusters are aligned.
% 1.08 of the clusters are aligned.
% 1.35 of the clusters are aligned.
% 1.62 of the clusters are aligned.
% 1.89 of the clusters are aligned.
% 2.17 of the clusters are aligned.
% 2.44 of the clusters are aligned.
% 2.71 of the clusters are aligned.
% 2.98 of the clusters are aligned.
% 3.25 of the clusters are aligned.
% 3.52 of the clusters are aligned.
% 3.79 of the clusters are aligned.
% 4.06 of the clusters are aligned.
% 4.33 of the clusters are aligned.
% 4.6 of the clusters are aligned.
% 4.87 of the clusters are aligned.
% 5.14 of the clusters are aligned.
% 5.41 of the clusters are aligned.
% 5.68 of the clusters are aligned.
% 5.96 of the clusters are aligned.
% 6.23 of the clusters are aligned.
% 6.5 of the clusters are aligned.
% 6.77 of the clusters are aligned.
% 7.04 of the clusters are aligned.
% 7.31 of the clusters are alig

% 60.36 of the clusters are aligned.
% 60.63 of the clusters are aligned.
% 60.91 of the clusters are aligned.
% 61.18 of the clusters are aligned.
% 61.45 of the clusters are aligned.
% 61.72 of the clusters are aligned.
% 61.99 of the clusters are aligned.
% 62.26 of the clusters are aligned.
% 62.53 of the clusters are aligned.
% 62.8 of the clusters are aligned.
% 63.07 of the clusters are aligned.
% 63.34 of the clusters are aligned.
% 63.61 of the clusters are aligned.
% 63.88 of the clusters are aligned.
% 64.15 of the clusters are aligned.
% 64.42 of the clusters are aligned.
% 64.7 of the clusters are aligned.
% 64.97 of the clusters are aligned.
% 65.24 of the clusters are aligned.
% 65.51 of the clusters are aligned.
% 65.78 of the clusters are aligned.
% 66.05 of the clusters are aligned.
% 66.32 of the clusters are aligned.
% 66.59 of the clusters are aligned.
% 66.86 of the clusters are aligned.
% 67.13 of the clusters are aligned.
% 67.4 of the clusters are aligned.
% 67

## Majority merging
At this stage, we merge each aligned cluster by putting up a voting for each position within the aligned sequences. Then, the fraction of all original sequences recovered is computed.

In [13]:
# This function returns the fraction of origignal squences recovered given a number of candidates
def fraction_recovered(candidates,orig_seqs):
    d = {}
    for seq in orig_seqs:
        d[seq] = 0
    for cand in candidates:
        if cand in d:
            d[cand] += 1
    av = sum([ d[seq]>0 for seq in d]) / len(d)
    print("Fraction of recovered sequences: ", av )
    if av>0:
        print("Fraction of recovered sequences: ", sum([ d[seq] for seq in d]) / len(d) / av )

In [14]:
def majority_merge(reads,weight = 0.4):
    # assume reads have the same length
    res = ""
    for i in range(len(reads[0])):
        counts = {'A':0,'C':0,'G':0,'T':0,'-':0,'N':0}
        for j in range(len(reads)):
            counts[reads[j][i]] +=1
        counts['-'] *= weight
        mv = max(counts.items(), key=operator.itemgetter(1))[0]
        if mv != '-':
            res += mv
    return res

In [15]:
candidates = []
for ma in fresults:
    candidates.append(majority_merge(ma,weight=0.5))
fraction_recovered( [seq[:60] for seq in candidates] , orig_seqs)

Fraction of recovered sequences:  0.7269730818531405
Fraction of recovered sequences:  2.1608732157850548


In [16]:
f = open("./data/data_drawnseg.txt","w")
for i,seq in enumerate(candidates):
    if len(seq) >= 60:
        f.write(seq[:60])
        f.write('\n')
f.close()

## Decoding

In [17]:
def decode(infile):
    stat = sp.Popen(["./simulate/texttodna", "--decode",  "--n=16383", "--k=10977", "--N=20", "--K=18", "--nuss=6", "--l=4", \
                    "--numblocks=1", infile, "--output=./data/data_rec.zip"],stdout=sp.PIPE)
    for line in stat.stdout.read().splitlines():
        print(line)
    return True

In [None]:
s = time.time()
decode("--input=./data/data_drawnseg.txt")
runtime = round(time.time()-s,2)

In [None]:
print("Decoding time:",runtime,"s")

## Evaluation
Let's first extract the decoded file (data_rec.zip)

In [19]:
with zipfile.ZipFile('./data/data_rec.zip', 'r') as zipObj:
   # Extract zip file contents to the "decoded" directory
   zipObj.extractall(path = "./data/decoded")

By looking at the files in File_1.zip and the decoded files located in the 'decoded' folder, we see that the original files are completely recovered.