## Denoiser for Genomic Data
### Setup
Remember to do `cd crates/python/ && maturin develop && cd ../..` from the `lz78_rust` directory after making any changes to the Rust or the Python interface!

In [1]:

import numpy as np
from sys import stdout
import numpy as np
import matplotlib.pyplot as plt

import scipy.signal as signal
from scipy.linalg import toeplitz
import dill as pickle
from lz78 import Sequence, LZ78Encoder, CharacterMap, BlockLZ78Encoder, LZ78SPA
from lz78 import encoded_sequence_from_bytes, spa_from_bytes

### Import Data

In [33]:
# Load a section of DNA sequence in /test_data/DNA_4M.txt
with open('test_data/DNA_100K_int.txt', 'r') as f:
    DNA_data = f.read().translate({ord('\n'): None})
print(DNA_data[-10:])

1202220303


In [38]:
stdout.flush()
spa = LZ78SPA(4, gamma=1) # try 1/2 or 2 may work better for genomic data
charmap_str = "0123"+DNA_data  ## IMPORTANT: do specify the sequence of the alphabet you want to train on. I'll make this more robust in the next update.
char_map =  CharacterMap(charmap_str)



In [39]:
stdout.flush() ## add for any long process
sequence = Sequence(DNA_data, charmap = char_map)
train_loss = spa.train_on_block(sequence)

In [40]:
Pi=np.array([[0.7, 0.1, 0.1, 0.1],
       [0.1, 0.7, 0.1, 0.1],
       [0.1, 0.1, 0.7, 0.1],
       [0.1, 0.1, 0.1, 0.7]])
alphabet = ['A', 'G', 'T', 'C']

In [41]:
WL = 5 ## window length
LAH = 1 ## look ahead
test_data =  "032101233212323220203030232010212120200101030102022111321212112230033203021030"
test_data = char_map.encode(test_data)
pdf_traverse = spa.traverse_and_get_prob(test_data[:WL])
pdf_traverse ## the fifth symbol is n for when DNA data is undefined. 

[0.3425742574257426,
 0.2524752475247525,
 0.21485148514851485,
 0.1900990099009901]

In [42]:
pdf_lookahead = spa.traverse_and_get_prob_with_lookahead(input= test_data[:WL], lookahead=test_data[WL:WL+LAH]) 
pdf_lookahead

[0.36893203883495146,
 0.30097087378640774,
 0.2961165048543689,
 0.03398058252427184]

In [47]:
def universal_denoiser_naive(signal_to_filter, spa_tree, char_map, window_len,pi_matrix, x_alphabet):
    sequence_l = len(signal_to_filter)
    
    Zt_index = signal_to_filter[window_len:sequence_l]
    print(Zt_index[:10  ])
    filtered_signal = []
    for il in range(sequence_l-window_len):
        prob = spa_tree.traverse_and_get_prob(signal_to_filter[il:il+window_len])
        PX_t_Z_tm1_ele=np.matmul(np.linalg.inv(np.transpose(pi_matrix)),prob)
        alphabet_ind = Zt_index[il]
        num = pi_matrix[:,alphabet_ind] * PX_t_Z_tm1_ele
        den = prob[alphabet_ind]
        pxt = num/den
        pxt = pxt/np.sum(pxt)
        xt_hat = np.dot(pxt,x_alphabet)
        filtered_signal.append(xt_hat)
    
    return filtered_signal

def get_prob_from_MC(Dd, seed_str, spa_tree, char_map,num_exper):
    generated_seq = []
    for i_mc in range(num_exper):
        output, loss = spa_tree.generate_data(Dd, seed_data = seed_str,temperature =1, min_context= 3)
        generated_seq.append(output)
    # count the number of times of last element in the generated sequence
    prob = np.zeros(3)
    for i in range(num_exper):
        # charmap the generated sequence
        generated_ind = char_map.str_to_symbol_list(generated_seq[i][-1])
        #print(generated_ind)    
        prob[generated_ind] += 1
    prob = prob/num_exper
    return prob

def universal_denoiser_delay_mc(Delay,signal_to_filter_mc, spa_tree, char_map, 
                                window_len,pi_matrix, x_alphabet, num_exp):
    sequence_len = len(signal_to_filter_mc) 
    Zt_index = char_map.str_to_symbol_list(signal_to_filter_mc[window_len:sequence_len])
    filtered_signal_delay_mc = []
    for idmc in range(sequence_len-window_len):
        Zt_seed_str = signal_to_filter_mc[idmc:window_len+idmc-Delay]
        prob_mc = get_prob_from_MC(Delay, Zt_seed_str, spa_tree, num_exp)
        
        PX_t_Z_tm1_mc = np.matmul(np.linalg.inv(np.transpose(pi_matrix)),prob_mc)
        alphabet_ind = Zt_index[idmc]
        #print(alphabet_ind)
        num = pi_matrix[:,alphabet_ind] * PX_t_Z_tm1_mc
        den = prob_mc[alphabet_ind]
        pxt = num/den
        pxt = pxt/np.sum(pxt)
        xt_hat_mc = np.dot(pxt,x_alphabet)
        #print(xt_hat_mc)
        filtered_signal_delay_mc.append(xt_hat_mc)
    
    #print(filtered_signal_delay_mc)

    return filtered_signal_delay_mc
def universal_denoiser_lookahead(Lookahead,signal_to_filter_la, spa_tree, char_map: CharacterMap, 
                                window_len,pi_matrix, x_alphabet):

    sequence_len = len(signal_to_filter_la) 
    Zt_index =signal_to_filter_la[window_len:sequence_len]
    
    filtered_signal_lookahead = []
    #signal_to_filter_la = np.array(signal_to_filter_la)
    for ila in range(sequence_len-window_len):
        Zt_seed_str = signal_to_filter_la[ila:window_len+ila]
        Zt_lookahead_str = signal_to_filter_la[ila+window_len:ila+window_len+Lookahead]

        Zt_seed_vec = Zt_seed_str
        Zt_lookahead_vec = Zt_lookahead_str
        prob_la = spa_tree.traverse_and_get_prob_with_lookahead(Zt_seed_vec, Zt_lookahead_vec)
        
        PX_t_Z_tm1_la = np.matmul(np.linalg.inv(np.transpose(pi_matrix)),prob_la)
        alphabet_ind = Zt_index[ila]
        num = pi_matrix[:,alphabet_ind] * PX_t_Z_tm1_la
        den = prob_la[alphabet_ind]
        if den == 0:
            #break and have average of x_alphabet
            xt_hat_la = np.mean(x_alphabet)
        else:
            pxt = num/den
            pxt = pxt/np.sum(pxt)
            xt_hat_la = np.dot(pxt,x_alphabet)
        filtered_signal_lookahead.append(xt_hat_la)
    
    return filtered_signal_lookahead    

In [48]:
stdout.flush()
LAH =1# lookahead
WL = 5 # window length
#load noisy sequence
with open('test_data/DNA_100K_int_noisy.txt', 'r') as f:
    Zt_str = f.read()
alphabet_Xt = [0,1,2,3] # A,G,T,C
Zt_to_filter = char_map.encode(Zt_str[:100]) ## this converts the string to a vec
#Xt_hat = universal_denoiser_naive(Zt_to_filter, spa, char_map,WL, pi_matrix_true, alphabet_Xt)
Xt_hat = universal_denoiser_lookahead(LAH,Zt_to_filter, spa, char_map,WL, Pi, alphabet_Xt)

In [51]:
#load X_clean
with open('test_data/DNA_100K_int.txt', 'r') as f:
    Xt = f.read()
# convert string to integer
Xt = char_map.encode(Xt[:100])


In [None]:
X_clean = Xt[WL:]
# calculate the mean squared error on not-nan values
MSE_loss = 0 
k = 0
for i in range(len(Xt_hat)):
    if not np.isnan(Xt_hat[i]):
        loss_ind = (Xt_hat[i] - X_clean[i])**2
        MSE_loss += loss_ind

MSE_loss = MSE_loss/len(Xt_hat)

MSE_loss
# I honestly don't know what loss function you guys use so you might need to write your own argmin function adapting to your loss function


0.9765414609974967